persia.ctx

Module Contents

class persia.ctx.BaseCtx(threadpool_worker_size=10, device_id=None)

Initializes a common context for other persia context, e.g. DataCtx, EmbeddingCtx and TrainCtx. This class should not be instantiated directly.

Parameters
  • threadpool_worker_size (int) – rpc threadpool worker size.

  • device_id (int, optional) – the CUDA device to use for this process.

class persia.ctx.DataCtx(*args, **kwargs)

Bases: BaseCtx

Data context provides the communication functionality to data generator component. Used for sending a PersiaBatch to the nn worker and embedding worker.

If you use the DataCtx to send the PersiaBatch on data-loader, you should use the StreamingDataset to receive the data on nn-worker.

On data-loader:

from persia.ctx import DataCtx
from persia.embedding.data import PersiaBatch

loader = make_loader()
with DataCtx() as ctx:
    for (non_id_type_features, id_type_features, labels) in loader:
        batch_data = PersiaBatch(
            id_type_features=id_type_features,
            non_id_type_features,
            label,
            requires_grad=True
        )
        ctx.send_data(persia_batch)

On nn-worker:

from persia.ctx import TrainCtx
from persia.data import StreamingDataset, DataLoader

buffer_size = 15

streaming_dataset = StreamingDataset(buffer_size)
data_loader = DataLoader(streaming_dataset)

with TrainCtx(...):
    for persia_training_batch in data_loader:
        ...

Note

The examples cannot be run directly, you should launch the nn_worker, embedding-worker, embedding-parameter-server, and nats-server to ensure the example gets the correct result.

Parameters
  • threadpool_worker_size (int) – rpc threadpool worker size.

  • device_id (int, optional) – the CUDA device to use for this process.

send_data(persia_batch)

Send PersiaBatch from data loader to nn worker and embedding worker side.

Parameters

persia_batch (PersiaBatch) – PersiaBatch that haven’t been processed.

class persia.ctx.EmbeddingCtx(preprocess_mode, model=None, embedding_config=None, *args, **kwargs)

Bases: BaseCtx

Provides the embedding-related functionality. EmbeddingCtx can run offline test or online inference depending on different preprocess_mode. The simplest way to get this context is by using eval_ctx to get the EmbeddingCtx instance.

Example for EmbeddingCtx:

from persia.ctx import EmbeddingCtx, PreprocessMode
from persia.embedding.data import PersiaBatch

model = get_dnn_model()
loader = make_dataloader()
device_id = 0

with EmbeddingCtx(
    PreprocessMode.EVAL,
    model=model,
    device_id=device_id
) as ctx:
    for (non_id_type_features, id_type_features, labels) in loader:
        persia_batch = PersiaBatch(
            id_type_features
            non_id_type_features=non_id_type_features,
            labels=labels
            requires_grad=False
        )
        persia_training_batch = ctx.get_embedding_from_data(persia_batch)
        (output, label) = ctx.forward(persia_training_batch)

Note

The examples cannot be run directly, you should launch the nn_worker, embedding-worker, embedding-parameter-server, and nats-server to ensure the example gets the correct result.

Note

If you set device_id=None, the training data and the model will be placed in host memory rather than in CUDA device memory by default.

Parameters
  • preprocess_mode (PreprocessMode) – different preprocess mode effect the behavior of prepare_features.

  • model (torch.nn.Module) – denese neural network PyTorch model.

  • embedding_config (EmbeddingConfig, optional) – the embedding configuration that will be sent to the embedding server.

clear_embeddings()

Clear all embeddings on all embedding servers.

configure_embedding_parameter_servers(embedding_config)

Apply EmbeddingConfig to embedding servers.

Parameters

embedding_config (EmbeddingConfig) – the embedding configuration that will be sent to the embedding server.

dump_checkpoint(dst_dir, dense_filename='dense.pt', jit_dense_filename='jit_dense.pt', blocking=True, with_jit_model=False)

