Train a Pytorch Lightning Image Classifier#

This example introduces how to train a Pytorch Lightning Module using AIR LightningTrainer. We will demonstrate how to train a basic neural network on the MNIST dataset with distributed data parallelism.

!pip install "torchmetrics>=0.9" "pytorch_lightning>=1.6" 
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock
from torch.utils.data import DataLoader, random_split, Subset
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import trainer
from pytorch_lightning.loggers.csv_logs import CSVLogger

Prepare Dataset and Module#

The Pytorch Lightning Trainer takes either torch.utils.data.DataLoader or pl.LightningDataModule as data inputs. You can keep using them without any changes for the Ray AIR LightningTrainer.

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=100):
        super().__init__()
        self.data_dir = os.getcwd()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        with FileLock(f"{self.data_dir}.lock"):
            mnist = MNIST(
                self.data_dir, train=True, download=True, transform=self.transform
            )

            # split data into train and val sets
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        with FileLock(f"{self.data_dir}.lock"):
            self.mnist_test = MNIST(
                self.data_dir, train=False, download=True, transform=self.transform
            )
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)


datamodule = MNISTDataModule(batch_size=128)

Next, define a simple multi-layer perception as the subclass of pl.LightningModule.

class MNISTClassifier(pl.LightningModule):
    def __init__(self, lr=1e-3, feature_dim=128):
        torch.manual_seed(421)
        super(MNISTClassifier, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, 10),
            nn.ReLU(),
        )
        self.lr = lr
        self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
        self.eval_loss = []
        self.eval_accuracy = []
        self.test_accuracy = []
        pl.seed_everything(888)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.linear_relu_stack(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        loss, acc = self._shared_eval(val_batch)
        self.log("val_accuracy", acc)
        self.eval_loss.append(loss)
        self.eval_accuracy.append(acc)
        return {"val_loss": loss, "val_accuracy": acc}

    def test_step(self, test_batch, batch_idx):
        loss, acc = self._shared_eval(test_batch)
        self.test_accuracy.append(acc)
        self.log("test_accuracy", acc, sync_dist=True, on_epoch=True)
        return {"test_loss": loss, "test_accuracy": acc}

    def _shared_eval(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        return loss, acc

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.log("val_accuracy", avg_acc, sync_dist=True)
        self.eval_loss.clear()
        self.eval_accuracy.clear()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

You don’t need to make any change to the definition of PyTorch Lightning model and datamodule.

Define the Cofigurations for AIR LightningTrainer#

The LightningConfigBuilder class stores all the parameters involved in training a PyTorch Lightning module. It takes the same parameter lists as those in PyTorch Lightning.

  • The .module() method takes a subclass of pl.LightningModule and its initialization parameters. LightningTrainer will instantiate a model instance internally in the workers’ training loop.

  • The .trainer() method takes the initialization parameters of pl.Trainer. You can specify training configurations, loggers, and callbacks here.

  • The .fit_params() method stores all the parameters that will be passed into pl.Trainer.fit(), including train/val dataloaders, datamodules, and checkpoint paths.

  • The .checkpointing() method saves the configurations for a RayModelCheckpoint callback. This callback reports the latest metrics to the AIR session along with a newly saved checkpoint.

  • The .build() method generates a dictionary that contains all the configurations in the builder. This dictionary will be passed to LightningTrainer later.

Next, let’s go step-by-step to see how to convert your existing PyTorch Lightning training script to a LightningTrainer.

from pytorch_lightning.callbacks import ModelCheckpoint
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.lightning import (
    LightningTrainer,
    LightningConfigBuilder,
    LightningCheckpoint,
)


def build_lightning_config_from_existing_code(use_gpu):
    # Create a config builder to encapsulate all required parameters.
    # Note that model instantiation and fitting will occur later in the LightingTrainer,
    # rather than in the config builder.
    config_builder = LightningConfigBuilder()

    # 1. define your model
    # model = MNISTClassifier(lr=1e-3, feature_dim=128)
    config_builder.module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)

    # 2. define a ModelCheckpoint callback
    # checkpoint_callback = ModelCheckpoint(
    #     monitor="val_accuracy", mode="max", save_top_k=3
    # )
    config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)

    # 3. Define a Lightning trainer
    # trainer = pl.Trainer(
    #     max_epochs=10,
    #     accelerator="cpu",
    #     strategy="ddp",
    #     log_every_n_steps=100,
    #     logger=CSVLogger("logs"),
    #     callbacks=[checkpoint_callback],
    # )
    config_builder.trainer(
        max_epochs=10,
        accelerator="gpu" if use_gpu else "cpu",
        log_every_n_steps=100,
        logger=CSVLogger("logs"),
    )
    # You do not need to provide the checkpoint callback and strategy here,
    # since LightningTrainer configures them automatically.
    # You can also add any other callbacks into LightningConfigBuilder.trainer().

    # 4. Parameters for model fitting
    # trainer.fit(model, datamodule=datamodule)
    config_builder.fit_params(datamodule=datamodule)

    # Finally, compile all the configs into a dictionary for LightningTrainer
    lightning_config = config_builder.build()
    return lightning_config

Now put everything together:

use_gpu = True # Set it to False if you want to run without GPUs
num_workers = 4
lightning_config = build_lightning_config_from_existing_code(use_gpu=use_gpu)

scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

run_config = RunConfig(
    name="ptl-mnist-example",
    storage_path="/tmp/ray_results",
    checkpoint_config=CheckpointConfig(
        num_to_keep=3,
        checkpoint_score_attribute="val_accuracy",
        checkpoint_score_order="max",
    ),
)

trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

Now fit your trainer:

result = trainer.fit()
print("Validation Accuracy: ", result.metrics["val_accuracy"])
result
2023-06-13 16:05:12,869	INFO worker.py:1452 -- Connecting to existing Ray cluster at address: 10.0.28.253:6379...
2023-06-13 16:05:12,877	INFO worker.py:1627 -- Connected to Ray cluster. View the dashboard at https://console.anyscale-staging.com/api/v2/sessions/ses_15dlj65vax84ljl7ayeplubryd/services?redirect_to=dashboard 
2023-06-13 16:05:13,036	INFO packaging.py:347 -- Pushing file package 'gcs://_ray_pkg_488e346d50f332edaa288fdaa22b2bdc.zip' (52.65MiB) to Ray cluster...
2023-06-13 16:05:13,221	INFO packaging.py:360 -- Successfully pushed file package 'gcs://_ray_pkg_488e346d50f332edaa288fdaa22b2bdc.zip'.
2023-06-13 16:05:13,314	INFO tune.py:226 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `Trainer(...)`.

Tune Status

Current time:2023-06-13 16:05:52
Running for: 00:00:39.29
Memory: 5.5/30.9 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 1.0/32 CPUs, 4.0/4 GPUs

Trial Status

