ray.train.mosaic.MosaicTrainer
ray.train.mosaic.MosaicTrainer#
- class ray.train.mosaic.MosaicTrainer(*args, **kwargs)[source]#
Bases:
ray.train.torch.torch_trainer.TorchTrainerA Trainer for data parallel Mosaic Composers on PyTorch training.
This Trainer runs the
composer.trainer.Trainer.fit()method on multiple Ray Actors. The training is carried out in a distributed fashion through PyTorch DDP. These actors already have the necessary torch process group already configured for distributed PyTorch training.The training function ran on every Actor will first run the specified
trainer_init_per_workerfunction to obtain an instantiatedcomposer.Trainerobject. Thetrainer_init_per_workerfunction will have access to preprocessed train and evaluation datasets.Example:
import torch.utils.data import torchvision from torchvision import transforms, datasets from composer.models.tasks import ComposerClassifier import composer.optim from composer.algorithms import LabelSmoothing import ray import ray.train as train from ray.train import ScalingConfig from ray.train.mosaic import MosaicTrainer def trainer_init_per_worker(config): # prepare the model for distributed training and wrap with # ComposerClassifier for Composer Trainer compatibility model = torchvision.models.resnet18(num_classes=10) model = ComposerClassifier(ray.train.torch.prepare_model(model)) # prepare train/test dataset mean = (0.507, 0.487, 0.441) std = (0.267, 0.256, 0.276) cifar10_transforms = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean, std)] ) data_directory = "~/data" train_dataset = datasets.CIFAR10( data_directory, train=True, download=True, transform=cifar10_transforms ) # prepare train dataloader batch_size_per_worker = BATCH_SIZE // session.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size_per_worker ) train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader) # prepare optimizer optimizer = composer.optim.DecoupledSGDW( model.parameters(), lr=0.05, momentum=0.9, weight_decay=2.0e-3, ) return composer.trainer.Trainer( model=model, train_dataloader=train_dataloader, optimizers=optimizer, **config ) scaling_config = ScalingConfig(num_workers=2, use_gpu=True) trainer_init_config = { "max_duration": "1ba", "algorithms": [LabelSmoothing()], } trainer = MosaicTrainer( trainer_init_per_worker=trainer_init_per_worker, trainer_init_config=trainer_init_config, scaling_config=scaling_config, ) trainer.fit()
- Parameters
trainer_init_per_worker – The function that returns an instantiated
composer.Trainerobject and takes in configuration dictionary (config) as an argument. This dictionary is based ontrainer_init_configand is modified for Ray - Composer integration.datasets – Any Datasets to use for training. At the moment, we do not support passing datasets to the trainer and using the dataset shards in the trainer loop. Instead, configure and load the datasets inside
trainer_init_per_workerfunctiontrainer_init_config – Configurations to pass into
trainer_init_per_workeras kwargs. Although the kwargs can be hard-coded in thetrainer_init_per_worker, using the config allows the flexibility of reusing the same worker init function while changing the trainer arguments. For example, when hyperparameter tuning you can reuse the sametrainer_init_per_workerfunction with different hyperparameter values rather than having multipletrainer_init_per_workerfunctions with different hard-coded hyperparameter values.torch_config – Configuration for setting up the PyTorch backend. If set to None, use the default configuration. This replaces the
backend_configarg ofDataParallelTrainer. Same as inTorchTrainer.scaling_config – Configuration for how to scale data parallel training.
dataset_config – Configuration for dataset ingest.
run_config – Configuration for the execution of the training run.
preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.
resume_from_checkpoint – A MosiacCheckpoint to resume training from.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods
Converts self to a
tune.Trainableclass.can_restore(path)Checks whether a given directory contains a restorable Train experiment.
fit()Runs training.
Returns a copy of this Trainer's final dataset configs.
setup()Called during fit() to perform initial setup on the Trainer.