Save the model checkpoint (both dense and embedding) to the destination directory.

Parameters
  • dst_dir (str) – destination directory.

  • dense_filename (str, optional) – dense checkpoint filename.

  • jit_dense_filename (str, optional) – dense checkpoint filename after PyTorch jit script.

  • blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.

  • with_jit_model (bool, optional) – dump jit script dense checkpoint or not.

dump_embedding(dst_dir, blocking=True)

Dump embeddings to the destination directory. By default, this function is synchronous and will wait for the completion of embedding loading before returning. This is done internally through a call to wait_for_dump_embedding. Set blocking=False to allow asyncronous computation, in which case the function will return immediately. wait_for_dump_embedding to wait until finished if blocking=False.

Parameters
  • dst_dir (str) – destination directory.

  • blocking (bool, optional) – dump embedding in blocking mode or not.

dump_torch_state_dict(torch_instance, dst_dir, file_name, is_jit=False)

Dump a Pytorch model or optimizer’s state dict to the destination directory.

Parameters
  • torch_instance (torch.nn.Module or torch.optim.Optimizer) – dense model or optimizer to be dumped.

  • dst_dir (str) – destination directory.

  • file_name (str) – destination filename.

  • is_jit (bool, optional) – whether to dump model as jit script.

forward(batch)

Call prepare_features and then do a forward step of the model in context.

Parameters

batch (PersiaTrainingBatch) – training data provided by PERSIA upstream including non_id_type_features ,labels, id_type_feature_embeddings and meta info.

Returns

the tuple of output data and target data.

Return type

Tuple[torch.Tensor, Optional[torch.Tensor]]

get_embedding_from_bytes(data, device_id=None)

Get embeddings of the serialized input batch data.

Parameters
  • data (PersiaBatch) – serialized input data without embeddings.

  • device_id (int, optional) – the CUDA device to use for this process.

Returns

PersiaTrainingBatch that contains id_type_feature_embeddings.

Return type

persia.prelude.PersiaTrainingBatch

get_embedding_from_data(persia_batch, device_id=None)

Get embeddings of the serialized input batch data.

Parameters
  • persia_batch (PersiaBatch) – input data without embeddings..

  • device_id (int, optional) – the CUDA device to use for this process.

Returns

PersiaTrainingBatch that contains id_type_feature_embeddings.

Return type

persia.prelude.PersiaTrainingBatch

get_embedding_size()

Get number of ids on all embedding servers.

Return type

List[int]

load_checkpoint(src_dir, map_location=None, dense_filename='dense.pt', blocking=True)

Load the dense and embedding checkpoint from the source directory.

Parameters
  • src_dir (str) – source directory.

  • map_location (str, optional) – load the dense checkpoint to specific device.

  • dense_filename (str, optional) – dense checkpoint filename.

  • blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.

load_embedding(src_dir, blocking=True)

Load embeddings from src_dir. By default, this function is synchronous and will wait for the completion of embedding loading before returning. This is done internally through a call to wait_for_load_embedding. Set blocking=False to allow asyncronous computation, in which case the function will return immediately.

Parameters
  • src_dir (str) – directory to load embeddings.

  • blocking (bool, optional) – dump embedding in blocking mode or not.

load_torch_state_dict(torch_instance, src_dir, map_location=None)

Load a Pytorch state dict from the source directory and apply to torch_instance.

Parameters
  • torch_instance (torch.nn.Module or torch.optim.Optimizer) – dense model or optimizer to restore.

  • src_dir (str) – directory to load torch state dict.

  • map_location (str, optional) – load the dense checkpoint to specific device.

prepare_features(persia_training_batch)

This function converts data from PersiaTrainingBatch to torch.Tensor.

PersiaTrainingBatch contains non_id_type_features, id_type_feature_embeddings and labels. But they can’t use directly in training before convert the Tensor to torch.Tensor.