Trial name status loc iter total time (s) train_loss val_accuracy val_loss
LightningTrainer_c0d28_00000TERMINATED10.0.28.253:16995 10 28.5133 0.0315991 0.970002 -12.3467
(pid=16995) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.
(pid=16995)   from pandas import MultiIndex, Int64Index
(LightningTrainer pid=16995) 2023-06-13 16:05:24,007	INFO backend_executor.py:137 -- Starting distributed worker processes: ['17232 (10.0.28.253)', '6371 (10.0.1.80)', '7319 (10.0.58.90)', '6493 (10.0.26.229)']
(RayTrainWorker pid=17232) 2023-06-13 16:05:24,966	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]
(RayTrainWorker pid=17232)   from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=17232)   from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=7319, ip=10.0.58.90) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.
(RayTrainWorker pid=7319, ip=10.0.58.90)   from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=17232) Global seed set to 888
(RayTrainWorker pid=17232) GPU available: True, used: True
(RayTrainWorker pid=17232) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=17232) IPU available: False, using: 0 IPUs
(RayTrainWorker pid=17232) HPU available: False, using: 0 HPUs
(RayTrainWorker pid=6371, ip=10.0.1.80) Missing logger folder: logs/lightning_logs
(RayTrainWorker pid=6371, ip=10.0.1.80) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
(RayTrainWorker pid=17232) 
(RayTrainWorker pid=17232)   | Name              | Type       | Params
(RayTrainWorker pid=17232) -------------------------------------------------
(RayTrainWorker pid=17232) 0 | linear_relu_stack | Sequential | 101 K 
(RayTrainWorker pid=17232) 1 | accuracy          | Accuracy   | 0     
(RayTrainWorker pid=17232) -------------------------------------------------
(RayTrainWorker pid=17232) 101 K     Trainable params
(RayTrainWorker pid=17232) 0         Non-trainable params
(RayTrainWorker pid=17232) 101 K     Total params
(RayTrainWorker pid=17232) 0.407     Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s] 
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  1.33it/s]
Epoch 0:   0%|          | 0/118 [00:00<?, ?it/s]                           
(RayTrainWorker pid=6493, ip=10.0.26.229) [W reducer.cpp:1298] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
(RayTrainWorker pid=6371, ip=10.0.1.80) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead. [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=6371, ip=10.0.1.80)   from pandas import MultiIndex, Int64Index [repeated 2x across cluster]
Epoch 0:   3%|▎         | 4/118 [00:00<00:07, 16.07it/s, loss=2.09, v_num=0]
Epoch 0:   4%|▍         | 5/118 [00:00<00:05, 19.42it/s, loss=2.09, v_num=0]
Epoch 0:  12%|█▏        | 14/118 [00:00<00:02, 39.49it/s, loss=1.55, v_num=0]
Epoch 0:  12%|█▏        | 14/118 [00:00<00:02, 38.73it/s, loss=1.5, v_num=0] 
Epoch 0:  21%|██        | 25/118 [00:00<00:01, 53.89it/s, loss=0.933, v_num=0]
Epoch 0:  30%|██▉       | 35/118 [00:00<00:01, 61.80it/s, loss=0.522, v_num=0]
Epoch 0:  38%|███▊      | 45/118 [00:00<00:01, 67.21it/s, loss=0.425, v_num=0]
Epoch 0:  45%|████▍     | 53/118 [00:00<00:00, 69.59it/s, loss=0.379, v_num=0]
Epoch 0:  46%|████▌     | 54/118 [00:00<00:00, 69.65it/s, loss=0.373, v_num=0]
Epoch 0:  54%|█████▍    | 64/118 [00:00<00:00, 73.24it/s, loss=0.364, v_num=0]
Epoch 0:  62%|██████▏   | 73/118 [00:00<00:00, 74.68it/s, loss=0.341, v_num=0]
Epoch 0:  63%|██████▎   | 74/118 [00:00<00:00, 75.21it/s, loss=0.341, v_num=0]
Epoch 0:  70%|███████   | 83/118 [00:01<00:00, 76.62it/s, loss=0.335, v_num=0]
Epoch 0:  80%|███████▉  | 94/118 [00:01<00:00, 79.16it/s, loss=0.297, v_num=0]
Epoch 0:  90%|████████▉ | 106/118 [00:01<00:00, 82.26it/s, loss=0.281, v_num=0]
Epoch 0:  92%|█████████▏| 108/118 [00:01<00:00, 83.04it/s, loss=0.284, v_num=0]
Validation: 0it [00:00, ?it/s]2) 
(RayTrainWorker pid=17232) 
Validation:   0%|          | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 0:  92%|█████████▏| 109/118 [00:01<00:00, 73.67it/s, loss=0.284, v_num=0]
Epoch 0:  93%|█████████▎| 110/118 [00:01<00:00, 74.14it/s, loss=0.284, v_num=0]
Epoch 0:  94%|█████████▍| 111/118 [00:01<00:00, 74.57it/s, loss=0.284, v_num=0]
(RayTrainWorker pid=17232) 
Epoch 0:  95%|█████████▍| 112/118 [00:01<00:00, 73.94it/s, loss=0.284, v_num=0]
Epoch 0:  96%|█████████▌| 113/118 [00:01<00:00, 74.45it/s, loss=0.284, v_num=0]
Epoch 0:  97%|█████████▋| 114/118 [00:01<00:00, 74.96it/s, loss=0.284, v_num=0]
Epoch 0:  97%|█████████▋| 115/118 [00:01<00:00, 75.47it/s, loss=0.284, v_num=0]
Epoch 0:  98%|█████████▊| 116/118 [00:01<00:00, 75.05it/s, loss=0.284, v_num=0]
Epoch 0:  99%|█████████▉| 117/118 [00:01<00:00, 75.55it/s, loss=0.284, v_num=0]
Epoch 0: 100%|██████████| 118/118 [00:01<00:00, 75.21it/s, loss=0.284, v_num=0]
Epoch 0: 100%|██████████| 118/118 [00:01<00:00, 75.17it/s, loss=0.284, v_num=0]

Trial Progress

