ray.train.lightning.LightningTrainer
ray.train.lightning.LightningTrainer#
- class ray.train.lightning.LightningTrainer(*args, **kwargs)[source]#
Bases:
ray.train.torch.torch_trainer.TorchTrainerA Trainer for data parallel PyTorch Lightning training.
This Trainer runs the
pytorch_lightning.Trainer.fit()method on multiple Ray Actors. The training is carried out in a distributed fashion through PyTorch DDP. These actors already have the necessary Torch process group configured for distributed data parallel training. We will support more distributed training strategies in the future.The training function ran on every Actor will first initialize an instance of the user-provided
lightning_moduleclass, which is a subclass ofpytorch_lightning.LightningModuleusing the arguments provided inLightningConfigBuilder.module().For data ingestion, the LightningTrainer will then either convert the Ray Dataset shards to a
pytorch_lightning.LightningDataModule, or directly use the datamodule or dataloaders if provided by users.The trainer also creates a ModelCheckpoint callback based on the configuration provided in
LightningConfigBuilder.checkpointing(). In addition to checkpointing, this callback also callstrain.report()to report the latest metrics along with the checkpoint.For logging, users can continue to use Lightning’s native loggers, such as WandbLogger, TensorboardLogger, etc. LightningTrainer will also log the latest metrics to the training results directory whenever a new checkpoint is saved.
Then, the training function will initialize an instance of
pl.Trainerusing the arguments provided inLightningConfigBuilder.fit_params()and then runpytorch_lightning.Trainer.fit.Example
import torch import torch.nn.functional as F from torchmetrics import Accuracy from torch.utils.data import DataLoader, Subset from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl from ray.air.config import ScalingConfig from ray.train.lightning import LightningTrainer, LightningConfigBuilder class MNISTClassifier(pl.LightningModule): def __init__(self, lr, feature_dim): super(MNISTClassifier, self).__init__() self.fc1 = torch.nn.Linear(28 * 28, feature_dim) self.fc2 = torch.nn.Linear(feature_dim, 10) self.lr = lr self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1) self.val_loss = [] self.val_acc = [] def forward(self, x): x = x.view(-1, 28 * 28) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = torch.nn.functional.cross_entropy(y_hat, y) self.log("train_loss", loss) return loss def validation_step(self, val_batch, batch_idx): x, y = val_batch logits = self.forward(x) loss = F.nll_loss(logits, y) acc = self.accuracy(logits, y) self.val_loss.append(loss) self.val_acc.append(acc) return {"val_loss": loss, "val_accuracy": acc} def on_validation_epoch_end(self): avg_loss = torch.stack(self.val_loss).mean() avg_acc = torch.stack(self.val_acc).mean() self.log("ptl/val_loss", avg_loss) self.log("ptl/val_accuracy", avg_acc) self.val_acc.clear() self.val_loss.clear() def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer # Prepare MNIST Datasets transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) mnist_train = MNIST( './data', train=True, download=True, transform=transform ) mnist_val = MNIST( './data', train=False, download=True, transform=transform ) # Take small subsets for smoke test # Please remove these two lines if you want to train the full dataset mnist_train = Subset(mnist_train, range(1000)) mnist_train = Subset(mnist_train, range(500)) train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True) val_loader = DataLoader(mnist_val, batch_size=128, shuffle=False) lightning_config = ( LightningConfigBuilder() .module(cls=MNISTClassifier, lr=1e-3, feature_dim=128) .trainer(max_epochs=3, accelerator="cpu") .fit_params(train_dataloaders=train_loader, val_dataloaders=val_loader) .build() ) scaling_config = ScalingConfig( num_workers=4, use_gpu=False, resources_per_worker={"CPU": 1} ) trainer = LightningTrainer( lightning_config=lightning_config, scaling_config=scaling_config, ) result = trainer.fit() result
- Parameters
lightning_config – Configuration for setting up the Pytorch Lightning Trainer. You can setup the configurations with
LightningConfigBuilder, and generate this config dictionary throughLightningBuilder.build().torch_config – Configuration for setting up the PyTorch backend. If set to None, use the default configuration. This replaces the
backend_configarg ofDataParallelTrainer. Same as inTorchTrainer.scaling_config – Configuration for how to scale data parallel training.
dataset_config – Configuration for dataset ingest.
run_config – Configuration for the execution of the training run.
datasets –
A dictionary of Ray Datasets to use for training. Use the key “train” to denote which dataset is the training dataset and (optionally) key “val” to denote the validation dataset. Internally, LightningTrainer shards the training dataset across all workers, and creates a PyTorch Dataloader for each shard.
The datasets will be transformed by
preprocessorif it is provided. If thepreprocessorhas not already been fit, it will be fit on the training dataset.If
datasetsis not specified,LightningTrainerwill use datamodule or dataloaders specified inLightningConfigBuilder.fit_paramsinstead.datasets_iter_config –
Configuration for iterating over the input ray datasets. You can configure the per-device batch size, prefetch batch size, collate function, and more. For valid arguments to pass, please refer to:
Dataset.iter_torch_batchesNote that if you provide a
datasetsparameter, you must always specifydatasets_iter_configfor it.preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.
resume_from_checkpoint – A checkpoint to resume training from.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods
Converts self to a
tune.Trainableclass.can_restore(path)Checks whether a given directory contains a restorable Train experiment.
fit()Runs training.
Returns a copy of this Trainer's final dataset configs.
restore(path[, datasets, preprocessor, ...])Restores a LightningTrainer from a previously interrupted/failed run.
setup()Called during fit() to perform initial setup on the Trainer.