Parameters

persia_training_batch (PersiaTrainingBatch) – training data provided by PERSIA upstream including non_id_type_features, labels, id_type_feature_embeddings and meta info.

Returns

the tuple of non_id_type_feature_tensors, id_type_feature_embedding_tensors and label_tensors.

Return type

Tuple[List[torch.Tensor], List[torch.Tensor], Optional[List[torch.Tensor]]]

wait_for_dump_embedding()

Wait for the embedding dump process.

wait_for_load_embedding()

Wait for the embedding load process.

class persia.ctx.InferCtx(embedding_worker_address_list, *args, **kwargs)

Bases: EmbeddingCtx

Subclass of EmbeddingCtx that provides the inference functionality without nats-servers.

Example for InferCtx:

import numpy as np
from persia.ctx import InferCtx
from persia.embedding.data import PersiaBatch, IDTypeFeatureWithSingleID

device_id = 0
id_type_feature = IDTypeFeatureWithSingleID(
    "id_type_feature",
    np.array([1, 2, 3], np.uint64)
)
persia_batch = PersiaBatch([id_type_feature], requires_grad=False)

embedding_worker_address_list = [
    "localhost: 8888",
    "localhost: 8889",
    "localhost: 8890"
]
with InferCtx(embedding_worker_address_list, device_id=device_id) as infer_ctx:
    persia_training_batch = persia_context.get_embedding_from_bytes(
        persia_batch.to_bytes(),
    )
    (
        non_id_type_feature_tensors,
        id_type_feature_embedding_tensors,
        label_tensors
    )= persia_context.prepare_features(batch)

Note

The example cannot be run directly, you should launch the embedding-worker and embedding-parameter-server to ensure the example gets correct result.

Parameters
  • embedding_worker_addrs (List[str]) – embedding worker address(ip:port) list.

  • embedding_worker_address_list (List[str]) –

wait_for_serving()
class persia.ctx.PreprocessMode

Bases: enum.Enum

Mode of preprocessing.

Used by prepare_features to generate features of different datatypes.

When set to TRAIN, prepare_features will return a torch tensor with requires_grad attribute set to True. When set to EVAL, prepare_features will return a torch tensor with requires_grad attribute set to False. INFERENCE behaves almost identical to EVAL, except that INFERENCE allows EmbeddingCtx to process the PersiaTrainingBatch without a target tensor.

EVAL = 2
INFERENCE = 3
TRAIN = 1
class persia.ctx.TrainCtx(embedding_optimizer, dense_optimizer, grad_scalar_update_factor=4, backward_buffer_size=10, backward_workers_size=8, grad_update_buffer_size=60, lookup_emb_directly=True, mixed_precision=True, distributed_option=None, *args, **kwargs)

Bases: EmbeddingCtx

Subclass of EmbeddingCtx that implements a backward function to update the embeddings.

Example for TrainCtx:

import torch
import persia
from persia.data import DataLoder, StreamingDataset

device_id = 0
model = get_dnn_model()
model.cuda(device_id)

embedding_optimizer = persia.embedding.optim.SGD(lr=1e-3)
dense_optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
loss_fn = torch.nn.BCELoss(reduction="mean")

prefetch_size = 15
stream_dataset = StreamingDataset(prefetch_size)

with TrainCtx(
    embedding_optimizer,
    dense_optimizer,
    model=model,
    device_id=device_id
) as ctx:
    dataloader = DataLoder(stream_dataset)
    for persia_training_batch in datalaoder:
        output, labels = ctx.forward(persia_training_batch)
        loss = loss_fn(output, labels[0])
        scaled_loss = ctx.backward(loss)

If you want to train the PERSIA task in a distributed environment, you can set distributed_option to the corresponding option you want to use. Currently support Pytorch DDP (distributed data-parallel) (DDPOption) and Bagua (BaguaDistributedOption). The default is Pytorch DDP. The default configuration is determined by get_default_distributed_option when the environment WORLD_SIZE > 1.