Trial name _report_on date done epoch experiment_taghostname iterations_since_restorenode_ip pidshould_checkpoint step time_since_restore time_this_iter_s time_total_s timestamp train_loss training_iterationtrial_id val_accuracy val_loss
LightningTrainer_c0d28_00000train_epoch_end2023-06-13_16-05-50True 9 0ip-10-0-28-253 1010.0.28.25316995True 1080 28.5133 1.73311 28.5133 1686697550 0.0315991 10c0d28_00000 0.970002 -12.3467
Epoch 1:   0%|          | 0/118 [00:00<?, ?it/s, loss=0.284, v_num=0]          
Epoch 1:   2%|▏         | 2/118 [00:00<00:15,  7.71it/s, loss=0.283, v_num=0]
Epoch 1:  11%|█         | 13/118 [00:00<00:02, 35.75it/s, loss=0.268, v_num=0]
Epoch 1:  20%|██        | 24/118 [00:00<00:01, 51.49it/s, loss=0.253, v_num=0]
Epoch 1:  28%|██▊       | 33/118 [00:00<00:01, 57.86it/s, loss=0.252, v_num=0]
Epoch 1:  36%|███▋      | 43/118 [00:00<00:01, 64.22it/s, loss=0.244, v_num=0]
Epoch 1:  37%|███▋      | 44/118 [00:00<00:01, 64.96it/s, loss=0.244, v_num=0]
Epoch 1:  37%|███▋      | 44/118 [00:00<00:01, 64.66it/s, loss=0.245, v_num=0]
Epoch 1:  46%|████▌     | 54/118 [00:00<00:00, 69.28it/s, loss=0.241, v_num=0]
Epoch 1:  55%|█████▌    | 65/118 [00:00<00:00, 73.79it/s, loss=0.245, v_num=0]
Epoch 1:  64%|██████▎   | 75/118 [00:00<00:00, 75.85it/s, loss=0.22, v_num=0] 
Epoch 1:  64%|██████▎   | 75/118 [00:00<00:00, 75.83it/s, loss=0.222, v_num=0]
Epoch 1:  72%|███████▏  | 85/118 [00:01<00:00, 78.43it/s, loss=0.203, v_num=0]
Epoch 1:  73%|███████▎  | 86/118 [00:01<00:00, 78.66it/s, loss=0.203, v_num=0]
Epoch 1:  81%|████████  | 95/118 [00:01<00:00, 79.71it/s, loss=0.199, v_num=0]
Epoch 1:  92%|█████████▏| 108/118 [00:01<00:00, 83.67it/s, loss=0.206, v_num=0]
(RayTrainWorker pid=17232) 
Validation: 0it [00:00, ?it/s]2) 
(RayTrainWorker pid=17232) 
Validation:   0%|          | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 1:  92%|█████████▏| 109/118 [00:01<00:00, 73.73it/s, loss=0.206, v_num=0]
Epoch 1:  93%|█████████▎| 110/118 [00:01<00:00, 74.00it/s, loss=0.206, v_num=0]
Epoch 1:  94%|█████████▍| 111/118 [00:01<00:00, 74.36it/s, loss=0.206, v_num=0]
Epoch 1:  95%|█████████▍| 112/118 [00:01<00:00, 74.72it/s, loss=0.206, v_num=0]
Epoch 1:  96%|█████████▌| 113/118 [00:01<00:00, 75.08it/s, loss=0.206, v_num=0]
(RayTrainWorker pid=17232) 
Epoch 1:  97%|█████████▋| 114/118 [00:01<00:00, 75.42it/s, loss=0.206, v_num=0]
Epoch 1:  97%|█████████▋| 115/118 [00:01<00:00, 75.77it/s, loss=0.206, v_num=0]
Epoch 1:  98%|█████████▊| 116/118 [00:01<00:00, 76.08it/s, loss=0.206, v_num=0]
Epoch 1:  99%|█████████▉| 117/118 [00:01<00:00, 76.59it/s, loss=0.206, v_num=0]
Epoch 1: 100%|██████████| 118/118 [00:01<00:00, 76.69it/s, loss=0.206, v_num=0]
Epoch 1: 100%|██████████| 118/118 [00:01<00:00, 76.64it/s, loss=0.206, v_num=0]
Epoch 2:   0%|          | 0/118 [00:00<?, ?it/s, loss=0.206, v_num=0]          
Epoch 2:   5%|▌         | 6/118 [00:00<00:05, 19.96it/s, loss=0.187, v_num=0]
Epoch 2:   5%|▌         | 6/118 [00:00<00:05, 19.93it/s, loss=0.188, v_num=0]
Epoch 2:  14%|█▎        | 16/118 [00:00<00:02, 39.92it/s, loss=0.176, v_num=0]
Epoch 2:  22%|██▏       | 26/118 [00:00<00:01, 51.69it/s, loss=0.183, v_num=0]
Epoch 2:  31%|███       | 36/118 [00:00<00:01, 59.53it/s, loss=0.18, v_num=0] 
Epoch 2:  31%|███▏      | 37/118 [00:00<00:01, 60.44it/s, loss=0.182, v_num=0]
Epoch 2:  41%|████      | 48/118 [00:00<00:01, 67.23it/s, loss=0.178, v_num=0]
Epoch 2:  49%|████▉     | 58/118 [00:00<00:00, 71.86it/s, loss=0.182, v_num=0]
Epoch 2:  57%|█████▋    | 67/118 [00:00<00:00, 73.02it/s, loss=0.177, v_num=0]
Epoch 2:  65%|██████▌   | 77/118 [00:01<00:00, 75.08it/s, loss=0.155, v_num=0]
Epoch 2:  74%|███████▎  | 87/118 [00:01<00:00, 77.13it/s, loss=0.157, v_num=0]
Epoch 2:  81%|████████▏ | 96/118 [00:01<00:00, 78.76it/s, loss=0.162, v_num=0]
Epoch 2:  92%|█████████▏| 108/118 [00:01<00:00, 81.91it/s, loss=0.149, v_num=0]
(RayTrainWorker pid=17232) 
Validation: 0it [00:00, ?it/s]2) 
(RayTrainWorker pid=17232) 
Validation:   0%|          | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 2:  92%|█████████▏| 109/118 [00:01<00:00, 71.87it/s, loss=0.149, v_num=0]
Epoch 2:  93%|█████████▎| 110/118 [00:01<00:00, 72.36it/s, loss=0.149, v_num=0]
Epoch 2:  94%|█████████▍| 111/118 [00:01<00:00, 72.87it/s, loss=0.149, v_num=0]
Epoch 2:  95%|█████████▍| 112/118 [00:01<00:00, 73.22it/s, loss=0.149, v_num=0]
(RayTrainWorker pid=17232) 
Epoch 2:  96%|█████████▌| 113/118 [00:01<00:00, 73.13it/s, loss=0.149, v_num=0]
Epoch 2:  97%|█████████▋| 114/118 [00:01<00:00, 73.63it/s, loss=0.149, v_num=0]
Epoch 2:  97%|█████████▋| 115/118 [00:01<00:00, 74.14it/s, loss=0.149, v_num=0]
Epoch 2:  98%|█████████▊| 116/118 [00:01<00:00, 74.65it/s, loss=0.149, v_num=0]
Epoch 2:  99%|█████████▉| 117/118 [00:01<00:00, 74.67it/s, loss=0.149, v_num=0]
Epoch 2: 100%|██████████| 118/118 [00:01<00:00, 74.79it/s, loss=0.149, v_num=0]
Epoch 2: 100%|██████████| 118/118 [00:01<00:00, 74.74it/s, loss=0.149, v_num=0]
Epoch 3:   0%|          | 0/118 [00:00<?, ?it/s, loss=0.149, v_num=0]          
Epoch 3:   1%|          | 1/118 [00:00<00:22,  5.27it/s, loss=0.149, v_num=0]
Epoch 3:   7%|▋         | 8/118 [00:00<00:04, 27.40it/s, loss=0.144, v_num=0]
Epoch 3:   7%|▋         | 8/118 [00:00<00:04, 26.95it/s, loss=0.143, v_num=0]
Epoch 3:  14%|█▎        | 16/118 [00:00<00:02, 40.27it/s, loss=0.13, v_num=0] 
Epoch 3:  22%|██▏       | 26/118 [00:00<00:01, 52.98it/s, loss=0.122, v_num=0]
Epoch 3:  30%|██▉       | 35/118 [00:00<00:01, 58.42it/s, loss=0.128, v_num=0]
Epoch 3:  31%|███       | 36/118 [00:00<00:01, 59.33it/s, loss=0.128, v_num=0]
Epoch 3:  39%|███▉      | 46/118 [00:00<00:01, 65.01it/s, loss=0.124, v_num=0]
Epoch 3:  47%|████▋     | 55/118 [00:00<00:00, 67.90it/s, loss=0.138, v_num=0]
Epoch 3:  55%|█████▌    | 65/118 [00:00<00:00, 71.30it/s, loss=0.145, v_num=0]
Epoch 3:  55%|█████▌    | 65/118 [00:00<00:00, 70.84it/s, loss=0.152, v_num=0]
Epoch 3:  64%|██████▍   | 76/118 [00:01<00:00, 74.81it/s, loss=0.138, v_num=0]
Epoch 3:  71%|███████   | 84/118 [00:01<00:00, 74.74it/s, loss=0.125, v_num=0]
Epoch 3:  80%|███████▉  | 94/118 [00:01<00:00, 76.91it/s, loss=0.124, v_num=0]
Epoch 3:  81%|████████  | 95/118 [00:01<00:00, 77.30it/s, loss=0.124, v_num=0]
Epoch 3:  90%|████████▉ | 106/118 [00:01<00:00, 79.61it/s, loss=0.119, v_num=0]
Epoch 3:  92%|█████████▏| 108/118 [00:01<00:00, 80.41it/s, loss=0.123, v_num=0]
Validation: 0it [00:00, ?it/s]2) 
(RayTrainWorker pid=17232) 
Validation:   0%|          | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]
(RayTrainWorker pid=17232) 
Epoch 3:  92%|█████████▏| 109/118 [00:01<00:00, 70.82it/s, loss=0.123, v_num=0]
(RayTrainWorker pid=17232) 
Epoch 3:  93%|█████████▎| 110/118 [00:01<00:00, 71.26it/s, loss=0.123, v_num=0]
Epoch 3:  94%|█████████▍| 111/118 [00:01<00:00, 71.60it/s, loss=0.123, v_num=0]
Epoch 3:  95%|█████████▍| 112/118 [00:01<00:00, 71.92it/s, loss=0.123, v_num=0]
Epoch 3:  96%|█████████▌| 113/118 [00:01<00:00, 72.20it/s, loss=0.123, v_num=0]
Epoch 3:  97%|█████████▋| 114/118 [00:01<00:00, 72.64it/s, loss=0.123, v_num=0]
Epoch 3:  97%|█████████▋| 115/118 [00:01<00:00, 73.05it/s, loss=0.123, v_num=0]
Epoch 3:  98%|█████████▊| 116/118 [00:01<00:00, 73.55it/s, loss=0.123, v_num=0]
Epoch 3:  99%|█████████▉| 117/118 [00:01<00:00, 74.00it/s, loss=0.123, v_num=0]
Epoch 3: 100%|██████████| 118/118 [00:01<00:00, 73.28it/s, loss=0.123, v_num=0]
Epoch 3: 100%|██████████| 118/118 [00:01<00:00, 73.23it/s, loss=0.123, v_num=0]
Epoch 4:   0%|          | 0/118 [00:00<?, ?it/s, loss=0.123, v_num=0]          
Epoch 4:   3%|▎         | 3/118 [00:00<00:10, 11.48it/s, loss=0.114, v_num=0]
Epoch 4:  10%|█         | 12/118 [00:00<00:03, 32.93it/s, loss=0.111, v_num=0]
Epoch 4:  18%|█▊        | 21/118 [00:00<00:02, 45.21it/s, loss=0.102, v_num=0]
Epoch 4:  26%|██▋       | 31/118 [00:00<00:01, 54.51it/s, loss=0.11, v_num=0] 
Epoch 4:  35%|███▍      | 41/118 [00:00<00:01, 60.49it/s, loss=0.112, v_num=0]
Epoch 4:  35%|███▍      | 41/118 [00:00<00:01, 60.40it/s, loss=0.109, v_num=0]
Epoch 4:  43%|████▎     | 51/118 [00:00<00:01, 65.31it/s, loss=0.112, v_num=0]
Epoch 4:  43%|████▎     | 51/118 [00:00<00:01, 64.95it/s, loss=0.11, v_num=0] 
Epoch 4:  52%|█████▏    | 61/118 [00:00<00:00, 68.96it/s, loss=0.116, v_num=0]
Epoch 4:  53%|█████▎    | 62/118 [00:00<00:00, 69.63it/s, loss=0.116, v_num=0]
Epoch 4:  61%|██████    | 72/118 [00:00<00:00, 72.69it/s, loss=0.12, v_num=0] 
Epoch 4:  61%|██████    | 72/118 [00:00<00:00, 72.35it/s, loss=0.119, v_num=0]
Epoch 4:  69%|██████▊   | 81/118 [00:01<00:00, 74.46it/s, loss=0.124, v_num=0]
Epoch 4:  69%|██████▊   | 81/118 [00:01<00:00, 73.59it/s, loss=0.121, v_num=0]
Epoch 4:  78%|███████▊  | 92/118 [00:01<00:00, 76.35it/s, loss=0.105, v_num=0]
Epoch 4:  78%|███████▊  | 92/118 [00:01<00:00, 76.33it/s, loss=0.108, v_num=0]
Epoch 4:  87%|████████▋ | 103/118 [00:01<00:00, 78.93it/s, loss=0.0973, v_num=0]
Epoch 4:  92%|█████████▏| 108/118 [00:01<00:00, 80.62it/s, loss=0.107, v_num=0] 
Validation: 0it [00:00, ?it/s]2) 
(RayTrainWorker pid=17232) 
Validation:   0%|          | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 4:  92%|█████████▏| 109/118 [00:01<00:00, 70.57it/s, loss=0.107, v_num=0]
Epoch 4:  93%|█████████▎| 110/118 [00:01<00:00, 71.05it/s, loss=0.107, v_num=0]
Epoch 4:  94%|█████████▍| 111/118 [00:01<00:00, 71.56it/s, loss=0.107, v_num=0]
Epoch 4:  95%|█████████▍| 112/118 [00:01<00:00, 71.92it/s, loss=0.107, v_num=0]
Epoch 4:  96%|█████████▌| 113/118 [00:01<00:00, 72.04it/s, loss=0.107, v_num=0]
Epoch 4:  97%|█████████▋| 114/118 [00:01<00:00, 72.52it/s, loss=0.107, v_num=0]
Epoch 4:  97%|█████████▋| 115/118 [00:01<00:00, 73.01it/s, loss=0.107, v_num=0]
Epoch 4:  98%|█████████▊| 116/118 [00:01<00:00, 73.42it/s, loss=0.107, v_num=0]
Epoch 4:  99%|█████████▉| 117/118 [00:01<00:00, 73.52it/s, loss=0.107, v_num=0]
Epoch 4: 100%|██████████| 118/118 [00:01<00:00, 73.64it/s, loss=0.107, v_num=0]
Epoch 4: 100%|██████████| 118/118 [00:01<00:00, 73.59it/s, loss=0.107, v_num=0]
Epoch 5:   0%|          | 0/118 [00:00<?, ?it/s, loss=0.107, v_num=0]          
Epoch 5:   2%|▏         | 2/118 [00:00<00:13,  8.38it/s, loss=0.103, v_num=0]
Epoch 5:   8%|▊         | 10/118 [00:00<00:03, 29.51it/s, loss=0.101, v_num=0] 
Epoch 5:  18%|█▊        | 21/118 [00:00<00:02, 47.55it/s, loss=0.103, v_num=0] 
Epoch 5:  26%|██▋       | 31/118 [00:00<00:01, 56.79it/s, loss=0.0998, v_num=0]
Epoch 5:  34%|███▍      | 40/118 [00:00<00:01, 61.21it/s, loss=0.104, v_num=0] 
Epoch 5:  42%|████▏     | 50/118 [00:00<00:01, 66.50it/s, loss=0.0978, v_num=0]
Epoch 5:  43%|████▎     | 51/118 [00:00<00:00, 67.10it/s, loss=0.0978, v_num=0]
Epoch 5:  53%|█████▎    | 62/118 [00:00<00:00, 71.68it/s, loss=0.0933, v_num=0]
Epoch 5:  60%|██████    | 71/118 [00:00<00:00, 73.91it/s, loss=0.0864, v_num=0]
Epoch 5:  61%|██████    | 72/118 [00:00<00:00, 74.28it/s, loss=0.0864, v_num=0]
Epoch 5:  69%|██████▊   | 81/118 [00:01<00:00, 75.92it/s, loss=0.0845, v_num=0]
Epoch 5:  69%|██████▉   | 82/118 [00:01<00:00, 76.33it/s, loss=0.0845, v_num=0]
Epoch 5:  78%|███████▊  | 92/118 [00:01<00:00, 78.53it/s, loss=0.102, v_num=0] 
Epoch 5:  87%|████████▋ | 103/118 [00:01<00:00, 80.60it/s, loss=0.109, v_num=0]
Epoch 5:  92%|█████████▏| 108/118 [00:01<00:00, 82.44it/s, loss=0.105, v_num=0]
Validation: 0it [00:00, ?it/s]2) 
(RayTrainWorker pid=17232) 
Validation:   0%|          | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 5:  92%|█████████▏| 109/118 [00:01<00:00, 72.14it/s, loss=0.105, v_num=0]
Epoch 5:  93%|█████████▎| 110/118 [00:01<00:00, 72.45it/s, loss=0.105, v_num=0]
Epoch 5:  94%|█████████▍| 111/118 [00:01<00:00, 72.86it/s, loss=0.105, v_num=0]
Epoch 5:  95%|█████████▍| 112/118 [00:01<00:00, 73.21it/s, loss=0.105, v_num=0]
Epoch 5:  96%|█████████▌| 113/118 [00:01<00:00, 73.55it/s, loss=0.105, v_num=0]
Epoch 5:  97%|█████████▋| 114/118 [00:01<00:00, 73.75it/s, loss=0.105, v_num=0]
Epoch 5:  97%|█████████▋| 115/118 [00:01<00:00, 74.15it/s, loss=0.105, v_num=0]
Epoch 5:  98%|█████████▊| 116/118 [00:01<00:00, 74.65it/s, loss=0.105, v_num=0]
Epoch 5:  99%|█████████▉| 117/118 [00:01<00:00, 75.10it/s, loss=0.105, v_num=0]
Epoch 5: 100%|██████████| 118/118 [00:01<00:00, 74.97it/s, loss=0.105, v_num=0]
Epoch 5: 100%|██████████| 118/118 [00:01<00:00, 74.92it/s, loss=0.105, v_num=0]
Epoch 6:   0%|          | 0/118 [00:00<?, ?it/s, loss=0.105, v_num=0]          
Epoch 6:   2%|▏         | 2/118 [00:00<00:13,  8.61it/s, loss=0.0952, v_num=0]
Epoch 6:   8%|▊         | 10/118 [00:00<00:03, 29.86it/s, loss=0.0742, v_num=0]
Epoch 6:  17%|█▋        | 20/118 [00:00<00:02, 45.73it/s, loss=0.0701, v_num=0]
Epoch 6:  25%|██▍       | 29/118 [00:00<00:01, 53.55it/s, loss=0.0818, v_num=0]
Epoch 6:  34%|███▍      | 40/118 [00:00<00:01, 61.80it/s, loss=0.0876, v_num=0]
Epoch 6:  34%|███▍      | 40/118 [00:00<00:01, 61.39it/s, loss=0.0874, v_num=0]
Epoch 6:  43%|████▎     | 51/118 [00:00<00:00, 67.65it/s, loss=0.09, v_num=0]  
Epoch 6:  52%|█████▏    | 61/118 [00:00<00:00, 71.13it/s, loss=0.0883, v_num=0]
Epoch 6:  60%|██████    | 71/118 [00:00<00:00, 73.82it/s, loss=0.08, v_num=0]  
Epoch 6:  69%|██████▉   | 82/118 [00:01<00:00, 77.07it/s, loss=0.0791, v_num=0]
Epoch 6:  69%|██████▉   | 82/118 [00:01<00:00, 76.84it/s, loss=0.0778, v_num=0]
Epoch 6:  77%|███████▋  | 91/118 [00:01<00:00, 78.10it/s, loss=0.0764, v_num=0]
Epoch 6:  78%|███████▊  | 92/118 [00:01<00:00, 78.52it/s, loss=0.0764, v_num=0]
Epoch 6:  84%|████████▍ | 99/118 [00:01<00:00, 78.46it/s, loss=0.0668, v_num=0]
Epoch 6:  92%|█████████▏| 108/118 [00:01<00:00, 81.21it/s, loss=0.0822, v_num=0]
Validation: 0it [00:00, ?it/s]2) 
(RayTrainWorker pid=17232) 
Validation:   0%|          | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 6:  92%|█████████▏| 109/118 [00:01<00:00, 70.99it/s, loss=0.0822, v_num=0]
Epoch 6:  93%|█████████▎| 110/118 [00:01<00:00, 71.41it/s, loss=0.0822, v_num=0]
Epoch 6:  94%|█████████▍| 111/118 [00:01<00:00, 71.74it/s, loss=0.0822, v_num=0]
Epoch 6:  95%|█████████▍| 112/118 [00:01<00:00, 72.00it/s, loss=0.0822, v_num=0]
Epoch 6:  96%|█████████▌| 113/118 [00:01<00:00, 72.26it/s, loss=0.0822, v_num=0]
Epoch 6:  97%|█████████▋| 114/118 [00:01<00:00, 72.67it/s, loss=0.0822, v_num=0]
Epoch 6:  97%|█████████▋| 115/118 [00:01<00:00, 73.08it/s, loss=0.0822, v_num=0]
Epoch 6:  98%|█████████▊| 116/118 [00:01<00:00, 73.40it/s, loss=0.0822, v_num=0]
(RayTrainWorker pid=17232) 
Epoch 6:  99%|█████████▉| 117/118 [00:01<00:00, 73.85it/s, loss=0.0822, v_num=0]
Epoch 6: 100%|██████████| 118/118 [00:01<00:00, 73.97it/s, loss=0.0822, v_num=0]
Epoch 6: 100%|██████████| 118/118 [00:01<00:00, 73.91it/s, loss=0.0822, v_num=0]
Epoch 7:   0%|          | 0/118 [00:00<?, ?it/s, loss=0.0822, v_num=0]          
Epoch 7:   1%|          | 1/118 [00:00<00:24,  4.76it/s, loss=0.0816, v_num=0]
Epoch 7:   8%|▊         | 10/118 [00:00<00:03, 32.68it/s, loss=0.0774, v_num=0]
Epoch 7:  17%|█▋        | 20/118 [00:00<00:02, 48.49it/s, loss=0.0633, v_num=0]
Epoch 7:  25%|██▍       | 29/118 [00:00<00:01, 55.60it/s, loss=0.0682, v_num=0]
Epoch 7:  34%|███▍      | 40/118 [00:00<00:01, 63.73it/s, loss=0.0633, v_num=0]
Epoch 7:  43%|████▎     | 51/118 [00:00<00:00, 70.11it/s, loss=0.0596, v_num=0]
Epoch 7:  43%|████▎     | 51/118 [00:00<00:00, 69.74it/s, loss=0.0599, v_num=0]
Epoch 7:  53%|█████▎    | 62/118 [00:00<00:00, 74.75it/s, loss=0.0601, v_num=0]
Epoch 7:  62%|██████▏   | 73/118 [00:00<00:00, 77.99it/s, loss=0.0696, v_num=0]
Epoch 7:  62%|██████▏   | 73/118 [00:00<00:00, 77.77it/s, loss=0.0715, v_num=0]
Epoch 7:  70%|███████   | 83/118 [00:01<00:00, 80.11it/s, loss=0.0793, v_num=0]
Epoch 7:  71%|███████   | 84/118 [00:01<00:00, 80.45it/s, loss=0.0793, v_num=0]
Epoch 7:  81%|████████  | 95/118 [00:01<00:00, 82.87it/s, loss=0.0778, v_num=0]
Epoch 7:  91%|█████████ | 107/118 [00:01<00:00, 85.72it/s, loss=0.0743, v_num=0]
Epoch 7:  91%|█████████ | 107/118 [00:01<00:00, 85.50it/s, loss=0.0753, v_num=0]
Epoch 7:  92%|█████████▏| 108/118 [00:01<00:00, 85.88it/s, loss=0.0742, v_num=0]
Validation: 0it [00:00, ?it/s]2) 
(RayTrainWorker pid=17232) 
Validation:   0%|          | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 7:  92%|█████████▏| 109/118 [00:01<00:00, 75.70it/s, loss=0.0742, v_num=0]
Epoch 7:  93%|█████████▎| 110/118 [00:01<00:00, 76.11it/s, loss=0.0742, v_num=0]
Epoch 7:  94%|█████████▍| 111/118 [00:01<00:00, 76.41it/s, loss=0.0742, v_num=0]
(RayTrainWorker pid=17232) 
Epoch 7:  95%|█████████▍| 112/118 [00:01<00:00, 76.72it/s, loss=0.0742, v_num=0]
Epoch 7:  96%|█████████▌| 113/118 [00:01<00:00, 77.06it/s, loss=0.0742, v_num=0]
Epoch 7:  97%|█████████▋| 114/118 [00:01<00:00, 77.34it/s, loss=0.0742, v_num=0]
Epoch 7:  97%|█████████▋| 115/118 [00:01<00:00, 77.75it/s, loss=0.0742, v_num=0]
Epoch 7:  98%|█████████▊| 116/118 [00:01<00:00, 77.90it/s, loss=0.0742, v_num=0]
Epoch 7:  99%|█████████▉| 117/118 [00:01<00:00, 78.41it/s, loss=0.0742, v_num=0]
Epoch 7: 100%|██████████| 118/118 [00:01<00:00, 77.83it/s, loss=0.0742, v_num=0]
Epoch 7: 100%|██████████| 118/118 [00:01<00:00, 77.77it/s, loss=0.0742, v_num=0]
Epoch 8:   0%|          | 0/118 [00:00<?, ?it/s, loss=0.0742, v_num=0]          
Epoch 8:   5%|▌         | 6/118 [00:00<00:04, 22.78it/s, loss=0.0719, v_num=0]
Epoch 8:  12%|█▏        | 14/118 [00:00<00:02, 38.90it/s, loss=0.0655, v_num=0]
Epoch 8:  13%|█▎        | 15/118 [00:00<00:02, 40.47it/s, loss=0.0663, v_num=0]
Epoch 8:  21%|██        | 25/118 [00:00<00:01, 52.81it/s, loss=0.064, v_num=0] 
Epoch 8:  29%|██▉       | 34/118 [00:00<00:01, 58.98it/s, loss=0.0592, v_num=0]
Epoch 8:  37%|███▋      | 44/118 [00:00<00:01, 64.76it/s, loss=0.0537, v_num=0]
Epoch 8:  46%|████▌     | 54/118 [00:00<00:00, 68.99it/s, loss=0.0569, v_num=0]
Epoch 8:  56%|█████▌    | 66/118 [00:00<00:00, 74.51it/s, loss=0.0609, v_num=0]
Epoch 8:  64%|██████▎   | 75/118 [00:00<00:00, 76.72it/s, loss=0.0608, v_num=0]
Epoch 8:  72%|███████▏  | 85/118 [00:01<00:00, 78.28it/s, loss=0.0573, v_num=0]
Epoch 8:  73%|███████▎  | 86/118 [00:01<00:00, 78.57it/s, loss=0.0573, v_num=0]
Epoch 8:  73%|███████▎  | 86/118 [00:01<00:00, 78.45it/s, loss=0.0561, v_num=0]
Epoch 8:  81%|████████  | 95/118 [00:01<00:00, 79.63it/s, loss=0.0485, v_num=0]
Epoch 8:  81%|████████▏ | 96/118 [00:01<00:00, 80.00it/s, loss=0.0497, v_num=0]
Epoch 8:  92%|█████████▏| 108/118 [00:01<00:00, 82.93it/s, loss=0.059, v_num=0] 
(RayTrainWorker pid=17232) 
Validation: 0it [00:00, ?it/s]2) 
(RayTrainWorker pid=17232) 
Validation:   0%|          | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 8:  92%|█████████▏| 109/118 [00:01<00:00, 72.21it/s, loss=0.059, v_num=0]
(RayTrainWorker pid=17232) 
Epoch 8:  93%|█████████▎| 110/118 [00:01<00:00, 72.61it/s, loss=0.059, v_num=0]
Epoch 8:  94%|█████████▍| 111/118 [00:01<00:00, 72.91it/s, loss=0.059, v_num=0]
Epoch 8:  95%|█████████▍| 112/118 [00:01<00:00, 73.36it/s, loss=0.059, v_num=0]
Epoch 8:  96%|█████████▌| 113/118 [00:01<00:00, 73.53it/s, loss=0.059, v_num=0]
Epoch 8:  97%|█████████▋| 114/118 [00:01<00:00, 74.02it/s, loss=0.059, v_num=0]
Epoch 8:  97%|█████████▋| 115/118 [00:01<00:00, 74.44it/s, loss=0.059, v_num=0]
Epoch 8:  98%|█████████▊| 116/118 [00:01<00:00, 74.94it/s, loss=0.059, v_num=0]
Epoch 8:  99%|█████████▉| 117/118 [00:01<00:00, 75.06it/s, loss=0.059, v_num=0]
Epoch 8: 100%|██████████| 118/118 [00:01<00:00, 75.20it/s, loss=0.059, v_num=0]
Epoch 8: 100%|██████████| 118/118 [00:01<00:00, 75.15it/s, loss=0.059, v_num=0]
Epoch 9:   0%|          | 0/118 [00:00<?, ?it/s, loss=0.059, v_num=0]          
Epoch 9:   4%|▍         | 5/118 [00:00<00:06, 18.72it/s, loss=0.0632, v_num=0]
Epoch 9:   5%|▌         | 6/118 [00:00<00:05, 21.95it/s, loss=0.0632, v_num=0]
Epoch 9:  14%|█▎        | 16/118 [00:00<00:02, 42.58it/s, loss=0.0603, v_num=0]
Epoch 9:  21%|██        | 25/118 [00:00<00:01, 52.25it/s, loss=0.0543, v_num=0]
Epoch 9:  31%|███       | 36/118 [00:00<00:01, 61.88it/s, loss=0.0572, v_num=0]
Epoch 9:  31%|███       | 36/118 [00:00<00:01, 61.81it/s, loss=0.0562, v_num=0]
Epoch 9:  39%|███▉      | 46/118 [00:00<00:01, 66.98it/s, loss=0.0504, v_num=0]
Epoch 9:  47%|████▋     | 56/118 [00:00<00:00, 71.09it/s, loss=0.0501, v_num=0]
Epoch 9:  57%|█████▋    | 67/118 [00:00<00:00, 75.35it/s, loss=0.0489, v_num=0]
Epoch 9:  64%|██████▍   | 76/118 [00:00<00:00, 76.36it/s, loss=0.0412, v_num=0]
Epoch 9:  74%|███████▎  | 87/118 [00:01<00:00, 79.08it/s, loss=0.0452, v_num=0]
Epoch 9:  82%|████████▏ | 97/118 [00:01<00:00, 81.13it/s, loss=0.0506, v_num=0]
Epoch 9:  83%|████████▎ | 98/118 [00:01<00:00, 81.21it/s, loss=0.0506, v_num=0]
Epoch 9:  92%|█████████▏| 108/118 [00:01<00:00, 84.51it/s, loss=0.0563, v_num=0]
(RayTrainWorker pid=17232) 
Validation: 0it [00:00, ?it/s]2) 
(RayTrainWorker pid=17232) 
Validation:   0%|          | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 9:  92%|█████████▏| 109/118 [00:01<00:00, 74.09it/s, loss=0.0563, v_num=0]
Epoch 9:  93%|█████████▎| 110/118 [00:01<00:00, 74.36it/s, loss=0.0563, v_num=0]
Epoch 9:  94%|█████████▍| 111/118 [00:01<00:00, 74.67it/s, loss=0.0563, v_num=0]
Epoch 9:  95%|█████████▍| 112/118 [00:01<00:00, 75.07it/s, loss=0.0563, v_num=0]
Epoch 9:  96%|█████████▌| 113/118 [00:01<00:00, 75.46it/s, loss=0.0563, v_num=0]
Epoch 9:  97%|█████████▋| 114/118 [00:01<00:00, 75.88it/s, loss=0.0563, v_num=0]
Epoch 9:  97%|█████████▋| 115/118 [00:01<00:00, 76.36it/s, loss=0.0563, v_num=0]
(RayTrainWorker pid=17232) 
Epoch 9:  98%|█████████▊| 116/118 [00:01<00:00, 75.64it/s, loss=0.0563, v_num=0]
Epoch 9:  99%|█████████▉| 117/118 [00:01<00:00, 76.00it/s, loss=0.0563, v_num=0]
Epoch 9: 100%|██████████| 118/118 [00:01<00:00, 76.04it/s, loss=0.0563, v_num=0]
Epoch 9: 100%|██████████| 118/118 [00:01<00:00, 75.99it/s, loss=0.0563, v_num=0]
Epoch 9: 100%|██████████| 118/118 [00:01<00:00, 68.09it/s, loss=0.0563, v_num=0]
2023-06-13 16:05:52,777	INFO tune.py:1111 -- Total run time: 39.46 seconds (39.28 seconds for the tuning loop).
Validation Accuracy:  0.9700015783309937
Result(
  metrics={'_report_on': 'train_epoch_end', 'train_loss': 0.03159911185503006, 'val_accuracy': 0.9700015783309937, 'val_loss': -12.346744537353516, 'epoch': 9, 'step': 1080, 'should_checkpoint': True, 'done': True, 'trial_id': 'c0d28_00000', 'experiment_tag': '0'},
  path='/tmp/ray_results/ptl-mnist-example/LightningTrainer_c0d28_00000_0_2023-06-13_16-05-13',
  checkpoint=LightningCheckpoint(local_path=/tmp/ray_results/ptl-mnist-example/LightningTrainer_c0d28_00000_0_2023-06-13_16-05-13/checkpoint_000009)
)

