persia.data

In PERSIA, we provide the DataLoader class to load the data. The DataLoader will preprocess the PersiaBatch and lookup the embedding for id_type_features. In order to initalize a DataLoader, the dataset must be an ``iterable dataset” (an instance of :class`.IterableDatasetBase` subclass). To generate an iterable dataset, you can use the StreamingDataset to fetch the PersiaBatch from the dataflow, or use the IterableDataset to generate the PersiaBatch locally.

Module Contents

class persia.data.DataLoader(dataset, forward_buffer_size=10, timeout_ms=1000 * 60 * 10, num_workers=10, reproducible=False, embedding_staleness=None)

Data loader will preprocess the data to the PersiaTrainingBatch.

The DataLoader is a pipeline that preprocess the PersiaBatch in several steps. Each step will process the task concurrently with multiple threads to improve the efficiency.

Warning

The DataLoader cannot stop the iteration unless raise the TimeoutError if you use the StreamingDataset (see StreamingDataset for more details).

Parameters
  • dataset (IterableDatasetBase) – dataset for DataLoader to retrive replica info and sender channel.

  • forward_buffer_size (int, optional) – PersiaTrainingBatch buffer size, this args effect the gpu memory cost.

  • timeout_ms (int, optional) – timeout of data fetching, millisecond unit.

  • num_workers (int, optional) – number of spawned thread workers to lookup embedding and PersiaBatch prefetch.

  • reproducible (bool, optional) – iterate the data in fixed order, make the dataflow deterministic.

  • embedding_staleness (int, optional) – max number of batched staleness embeddings each rank. A staleness embedding means it is prefetched from embedding server before gradient updated.

class persia.data.IterableDataset(buffer_size=10)

Bases: IterableDatasetBase

The IterableDataset can iterate through the dataset multiple times, whereas in StreamingDataset the dataset is only iterated once. It is advised that you implement the TestDataset using IterableDataset.

Implement the __iter__ function to define the PersiaBatch generation phase.

import numpy as np
from persia.data import IterableDataset, DataLoader
from persia.embedding.data import PersiaBatch, IDTypeFeature

class MyTestDataset(IterableDataset):
    def __init__(self):
        super(MyTestDataset, self).__init__()
        self.data = data
        self.size = 10

    def __iter__(self):
        for i in range(self.size):
            persia_batch = PersiaBatch(id_type_features=IDTypeFeature(
                "id_type_feature_slot",
                [
                    np.array([1000, 10001], dtype=np.uint64),
                    np.array([1003, 10011], dtype=np.uint64),
                ]
            ), requires_grad=False)
            yield persia_batch

dataset = MyTestDataset()
dataloader = DataLoader(dataset)
Parameters

buffer_size (int, optional) – PersiaBatch buffer size

consume_dataset()

Consume __iter__ of itself and return the iterator of preprocess indexes.

Return type

Iterator[int]

class persia.data.IterableDatasetBase(buffer_size=10)

Bases: abc.ABC, Iterable[persia.embedding.data.PersiaBatch]

The role of IterableDatasetBase is to transfer the PersiaBatch to the DataLoader. It wraps the PersiaBatchDataChannel which provides the ability to send data to DataLoader. It has a sender (PersiaBatchDataSender) and a receiver (PersiaBatchDataSender), whose functionalities are illustrated in the example below.

This class cannot be used directly unless it implements __iter__ and consume_dataset functions to be compatible with the DataLoader. __iter__ function generates the PersiaBatch, and consume_dataset sends the PersiaBatch by PersiaBatchDataSender.

Here is an example that implements a synchronous IterableDatasetBase.

from typing import Iterator

import numpy as np
from persia.data import IterableDataset
from persia.embedding.data import PersiaBatch, IDTypeFeature

class MyPersiaIterableDataset(IterableDatasetBase):

    def __iter__(self):
        persia_batch = PersiaBatch(id_type_features=IDTypeFeature(
            "id_type_feature_slot",
            [
                np.array([1000, 10001], dtype=np.uint64),
                np.array([1003, 10011], dtype=np.uint64),
            ]
        ), requires_grad=False)

        yield persia_batch
        yield persia_batch

    def consume_data(self) -> Iterator[int]:
        for preprocess_idx, persia_batch in enumerate(self):
            self.sender.send(persia_batch)
            yield preprocess_idx

Note

MyPersiaIterableDataset implemented in the above example will be slow if you are dealing with a large dataset, since it processes the PersiaBatch synchronously. If you want to improve the performance of data processing, try to use the IterableDataset or StreamingDataset instead.

Parameters

buffer_size (int, optional) – buffer size for PersiaBatchDataChannel.

abstract consume_dataset()

Consume __iter__ of itself and return the iterator of preprocess indexes.

Return type

Iterator[int]

class persia.data.StreamingDataset(buffer_size=10)

Bases: IterableDatasetBase

Streaming dataset receives the PersiaBatch from the upstream data flow sent by DataCtx.

In the implemented StreamingDataset.consume_dataset, the PersiaBatchDataSender instance binds to the RPC service that receives the data automatically. So it is not necessary to implements the

Warning

StreamingDataset will make the DataLoader raise the TimeoutError if the upstream data flow drained.

Parameters

buffer_size (int, optional) – PersiaBatchDataChannel buffer size

consume_dataset()

Consume __iter__ of itself and return the iterator of preprocess indexes.

Return type

Iterator[int]