You can configure the DDPOption to your specific requirements.

import persia
from persia.distributed import DDPOption

backend = "nccl"
# backend = "gloo" # If you want to train the PERSIA on the CPU cluster.

ddp_option = DDPOption(
    backend=backend,
    init_method="tcp"
)

with TrainCtx(
    embedding_optimizer,
    dense_optimizer,
    model=model,
    distributed_option=ddp_option
) as ctx:
    ...

We also integrated Bagua to PERSIA as an alternative to PytorchDDP. Bagua is an advanced data-parallel framework, also developed by AI Platform @ Kuaishou. Using BaguaDistributedOption in place of DDPOption can significantly speed up the training (See Bagua Benchmark). For more details on the algorithms used by and available options of BaguaDistributedOption, please refer to Bagua tutorials.

Example for BaguaDistributedOption:

from persia.distributed import BaguaDistributedOption

algorithm = "gradient_allreduce"
bagua_args = {}
bagua_option = BaguaDistributedOption(
    algorithm,
    **bagua_args
)

with TrainCtx(
    embedding_optimizer,
    dense_optimizer,
    model=model,
    distributed_option=bagua_option
) as ctx:
    ...
Parameters
  • embedding_optimizer (persia.embedding.optim.Optimizer) – optimizer for the embedding parameters.

  • dense_optimizer (torch.optim.Optimizer) – optimizer for dense parameters.

  • grad_scalar_update_factor (float, optional) – update factor of Gradscalar to ensure that loss scale is finite if set mixed_precision=True.

  • backward_buffer_size (int, optional) – maximum number of gradients queued in the buffer between two backward steps.

  • backward_workers_size (int, optional) – number of workers sending embedding gradients in parallel.

  • grad_update_buffer_size (int, optional) – the size of gradient buffers. The buffer will cache the gradient tensor until the embedding update is finished.

  • lookup_emb_directly (bool, optional) – lookup embedding directly without a separate data loader.

  • mixed_precision (bool) – whether to enable mixed_precision.

  • distributed_option (DistributedBaseOption, optional) – option for distributed training.

backward(loss, embedding_gradient_check_frequency=20)

Update the parameters of the current dense model and embedding model.

Parameters
  • loss (torch.Tensor) – loss of current batch.

  • embedding_gradient_check_frequency (int, optional) – how many batch_size to check gradient finite or not for current embedding.

Return type

torch.Tensor

dump_checkpoint(dst_dir, dense_model_filename='dense.pt', jit_dense_model_filename='jit_dense.pt', opt_filename='opt.pt', blocking=True, with_jit_model=False)

Dump the dense and embedding checkpoint to destination directory.

Parameters
  • dst_dir (str) – destination directory.

  • dense_model_filename (str, optional) – dense model checkpoint filename.

  • jit_dense_model_filename (str, optional) – dense checkpoint filename after PyTorch jit.

  • opt_filename (str, optional) – optimizer checkpoint filename.

  • blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.

  • with_jit_model (bool, optional) – dump dense checkpoint as jit script or not.

load_checkpoint(src_dir, map_location=None, dense_model_filename='dense.pt', opt_filename='opt.pt', blocking=True)

Load the dense and embedding checkpoint from source directory.

Parameters
  • src_dir (str) – source directory.

  • map_location (str, optional) – load the dense checkpoint to specific device.

  • dense_model_filename (str, optional) – dense checkpoint filename.

  • opt_filename (str, optional) – optimizer checkpoint filename.

  • blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.

wait_servers_ready()

Wait until embedding servers are ready to serve.

persia.ctx.cnt_ctx()

Get the BaseCtx recently entered.

Return type

Optional[BaseCtx]

persia.ctx.eval_ctx(*args, **kwargs)

Get the EmbeddingCtx with the EVAL mode.

Return type

EmbeddingCtx