Evaluate your model on test dataset#

Next, let’s evaluate the model’s performance on the MNIST test set. We will first retrieve the best checkpoint from the fitting results and load it into the model.

If you lost the in-memory result object, you can also restore the model from the checkpoint file. Here the checkpoint path is: /tmp/ray_results/ptl-mnist-example/LightningTrainer_c0d28_00000_0_2023-06-13_16-05-13/checkpoint_000009/model.

checkpoint: LightningCheckpoint = result.checkpoint
best_model: pl.LightningModule = checkpoint.get_model(MNISTClassifier)

Single-node Testing#

If you have a relatively small test set, like MNIST, the easiest way is to use PyTorch Lightning’s native interface to evaluate the best model. Pass the loaded model and test data loader to pl.Trainer.test(), which will execute the test loop using your custom pl.LightningModule.test_step() method on your head node.

# Download and setup MNIST datamodule on the head node
datamodule.setup()
test_dataloader = datamodule.test_dataloader()

trainer = pl.Trainer()
result = trainer.test(best_model, dataloaders=test_dataloader)
/home/ray/anaconda3/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:92: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
  rank_zero_warn(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/ray/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1814: PossibleUserWarning: GPU available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='gpu', devices=1)`.
  rank_zero_warn(
2023-06-13 16:05:53.932195: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-13 16:05:54.097738: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-06-13 16:05:55.022170: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-06-13 16:05:55.022249: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-06-13 16:05:55.022255: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9740999937057495     │
└───────────────────────────┴───────────────────────────┘

Multi-node Testing#

Alternatively, if you have a large test set and want to speed up the testing process in parallel, you can create a group of Ray Actors to leverage multiple GPUs across multiple nodes for distributed inference. Here we demonstrate how to set up a process group and do evaluation using 4 GPUs.

import ray
import pytorch_lightning as pl

from pytorch_lightning.plugins.environments.lightning_environment import (
    LightningEnvironment,
)
from ray.air.util.torch_dist import (
    TorchDistributedWorker,
    init_torch_dist_process_group,
    shutdown_torch_dist_process_group,
)


class RayEnvironment(LightningEnvironment):
    """Setup Lightning DDP training environment for Ray cluster."""

    def world_size(self) -> int:
        return int(os.environ["WORLD_SIZE"])

    def global_rank(self) -> int:
        return int(os.environ["RANK"])

    def local_rank(self) -> int:
        return int(os.environ["LOCAL_RANK"])

    def set_world_size(self, size: int) -> None:
        # Disable it since `world_size()` directly returns data from AIR session.
        pass

    def set_global_rank(self, rank: int) -> None:
        # Disable it since `global_rank()` directly returns data from AIR session.
        pass

    def teardown(self):
        pass


@ray.remote
class TestWorker(TorchDistributedWorker):
    def run(self):
        trainer = pl.Trainer(
            num_nodes=num_workers,
            accelerator="gpu",
            strategy="ddp",
            plugins=[RayEnvironment()],
        )
        return trainer.test(best_model, dataloaders=test_dataloader)


# Create 4 remote Ray Actors, each with 1 GPU
workers = [TestWorker.options(num_gpus=1).remote() for _ in range(num_workers)]

# Initialize the Torch distributed group among the 4 actors.
# This will set up the required environment variables including 
# RANK, LOCAL_RANK, WORLD_SIZE, MASTER_ADDRESS, ...
init_torch_dist_process_group(workers=workers, backend="nccl")

# Execute the testing run in parallel
results = ray.get([worker.run.remote() for worker in workers])

# Shutdown the process group
shutdown_torch_dist_process_group(workers=workers)
2023-06-13 16:05:56,270	WARNING worker.py:2019 -- Warning: The actor TestWorker is very large (53 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.
(RayTrainWorker pid=17232) Global seed set to 888 [repeated 7x across cluster]
(RayTrainWorker pid=7319, ip=10.0.58.90) Missing logger folder: logs/lightning_logs [repeated 3x across cluster]
(RayTrainWorker pid=7319, ip=10.0.58.90) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] [repeated 3x across cluster]
(RayTrainWorker pid=6371, ip=10.0.1.80) [W reducer.cpp:1298] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) [repeated 3x across cluster]
(pid=9162, ip=10.0.26.229)   from pandas import MultiIndex, Int64Index
(pid=9162, ip=10.0.26.229)   from pandas import MultiIndex, Int64Index
(pid=9976, ip=10.0.58.90) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.
(pid=9976, ip=10.0.58.90)   from pandas import MultiIndex, Int64Index
(TestWorker pid=20600) /home/ray/anaconda3/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:92: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
(TestWorker pid=20600)   rank_zero_warn(
(TestWorker pid=20600) GPU available: True, used: True
(TestWorker pid=20600) TPU available: False, using: 0 TPU cores
(TestWorker pid=20600) IPU available: False, using: 0 IPUs
(TestWorker pid=20600) HPU available: False, using: 0 HPUs
(TestWorker pid=20600) /home/ray/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:330: PossibleUserWarning: Using `DistributedSampler` with the dataloaders. During `trainer.test()`, it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates some samples to make sure all devices have same batch size in case of uneven inputs.
(TestWorker pid=20600)   rank_zero_warn(
Testing: 0it [00:00, ?it/s]600) 
Testing DataLoader 0:   0%|          | 0/20 [00:00<?, ?it/s]
Testing DataLoader 0:  10%|█         | 2/20 [00:00<00:13,  1.36it/s]
Testing DataLoader 0:  75%|███████▌  | 15/20 [00:00<00:00, 22.10it/s]
(TestWorker pid=20600) 2023-06-13 16:06:07.550225: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
(TestWorker pid=20600) To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
(TestWorker pid=9976, ip=10.0.58.90) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] [repeated 4x across cluster]
(pid=20600) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead. [repeated 2x across cluster]
(pid=20600)   from pandas import MultiIndex, Int64Index [repeated 2x across cluster]
(TestWorker pid=20600) 2023-06-13 16:06:07.708119: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Testing DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 22.10it/s]
(TestWorker pid=20600) 2023-06-13 16:06:08.680418: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
(TestWorker pid=20600) 2023-06-13 16:06:08.680524: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
(TestWorker pid=20600) 2023-06-13 16:06:08.680532: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Testing DataLoader 0: 100%|██████████| 20/20 [00:02<00:00,  7.31it/s]
(TestWorker pid=20600) ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
(TestWorker pid=20600) ┃        Test metric        ┃       DataLoader 0        ┃
(TestWorker pid=20600) ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
(TestWorker pid=20600) │       test_accuracy       │    0.9740999937057495     │
(TestWorker pid=20600) └───────────────────────────┴───────────────────────────┘

What’s next?#