ray.train.lightning.LightningTrainer#

class ray.train.lightning.LightningTrainer(*args, **kwargs)[source]#

Bases: ray.train.torch.torch_trainer.TorchTrainer

A Trainer for data parallel PyTorch Lightning training.

This Trainer runs the pytorch_lightning.Trainer.fit() method on multiple Ray Actors. The training is carried out in a distributed fashion through PyTorch DDP. These actors already have the necessary Torch process group configured for distributed data parallel training. We will support more distributed training strategies in the future.

The training function ran on every Actor will first initialize an instance of the user-provided lightning_module class, which is a subclass of pytorch_lightning.LightningModule using the arguments provided in LightningConfigBuilder.module().

For data ingestion, the LightningTrainer will then either convert the Ray Dataset shards to a pytorch_lightning.LightningDataModule, or directly use the datamodule or dataloaders if provided by users.

The trainer also creates a ModelCheckpoint callback based on the configuration provided in LightningConfigBuilder.checkpointing(). In addition to checkpointing, this callback also calls train.report() to report the latest metrics along with the checkpoint.

For logging, users can continue to use Lightning’s native loggers, such as WandbLogger, TensorboardLogger, etc. LightningTrainer will also log the latest metrics to the training results directory whenever a new checkpoint is saved.

Then, the training function will initialize an instance of pl.Trainer using the arguments provided in LightningConfigBuilder.fit_params() and then run pytorch_lightning.Trainer.fit.

Example

import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from ray.air.config import ScalingConfig
from ray.train.lightning import LightningTrainer, LightningConfigBuilder


class MNISTClassifier(pl.LightningModule):
    def __init__(self, lr, feature_dim):
        super(MNISTClassifier, self).__init__()
        self.fc1 = torch.nn.Linear(28 * 28, feature_dim)
        self.fc2 = torch.nn.Linear(feature_dim, 10)
        self.lr = lr
        self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
        self.val_loss = []
        self.val_acc = []

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        self.val_loss.append(loss)
        self.val_acc.append(acc)
        return {"val_loss": loss, "val_accuracy": acc}

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.val_loss).mean()
        avg_acc = torch.stack(self.val_acc).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)
        self.val_acc.clear()
        self.val_loss.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

# Prepare MNIST Datasets
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
mnist_train = MNIST(
    './data', train=True, download=True, transform=transform
)
mnist_val = MNIST(
    './data', train=False, download=True, transform=transform
)

# Take small subsets for smoke test
# Please remove these two lines if you want to train the full dataset
mnist_train = Subset(mnist_train, range(1000))
mnist_train = Subset(mnist_train, range(500))

train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
val_loader = DataLoader(mnist_val, batch_size=128, shuffle=False)

lightning_config = (
    LightningConfigBuilder()
    .module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)
    .trainer(max_epochs=3, accelerator="cpu")
    .fit_params(train_dataloaders=train_loader, val_dataloaders=val_loader)
    .build()
)

scaling_config = ScalingConfig(
    num_workers=4, use_gpu=False, resources_per_worker={"CPU": 1}
)
trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
)
result = trainer.fit()
result
Parameters
  • lightning_config – Configuration for setting up the Pytorch Lightning Trainer. You can setup the configurations with LightningConfigBuilder, and generate this config dictionary through LightningBuilder.build().

  • torch_config – Configuration for setting up the PyTorch backend. If set to None, use the default configuration. This replaces the backend_config arg of DataParallelTrainer. Same as in TorchTrainer.

  • scaling_config – Configuration for how to scale data parallel training.

  • dataset_config – Configuration for dataset ingest.

  • run_config – Configuration for the execution of the training run.

  • datasets

    A dictionary of Ray Datasets to use for training. Use the key “train” to denote which dataset is the training dataset and (optionally) key “val” to denote the validation dataset. Internally, LightningTrainer shards the training dataset across all workers, and creates a PyTorch Dataloader for each shard.

    The datasets will be transformed by preprocessor if it is provided. If the preprocessor has not already been fit, it will be fit on the training dataset.

    If datasets is not specified, LightningTrainer will use datamodule or dataloaders specified in LightningConfigBuilder.fit_params instead.

  • datasets_iter_config

    Configuration for iterating over the input ray datasets. You can configure the per-device batch size, prefetch batch size, collate function, and more. For valid arguments to pass, please refer to: Dataset.iter_torch_batches

    Note that if you provide a datasets parameter, you must always specify datasets_iter_config for it.

  • preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.

  • resume_from_checkpoint – A checkpoint to resume training from.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

Methods

as_trainable()

Converts self to a tune.Trainable class.

can_restore(path)

Checks whether a given directory contains a restorable Train experiment.

fit()

Runs training.

get_dataset_config()

Returns a copy of this Trainer's final dataset configs.

restore(path[, datasets, preprocessor, ...])

Restores a LightningTrainer from a previously interrupted/failed run.

setup()

Called during fit() to perform initial setup on the Trainer.