Train a Pytorch Lightning Image Classifier
Contents
Train a Pytorch Lightning Image Classifier#
This example introduces how to train a Pytorch Lightning Module using AIR LightningTrainer. We will demonstrate how to train a basic neural network on the MNIST dataset with distributed data parallelism.
!pip install "torchmetrics>=0.9" "pytorch_lightning>=1.6"
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock
from torch.utils.data import DataLoader, random_split, Subset
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import trainer
from pytorch_lightning.loggers.csv_logs import CSVLogger
Prepare Dataset and Module#
The Pytorch Lightning Trainer takes either torch.utils.data.DataLoader or pl.LightningDataModule as data inputs. You can keep using them without any changes for the Ray AIR LightningTrainer.
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=100):
super().__init__()
self.data_dir = os.getcwd()
self.batch_size = batch_size
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
def setup(self, stage=None):
with FileLock(f"{self.data_dir}.lock"):
mnist = MNIST(
self.data_dir, train=True, download=True, transform=self.transform
)
# split data into train and val sets
self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)
def test_dataloader(self):
with FileLock(f"{self.data_dir}.lock"):
self.mnist_test = MNIST(
self.data_dir, train=False, download=True, transform=self.transform
)
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)
datamodule = MNISTDataModule(batch_size=128)
Next, define a simple multi-layer perception as the subclass of pl.LightningModule.
class MNISTClassifier(pl.LightningModule):
def __init__(self, lr=1e-3, feature_dim=128):
torch.manual_seed(421)
super(MNISTClassifier, self).__init__()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, 10),
nn.ReLU(),
)
self.lr = lr
self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
self.eval_loss = []
self.eval_accuracy = []
self.test_accuracy = []
pl.seed_everything(888)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.linear_relu_stack(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("train_loss", loss)
return loss
def validation_step(self, val_batch, batch_idx):
loss, acc = self._shared_eval(val_batch)
self.log("val_accuracy", acc)
self.eval_loss.append(loss)
self.eval_accuracy.append(acc)
return {"val_loss": loss, "val_accuracy": acc}
def test_step(self, test_batch, batch_idx):
loss, acc = self._shared_eval(test_batch)
self.test_accuracy.append(acc)
self.log("test_accuracy", acc, sync_dist=True, on_epoch=True)
return {"test_loss": loss, "test_accuracy": acc}
def _shared_eval(self, batch):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return loss, acc
def on_validation_epoch_end(self):
avg_loss = torch.stack(self.eval_loss).mean()
avg_acc = torch.stack(self.eval_accuracy).mean()
self.log("val_loss", avg_loss, sync_dist=True)
self.log("val_accuracy", avg_acc, sync_dist=True)
self.eval_loss.clear()
self.eval_accuracy.clear()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
You don’t need to make any change to the definition of PyTorch Lightning model and datamodule.
Define the Cofigurations for AIR LightningTrainer#
The LightningConfigBuilder class stores all the parameters involved in training a PyTorch Lightning module. It takes the same parameter lists as those in PyTorch Lightning.
The
.module()method takes a subclass ofpl.LightningModuleand its initialization parameters.LightningTrainerwill instantiate a model instance internally in the workers’ training loop.The
.trainer()method takes the initialization parameters ofpl.Trainer. You can specify training configurations, loggers, and callbacks here.The
.fit_params()method stores all the parameters that will be passed intopl.Trainer.fit(), including train/val dataloaders, datamodules, and checkpoint paths.The
.checkpointing()method saves the configurations for aRayModelCheckpointcallback. This callback reports the latest metrics to the AIR session along with a newly saved checkpoint.The
.build()method generates a dictionary that contains all the configurations in the builder. This dictionary will be passed toLightningTrainerlater.
Next, let’s go step-by-step to see how to convert your existing PyTorch Lightning training script to a LightningTrainer.
from pytorch_lightning.callbacks import ModelCheckpoint
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.lightning import (
LightningTrainer,
LightningConfigBuilder,
LightningCheckpoint,
)
def build_lightning_config_from_existing_code(use_gpu):
# Create a config builder to encapsulate all required parameters.
# Note that model instantiation and fitting will occur later in the LightingTrainer,
# rather than in the config builder.
config_builder = LightningConfigBuilder()
# 1. define your model
# model = MNISTClassifier(lr=1e-3, feature_dim=128)
config_builder.module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)
# 2. define a ModelCheckpoint callback
# checkpoint_callback = ModelCheckpoint(
# monitor="val_accuracy", mode="max", save_top_k=3
# )
config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)
# 3. Define a Lightning trainer
# trainer = pl.Trainer(
# max_epochs=10,
# accelerator="cpu",
# strategy="ddp",
# log_every_n_steps=100,
# logger=CSVLogger("logs"),
# callbacks=[checkpoint_callback],
# )
config_builder.trainer(
max_epochs=10,
accelerator="gpu" if use_gpu else "cpu",
log_every_n_steps=100,
logger=CSVLogger("logs"),
)
# You do not need to provide the checkpoint callback and strategy here,
# since LightningTrainer configures them automatically.
# You can also add any other callbacks into LightningConfigBuilder.trainer().
# 4. Parameters for model fitting
# trainer.fit(model, datamodule=datamodule)
config_builder.fit_params(datamodule=datamodule)
# Finally, compile all the configs into a dictionary for LightningTrainer
lightning_config = config_builder.build()
return lightning_config
Now put everything together:
use_gpu = True # Set it to False if you want to run without GPUs
num_workers = 4
lightning_config = build_lightning_config_from_existing_code(use_gpu=use_gpu)
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(
name="ptl-mnist-example",
storage_path="/tmp/ray_results",
checkpoint_config=CheckpointConfig(
num_to_keep=3,
checkpoint_score_attribute="val_accuracy",
checkpoint_score_order="max",
),
)
trainer = LightningTrainer(
lightning_config=lightning_config,
scaling_config=scaling_config,
run_config=run_config,
)
Now fit your trainer:
result = trainer.fit()
print("Validation Accuracy: ", result.metrics["val_accuracy"])
result
2023-06-13 16:05:12,869 INFO worker.py:1452 -- Connecting to existing Ray cluster at address: 10.0.28.253:6379...
2023-06-13 16:05:12,877 INFO worker.py:1627 -- Connected to Ray cluster. View the dashboard at https://console.anyscale-staging.com/api/v2/sessions/ses_15dlj65vax84ljl7ayeplubryd/services?redirect_to=dashboard
2023-06-13 16:05:13,036 INFO packaging.py:347 -- Pushing file package 'gcs://_ray_pkg_488e346d50f332edaa288fdaa22b2bdc.zip' (52.65MiB) to Ray cluster...
2023-06-13 16:05:13,221 INFO packaging.py:360 -- Successfully pushed file package 'gcs://_ray_pkg_488e346d50f332edaa288fdaa22b2bdc.zip'.
2023-06-13 16:05:13,314 INFO tune.py:226 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `Trainer(...)`.
Tune Status
| Current time: | 2023-06-13 16:05:52 |
| Running for: | 00:00:39.29 |
| Memory: | 5.5/30.9 GiB |
System Info
Using FIFO scheduling algorithm.Logical resource usage: 1.0/32 CPUs, 4.0/4 GPUs
Trial Status
| Trial name | status | loc | iter | total time (s) | train_loss | val_accuracy | val_loss |
|---|---|---|---|---|---|---|---|
| LightningTrainer_c0d28_00000 | TERMINATED | 10.0.28.253:16995 | 10 | 28.5133 | 0.0315991 | 0.970002 | -12.3467 |
(pid=16995) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.
(pid=16995) from pandas import MultiIndex, Int64Index
(LightningTrainer pid=16995) 2023-06-13 16:05:24,007 INFO backend_executor.py:137 -- Starting distributed worker processes: ['17232 (10.0.28.253)', '6371 (10.0.1.80)', '7319 (10.0.58.90)', '6493 (10.0.26.229)']
(RayTrainWorker pid=17232) 2023-06-13 16:05:24,966 INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]
(RayTrainWorker pid=17232) from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=17232) from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=7319, ip=10.0.58.90) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.
(RayTrainWorker pid=7319, ip=10.0.58.90) from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=17232) Global seed set to 888
(RayTrainWorker pid=17232) GPU available: True, used: True
(RayTrainWorker pid=17232) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=17232) IPU available: False, using: 0 IPUs
(RayTrainWorker pid=17232) HPU available: False, using: 0 HPUs
(RayTrainWorker pid=6371, ip=10.0.1.80) Missing logger folder: logs/lightning_logs
(RayTrainWorker pid=6371, ip=10.0.1.80) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
(RayTrainWorker pid=17232)
(RayTrainWorker pid=17232) | Name | Type | Params
(RayTrainWorker pid=17232) -------------------------------------------------
(RayTrainWorker pid=17232) 0 | linear_relu_stack | Sequential | 101 K
(RayTrainWorker pid=17232) 1 | accuracy | Accuracy | 0
(RayTrainWorker pid=17232) -------------------------------------------------
(RayTrainWorker pid=17232) 101 K Trainable params
(RayTrainWorker pid=17232) 0 Non-trainable params
(RayTrainWorker pid=17232) 101 K Total params
(RayTrainWorker pid=17232) 0.407 Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 1.33it/s]
Epoch 0: 0%| | 0/118 [00:00<?, ?it/s]
(RayTrainWorker pid=6493, ip=10.0.26.229) [W reducer.cpp:1298] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
(RayTrainWorker pid=6371, ip=10.0.1.80) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead. [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=6371, ip=10.0.1.80) from pandas import MultiIndex, Int64Index [repeated 2x across cluster]
Epoch 0: 3%|▎ | 4/118 [00:00<00:07, 16.07it/s, loss=2.09, v_num=0]
Epoch 0: 4%|▍ | 5/118 [00:00<00:05, 19.42it/s, loss=2.09, v_num=0]
Epoch 0: 12%|█▏ | 14/118 [00:00<00:02, 39.49it/s, loss=1.55, v_num=0]
Epoch 0: 12%|█▏ | 14/118 [00:00<00:02, 38.73it/s, loss=1.5, v_num=0]
Epoch 0: 21%|██ | 25/118 [00:00<00:01, 53.89it/s, loss=0.933, v_num=0]
Epoch 0: 30%|██▉ | 35/118 [00:00<00:01, 61.80it/s, loss=0.522, v_num=0]
Epoch 0: 38%|███▊ | 45/118 [00:00<00:01, 67.21it/s, loss=0.425, v_num=0]
Epoch 0: 45%|████▍ | 53/118 [00:00<00:00, 69.59it/s, loss=0.379, v_num=0]
Epoch 0: 46%|████▌ | 54/118 [00:00<00:00, 69.65it/s, loss=0.373, v_num=0]
Epoch 0: 54%|█████▍ | 64/118 [00:00<00:00, 73.24it/s, loss=0.364, v_num=0]
Epoch 0: 62%|██████▏ | 73/118 [00:00<00:00, 74.68it/s, loss=0.341, v_num=0]
Epoch 0: 63%|██████▎ | 74/118 [00:00<00:00, 75.21it/s, loss=0.341, v_num=0]
Epoch 0: 70%|███████ | 83/118 [00:01<00:00, 76.62it/s, loss=0.335, v_num=0]
Epoch 0: 80%|███████▉ | 94/118 [00:01<00:00, 79.16it/s, loss=0.297, v_num=0]
Epoch 0: 90%|████████▉ | 106/118 [00:01<00:00, 82.26it/s, loss=0.281, v_num=0]
Epoch 0: 92%|█████████▏| 108/118 [00:01<00:00, 83.04it/s, loss=0.284, v_num=0]
Validation: 0it [00:00, ?it/s]2)
(RayTrainWorker pid=17232)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 0: 92%|█████████▏| 109/118 [00:01<00:00, 73.67it/s, loss=0.284, v_num=0]
Epoch 0: 93%|█████████▎| 110/118 [00:01<00:00, 74.14it/s, loss=0.284, v_num=0]
Epoch 0: 94%|█████████▍| 111/118 [00:01<00:00, 74.57it/s, loss=0.284, v_num=0]
(RayTrainWorker pid=17232)
Epoch 0: 95%|█████████▍| 112/118 [00:01<00:00, 73.94it/s, loss=0.284, v_num=0]
Epoch 0: 96%|█████████▌| 113/118 [00:01<00:00, 74.45it/s, loss=0.284, v_num=0]
Epoch 0: 97%|█████████▋| 114/118 [00:01<00:00, 74.96it/s, loss=0.284, v_num=0]
Epoch 0: 97%|█████████▋| 115/118 [00:01<00:00, 75.47it/s, loss=0.284, v_num=0]
Epoch 0: 98%|█████████▊| 116/118 [00:01<00:00, 75.05it/s, loss=0.284, v_num=0]
Epoch 0: 99%|█████████▉| 117/118 [00:01<00:00, 75.55it/s, loss=0.284, v_num=0]
Epoch 0: 100%|██████████| 118/118 [00:01<00:00, 75.21it/s, loss=0.284, v_num=0]
Epoch 0: 100%|██████████| 118/118 [00:01<00:00, 75.17it/s, loss=0.284, v_num=0]
Trial Progress
| Trial name | _report_on | date | done | epoch | experiment_tag | hostname | iterations_since_restore | node_ip | pid | should_checkpoint | step | time_since_restore | time_this_iter_s | time_total_s | timestamp | train_loss | training_iteration | trial_id | val_accuracy | val_loss |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| LightningTrainer_c0d28_00000 | train_epoch_end | 2023-06-13_16-05-50 | True | 9 | 0 | ip-10-0-28-253 | 10 | 10.0.28.253 | 16995 | True | 1080 | 28.5133 | 1.73311 | 28.5133 | 1686697550 | 0.0315991 | 10 | c0d28_00000 | 0.970002 | -12.3467 |
Epoch 1: 0%| | 0/118 [00:00<?, ?it/s, loss=0.284, v_num=0]
Epoch 1: 2%|▏ | 2/118 [00:00<00:15, 7.71it/s, loss=0.283, v_num=0]
Epoch 1: 11%|█ | 13/118 [00:00<00:02, 35.75it/s, loss=0.268, v_num=0]
Epoch 1: 20%|██ | 24/118 [00:00<00:01, 51.49it/s, loss=0.253, v_num=0]
Epoch 1: 28%|██▊ | 33/118 [00:00<00:01, 57.86it/s, loss=0.252, v_num=0]
Epoch 1: 36%|███▋ | 43/118 [00:00<00:01, 64.22it/s, loss=0.244, v_num=0]
Epoch 1: 37%|███▋ | 44/118 [00:00<00:01, 64.96it/s, loss=0.244, v_num=0]
Epoch 1: 37%|███▋ | 44/118 [00:00<00:01, 64.66it/s, loss=0.245, v_num=0]
Epoch 1: 46%|████▌ | 54/118 [00:00<00:00, 69.28it/s, loss=0.241, v_num=0]
Epoch 1: 55%|█████▌ | 65/118 [00:00<00:00, 73.79it/s, loss=0.245, v_num=0]
Epoch 1: 64%|██████▎ | 75/118 [00:00<00:00, 75.85it/s, loss=0.22, v_num=0]
Epoch 1: 64%|██████▎ | 75/118 [00:00<00:00, 75.83it/s, loss=0.222, v_num=0]
Epoch 1: 72%|███████▏ | 85/118 [00:01<00:00, 78.43it/s, loss=0.203, v_num=0]
Epoch 1: 73%|███████▎ | 86/118 [00:01<00:00, 78.66it/s, loss=0.203, v_num=0]
Epoch 1: 81%|████████ | 95/118 [00:01<00:00, 79.71it/s, loss=0.199, v_num=0]
Epoch 1: 92%|█████████▏| 108/118 [00:01<00:00, 83.67it/s, loss=0.206, v_num=0]
(RayTrainWorker pid=17232)
Validation: 0it [00:00, ?it/s]2)
(RayTrainWorker pid=17232)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 1: 92%|█████████▏| 109/118 [00:01<00:00, 73.73it/s, loss=0.206, v_num=0]
Epoch 1: 93%|█████████▎| 110/118 [00:01<00:00, 74.00it/s, loss=0.206, v_num=0]
Epoch 1: 94%|█████████▍| 111/118 [00:01<00:00, 74.36it/s, loss=0.206, v_num=0]
Epoch 1: 95%|█████████▍| 112/118 [00:01<00:00, 74.72it/s, loss=0.206, v_num=0]
Epoch 1: 96%|█████████▌| 113/118 [00:01<00:00, 75.08it/s, loss=0.206, v_num=0]
(RayTrainWorker pid=17232)
Epoch 1: 97%|█████████▋| 114/118 [00:01<00:00, 75.42it/s, loss=0.206, v_num=0]
Epoch 1: 97%|█████████▋| 115/118 [00:01<00:00, 75.77it/s, loss=0.206, v_num=0]
Epoch 1: 98%|█████████▊| 116/118 [00:01<00:00, 76.08it/s, loss=0.206, v_num=0]
Epoch 1: 99%|█████████▉| 117/118 [00:01<00:00, 76.59it/s, loss=0.206, v_num=0]
Epoch 1: 100%|██████████| 118/118 [00:01<00:00, 76.69it/s, loss=0.206, v_num=0]
Epoch 1: 100%|██████████| 118/118 [00:01<00:00, 76.64it/s, loss=0.206, v_num=0]
Epoch 2: 0%| | 0/118 [00:00<?, ?it/s, loss=0.206, v_num=0]
Epoch 2: 5%|▌ | 6/118 [00:00<00:05, 19.96it/s, loss=0.187, v_num=0]
Epoch 2: 5%|▌ | 6/118 [00:00<00:05, 19.93it/s, loss=0.188, v_num=0]
Epoch 2: 14%|█▎ | 16/118 [00:00<00:02, 39.92it/s, loss=0.176, v_num=0]
Epoch 2: 22%|██▏ | 26/118 [00:00<00:01, 51.69it/s, loss=0.183, v_num=0]
Epoch 2: 31%|███ | 36/118 [00:00<00:01, 59.53it/s, loss=0.18, v_num=0]
Epoch 2: 31%|███▏ | 37/118 [00:00<00:01, 60.44it/s, loss=0.182, v_num=0]
Epoch 2: 41%|████ | 48/118 [00:00<00:01, 67.23it/s, loss=0.178, v_num=0]
Epoch 2: 49%|████▉ | 58/118 [00:00<00:00, 71.86it/s, loss=0.182, v_num=0]
Epoch 2: 57%|█████▋ | 67/118 [00:00<00:00, 73.02it/s, loss=0.177, v_num=0]
Epoch 2: 65%|██████▌ | 77/118 [00:01<00:00, 75.08it/s, loss=0.155, v_num=0]
Epoch 2: 74%|███████▎ | 87/118 [00:01<00:00, 77.13it/s, loss=0.157, v_num=0]
Epoch 2: 81%|████████▏ | 96/118 [00:01<00:00, 78.76it/s, loss=0.162, v_num=0]
Epoch 2: 92%|█████████▏| 108/118 [00:01<00:00, 81.91it/s, loss=0.149, v_num=0]
(RayTrainWorker pid=17232)
Validation: 0it [00:00, ?it/s]2)
(RayTrainWorker pid=17232)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 2: 92%|█████████▏| 109/118 [00:01<00:00, 71.87it/s, loss=0.149, v_num=0]
Epoch 2: 93%|█████████▎| 110/118 [00:01<00:00, 72.36it/s, loss=0.149, v_num=0]
Epoch 2: 94%|█████████▍| 111/118 [00:01<00:00, 72.87it/s, loss=0.149, v_num=0]
Epoch 2: 95%|█████████▍| 112/118 [00:01<00:00, 73.22it/s, loss=0.149, v_num=0]
(RayTrainWorker pid=17232)
Epoch 2: 96%|█████████▌| 113/118 [00:01<00:00, 73.13it/s, loss=0.149, v_num=0]
Epoch 2: 97%|█████████▋| 114/118 [00:01<00:00, 73.63it/s, loss=0.149, v_num=0]
Epoch 2: 97%|█████████▋| 115/118 [00:01<00:00, 74.14it/s, loss=0.149, v_num=0]
Epoch 2: 98%|█████████▊| 116/118 [00:01<00:00, 74.65it/s, loss=0.149, v_num=0]
Epoch 2: 99%|█████████▉| 117/118 [00:01<00:00, 74.67it/s, loss=0.149, v_num=0]
Epoch 2: 100%|██████████| 118/118 [00:01<00:00, 74.79it/s, loss=0.149, v_num=0]
Epoch 2: 100%|██████████| 118/118 [00:01<00:00, 74.74it/s, loss=0.149, v_num=0]
Epoch 3: 0%| | 0/118 [00:00<?, ?it/s, loss=0.149, v_num=0]
Epoch 3: 1%| | 1/118 [00:00<00:22, 5.27it/s, loss=0.149, v_num=0]
Epoch 3: 7%|▋ | 8/118 [00:00<00:04, 27.40it/s, loss=0.144, v_num=0]
Epoch 3: 7%|▋ | 8/118 [00:00<00:04, 26.95it/s, loss=0.143, v_num=0]
Epoch 3: 14%|█▎ | 16/118 [00:00<00:02, 40.27it/s, loss=0.13, v_num=0]
Epoch 3: 22%|██▏ | 26/118 [00:00<00:01, 52.98it/s, loss=0.122, v_num=0]
Epoch 3: 30%|██▉ | 35/118 [00:00<00:01, 58.42it/s, loss=0.128, v_num=0]
Epoch 3: 31%|███ | 36/118 [00:00<00:01, 59.33it/s, loss=0.128, v_num=0]
Epoch 3: 39%|███▉ | 46/118 [00:00<00:01, 65.01it/s, loss=0.124, v_num=0]
Epoch 3: 47%|████▋ | 55/118 [00:00<00:00, 67.90it/s, loss=0.138, v_num=0]
Epoch 3: 55%|█████▌ | 65/118 [00:00<00:00, 71.30it/s, loss=0.145, v_num=0]
Epoch 3: 55%|█████▌ | 65/118 [00:00<00:00, 70.84it/s, loss=0.152, v_num=0]
Epoch 3: 64%|██████▍ | 76/118 [00:01<00:00, 74.81it/s, loss=0.138, v_num=0]
Epoch 3: 71%|███████ | 84/118 [00:01<00:00, 74.74it/s, loss=0.125, v_num=0]
Epoch 3: 80%|███████▉ | 94/118 [00:01<00:00, 76.91it/s, loss=0.124, v_num=0]
Epoch 3: 81%|████████ | 95/118 [00:01<00:00, 77.30it/s, loss=0.124, v_num=0]
Epoch 3: 90%|████████▉ | 106/118 [00:01<00:00, 79.61it/s, loss=0.119, v_num=0]
Epoch 3: 92%|█████████▏| 108/118 [00:01<00:00, 80.41it/s, loss=0.123, v_num=0]
Validation: 0it [00:00, ?it/s]2)
(RayTrainWorker pid=17232)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
(RayTrainWorker pid=17232)
Epoch 3: 92%|█████████▏| 109/118 [00:01<00:00, 70.82it/s, loss=0.123, v_num=0]
(RayTrainWorker pid=17232)
Epoch 3: 93%|█████████▎| 110/118 [00:01<00:00, 71.26it/s, loss=0.123, v_num=0]
Epoch 3: 94%|█████████▍| 111/118 [00:01<00:00, 71.60it/s, loss=0.123, v_num=0]
Epoch 3: 95%|█████████▍| 112/118 [00:01<00:00, 71.92it/s, loss=0.123, v_num=0]
Epoch 3: 96%|█████████▌| 113/118 [00:01<00:00, 72.20it/s, loss=0.123, v_num=0]
Epoch 3: 97%|█████████▋| 114/118 [00:01<00:00, 72.64it/s, loss=0.123, v_num=0]
Epoch 3: 97%|█████████▋| 115/118 [00:01<00:00, 73.05it/s, loss=0.123, v_num=0]
Epoch 3: 98%|█████████▊| 116/118 [00:01<00:00, 73.55it/s, loss=0.123, v_num=0]
Epoch 3: 99%|█████████▉| 117/118 [00:01<00:00, 74.00it/s, loss=0.123, v_num=0]
Epoch 3: 100%|██████████| 118/118 [00:01<00:00, 73.28it/s, loss=0.123, v_num=0]
Epoch 3: 100%|██████████| 118/118 [00:01<00:00, 73.23it/s, loss=0.123, v_num=0]
Epoch 4: 0%| | 0/118 [00:00<?, ?it/s, loss=0.123, v_num=0]
Epoch 4: 3%|▎ | 3/118 [00:00<00:10, 11.48it/s, loss=0.114, v_num=0]
Epoch 4: 10%|█ | 12/118 [00:00<00:03, 32.93it/s, loss=0.111, v_num=0]
Epoch 4: 18%|█▊ | 21/118 [00:00<00:02, 45.21it/s, loss=0.102, v_num=0]
Epoch 4: 26%|██▋ | 31/118 [00:00<00:01, 54.51it/s, loss=0.11, v_num=0]
Epoch 4: 35%|███▍ | 41/118 [00:00<00:01, 60.49it/s, loss=0.112, v_num=0]
Epoch 4: 35%|███▍ | 41/118 [00:00<00:01, 60.40it/s, loss=0.109, v_num=0]
Epoch 4: 43%|████▎ | 51/118 [00:00<00:01, 65.31it/s, loss=0.112, v_num=0]
Epoch 4: 43%|████▎ | 51/118 [00:00<00:01, 64.95it/s, loss=0.11, v_num=0]
Epoch 4: 52%|█████▏ | 61/118 [00:00<00:00, 68.96it/s, loss=0.116, v_num=0]
Epoch 4: 53%|█████▎ | 62/118 [00:00<00:00, 69.63it/s, loss=0.116, v_num=0]
Epoch 4: 61%|██████ | 72/118 [00:00<00:00, 72.69it/s, loss=0.12, v_num=0]
Epoch 4: 61%|██████ | 72/118 [00:00<00:00, 72.35it/s, loss=0.119, v_num=0]
Epoch 4: 69%|██████▊ | 81/118 [00:01<00:00, 74.46it/s, loss=0.124, v_num=0]
Epoch 4: 69%|██████▊ | 81/118 [00:01<00:00, 73.59it/s, loss=0.121, v_num=0]
Epoch 4: 78%|███████▊ | 92/118 [00:01<00:00, 76.35it/s, loss=0.105, v_num=0]
Epoch 4: 78%|███████▊ | 92/118 [00:01<00:00, 76.33it/s, loss=0.108, v_num=0]
Epoch 4: 87%|████████▋ | 103/118 [00:01<00:00, 78.93it/s, loss=0.0973, v_num=0]
Epoch 4: 92%|█████████▏| 108/118 [00:01<00:00, 80.62it/s, loss=0.107, v_num=0]
Validation: 0it [00:00, ?it/s]2)
(RayTrainWorker pid=17232)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 4: 92%|█████████▏| 109/118 [00:01<00:00, 70.57it/s, loss=0.107, v_num=0]
Epoch 4: 93%|█████████▎| 110/118 [00:01<00:00, 71.05it/s, loss=0.107, v_num=0]
Epoch 4: 94%|█████████▍| 111/118 [00:01<00:00, 71.56it/s, loss=0.107, v_num=0]
Epoch 4: 95%|█████████▍| 112/118 [00:01<00:00, 71.92it/s, loss=0.107, v_num=0]
Epoch 4: 96%|█████████▌| 113/118 [00:01<00:00, 72.04it/s, loss=0.107, v_num=0]
Epoch 4: 97%|█████████▋| 114/118 [00:01<00:00, 72.52it/s, loss=0.107, v_num=0]
Epoch 4: 97%|█████████▋| 115/118 [00:01<00:00, 73.01it/s, loss=0.107, v_num=0]
Epoch 4: 98%|█████████▊| 116/118 [00:01<00:00, 73.42it/s, loss=0.107, v_num=0]
Epoch 4: 99%|█████████▉| 117/118 [00:01<00:00, 73.52it/s, loss=0.107, v_num=0]
Epoch 4: 100%|██████████| 118/118 [00:01<00:00, 73.64it/s, loss=0.107, v_num=0]
Epoch 4: 100%|██████████| 118/118 [00:01<00:00, 73.59it/s, loss=0.107, v_num=0]
Epoch 5: 0%| | 0/118 [00:00<?, ?it/s, loss=0.107, v_num=0]
Epoch 5: 2%|▏ | 2/118 [00:00<00:13, 8.38it/s, loss=0.103, v_num=0]
Epoch 5: 8%|▊ | 10/118 [00:00<00:03, 29.51it/s, loss=0.101, v_num=0]
Epoch 5: 18%|█▊ | 21/118 [00:00<00:02, 47.55it/s, loss=0.103, v_num=0]
Epoch 5: 26%|██▋ | 31/118 [00:00<00:01, 56.79it/s, loss=0.0998, v_num=0]
Epoch 5: 34%|███▍ | 40/118 [00:00<00:01, 61.21it/s, loss=0.104, v_num=0]
Epoch 5: 42%|████▏ | 50/118 [00:00<00:01, 66.50it/s, loss=0.0978, v_num=0]
Epoch 5: 43%|████▎ | 51/118 [00:00<00:00, 67.10it/s, loss=0.0978, v_num=0]
Epoch 5: 53%|█████▎ | 62/118 [00:00<00:00, 71.68it/s, loss=0.0933, v_num=0]
Epoch 5: 60%|██████ | 71/118 [00:00<00:00, 73.91it/s, loss=0.0864, v_num=0]
Epoch 5: 61%|██████ | 72/118 [00:00<00:00, 74.28it/s, loss=0.0864, v_num=0]
Epoch 5: 69%|██████▊ | 81/118 [00:01<00:00, 75.92it/s, loss=0.0845, v_num=0]
Epoch 5: 69%|██████▉ | 82/118 [00:01<00:00, 76.33it/s, loss=0.0845, v_num=0]
Epoch 5: 78%|███████▊ | 92/118 [00:01<00:00, 78.53it/s, loss=0.102, v_num=0]
Epoch 5: 87%|████████▋ | 103/118 [00:01<00:00, 80.60it/s, loss=0.109, v_num=0]
Epoch 5: 92%|█████████▏| 108/118 [00:01<00:00, 82.44it/s, loss=0.105, v_num=0]
Validation: 0it [00:00, ?it/s]2)
(RayTrainWorker pid=17232)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 5: 92%|█████████▏| 109/118 [00:01<00:00, 72.14it/s, loss=0.105, v_num=0]
Epoch 5: 93%|█████████▎| 110/118 [00:01<00:00, 72.45it/s, loss=0.105, v_num=0]
Epoch 5: 94%|█████████▍| 111/118 [00:01<00:00, 72.86it/s, loss=0.105, v_num=0]
Epoch 5: 95%|█████████▍| 112/118 [00:01<00:00, 73.21it/s, loss=0.105, v_num=0]
Epoch 5: 96%|█████████▌| 113/118 [00:01<00:00, 73.55it/s, loss=0.105, v_num=0]
Epoch 5: 97%|█████████▋| 114/118 [00:01<00:00, 73.75it/s, loss=0.105, v_num=0]
Epoch 5: 97%|█████████▋| 115/118 [00:01<00:00, 74.15it/s, loss=0.105, v_num=0]
Epoch 5: 98%|█████████▊| 116/118 [00:01<00:00, 74.65it/s, loss=0.105, v_num=0]
Epoch 5: 99%|█████████▉| 117/118 [00:01<00:00, 75.10it/s, loss=0.105, v_num=0]
Epoch 5: 100%|██████████| 118/118 [00:01<00:00, 74.97it/s, loss=0.105, v_num=0]
Epoch 5: 100%|██████████| 118/118 [00:01<00:00, 74.92it/s, loss=0.105, v_num=0]
Epoch 6: 0%| | 0/118 [00:00<?, ?it/s, loss=0.105, v_num=0]
Epoch 6: 2%|▏ | 2/118 [00:00<00:13, 8.61it/s, loss=0.0952, v_num=0]
Epoch 6: 8%|▊ | 10/118 [00:00<00:03, 29.86it/s, loss=0.0742, v_num=0]
Epoch 6: 17%|█▋ | 20/118 [00:00<00:02, 45.73it/s, loss=0.0701, v_num=0]
Epoch 6: 25%|██▍ | 29/118 [00:00<00:01, 53.55it/s, loss=0.0818, v_num=0]
Epoch 6: 34%|███▍ | 40/118 [00:00<00:01, 61.80it/s, loss=0.0876, v_num=0]
Epoch 6: 34%|███▍ | 40/118 [00:00<00:01, 61.39it/s, loss=0.0874, v_num=0]
Epoch 6: 43%|████▎ | 51/118 [00:00<00:00, 67.65it/s, loss=0.09, v_num=0]
Epoch 6: 52%|█████▏ | 61/118 [00:00<00:00, 71.13it/s, loss=0.0883, v_num=0]
Epoch 6: 60%|██████ | 71/118 [00:00<00:00, 73.82it/s, loss=0.08, v_num=0]
Epoch 6: 69%|██████▉ | 82/118 [00:01<00:00, 77.07it/s, loss=0.0791, v_num=0]
Epoch 6: 69%|██████▉ | 82/118 [00:01<00:00, 76.84it/s, loss=0.0778, v_num=0]
Epoch 6: 77%|███████▋ | 91/118 [00:01<00:00, 78.10it/s, loss=0.0764, v_num=0]
Epoch 6: 78%|███████▊ | 92/118 [00:01<00:00, 78.52it/s, loss=0.0764, v_num=0]
Epoch 6: 84%|████████▍ | 99/118 [00:01<00:00, 78.46it/s, loss=0.0668, v_num=0]
Epoch 6: 92%|█████████▏| 108/118 [00:01<00:00, 81.21it/s, loss=0.0822, v_num=0]
Validation: 0it [00:00, ?it/s]2)
(RayTrainWorker pid=17232)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 6: 92%|█████████▏| 109/118 [00:01<00:00, 70.99it/s, loss=0.0822, v_num=0]
Epoch 6: 93%|█████████▎| 110/118 [00:01<00:00, 71.41it/s, loss=0.0822, v_num=0]
Epoch 6: 94%|█████████▍| 111/118 [00:01<00:00, 71.74it/s, loss=0.0822, v_num=0]
Epoch 6: 95%|█████████▍| 112/118 [00:01<00:00, 72.00it/s, loss=0.0822, v_num=0]
Epoch 6: 96%|█████████▌| 113/118 [00:01<00:00, 72.26it/s, loss=0.0822, v_num=0]
Epoch 6: 97%|█████████▋| 114/118 [00:01<00:00, 72.67it/s, loss=0.0822, v_num=0]
Epoch 6: 97%|█████████▋| 115/118 [00:01<00:00, 73.08it/s, loss=0.0822, v_num=0]
Epoch 6: 98%|█████████▊| 116/118 [00:01<00:00, 73.40it/s, loss=0.0822, v_num=0]
(RayTrainWorker pid=17232)
Epoch 6: 99%|█████████▉| 117/118 [00:01<00:00, 73.85it/s, loss=0.0822, v_num=0]
Epoch 6: 100%|██████████| 118/118 [00:01<00:00, 73.97it/s, loss=0.0822, v_num=0]
Epoch 6: 100%|██████████| 118/118 [00:01<00:00, 73.91it/s, loss=0.0822, v_num=0]
Epoch 7: 0%| | 0/118 [00:00<?, ?it/s, loss=0.0822, v_num=0]
Epoch 7: 1%| | 1/118 [00:00<00:24, 4.76it/s, loss=0.0816, v_num=0]
Epoch 7: 8%|▊ | 10/118 [00:00<00:03, 32.68it/s, loss=0.0774, v_num=0]
Epoch 7: 17%|█▋ | 20/118 [00:00<00:02, 48.49it/s, loss=0.0633, v_num=0]
Epoch 7: 25%|██▍ | 29/118 [00:00<00:01, 55.60it/s, loss=0.0682, v_num=0]
Epoch 7: 34%|███▍ | 40/118 [00:00<00:01, 63.73it/s, loss=0.0633, v_num=0]
Epoch 7: 43%|████▎ | 51/118 [00:00<00:00, 70.11it/s, loss=0.0596, v_num=0]
Epoch 7: 43%|████▎ | 51/118 [00:00<00:00, 69.74it/s, loss=0.0599, v_num=0]
Epoch 7: 53%|█████▎ | 62/118 [00:00<00:00, 74.75it/s, loss=0.0601, v_num=0]
Epoch 7: 62%|██████▏ | 73/118 [00:00<00:00, 77.99it/s, loss=0.0696, v_num=0]
Epoch 7: 62%|██████▏ | 73/118 [00:00<00:00, 77.77it/s, loss=0.0715, v_num=0]
Epoch 7: 70%|███████ | 83/118 [00:01<00:00, 80.11it/s, loss=0.0793, v_num=0]
Epoch 7: 71%|███████ | 84/118 [00:01<00:00, 80.45it/s, loss=0.0793, v_num=0]
Epoch 7: 81%|████████ | 95/118 [00:01<00:00, 82.87it/s, loss=0.0778, v_num=0]
Epoch 7: 91%|█████████ | 107/118 [00:01<00:00, 85.72it/s, loss=0.0743, v_num=0]
Epoch 7: 91%|█████████ | 107/118 [00:01<00:00, 85.50it/s, loss=0.0753, v_num=0]
Epoch 7: 92%|█████████▏| 108/118 [00:01<00:00, 85.88it/s, loss=0.0742, v_num=0]
Validation: 0it [00:00, ?it/s]2)
(RayTrainWorker pid=17232)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 7: 92%|█████████▏| 109/118 [00:01<00:00, 75.70it/s, loss=0.0742, v_num=0]
Epoch 7: 93%|█████████▎| 110/118 [00:01<00:00, 76.11it/s, loss=0.0742, v_num=0]
Epoch 7: 94%|█████████▍| 111/118 [00:01<00:00, 76.41it/s, loss=0.0742, v_num=0]
(RayTrainWorker pid=17232)
Epoch 7: 95%|█████████▍| 112/118 [00:01<00:00, 76.72it/s, loss=0.0742, v_num=0]
Epoch 7: 96%|█████████▌| 113/118 [00:01<00:00, 77.06it/s, loss=0.0742, v_num=0]
Epoch 7: 97%|█████████▋| 114/118 [00:01<00:00, 77.34it/s, loss=0.0742, v_num=0]
Epoch 7: 97%|█████████▋| 115/118 [00:01<00:00, 77.75it/s, loss=0.0742, v_num=0]
Epoch 7: 98%|█████████▊| 116/118 [00:01<00:00, 77.90it/s, loss=0.0742, v_num=0]
Epoch 7: 99%|█████████▉| 117/118 [00:01<00:00, 78.41it/s, loss=0.0742, v_num=0]
Epoch 7: 100%|██████████| 118/118 [00:01<00:00, 77.83it/s, loss=0.0742, v_num=0]
Epoch 7: 100%|██████████| 118/118 [00:01<00:00, 77.77it/s, loss=0.0742, v_num=0]
Epoch 8: 0%| | 0/118 [00:00<?, ?it/s, loss=0.0742, v_num=0]
Epoch 8: 5%|▌ | 6/118 [00:00<00:04, 22.78it/s, loss=0.0719, v_num=0]
Epoch 8: 12%|█▏ | 14/118 [00:00<00:02, 38.90it/s, loss=0.0655, v_num=0]
Epoch 8: 13%|█▎ | 15/118 [00:00<00:02, 40.47it/s, loss=0.0663, v_num=0]
Epoch 8: 21%|██ | 25/118 [00:00<00:01, 52.81it/s, loss=0.064, v_num=0]
Epoch 8: 29%|██▉ | 34/118 [00:00<00:01, 58.98it/s, loss=0.0592, v_num=0]
Epoch 8: 37%|███▋ | 44/118 [00:00<00:01, 64.76it/s, loss=0.0537, v_num=0]
Epoch 8: 46%|████▌ | 54/118 [00:00<00:00, 68.99it/s, loss=0.0569, v_num=0]
Epoch 8: 56%|█████▌ | 66/118 [00:00<00:00, 74.51it/s, loss=0.0609, v_num=0]
Epoch 8: 64%|██████▎ | 75/118 [00:00<00:00, 76.72it/s, loss=0.0608, v_num=0]
Epoch 8: 72%|███████▏ | 85/118 [00:01<00:00, 78.28it/s, loss=0.0573, v_num=0]
Epoch 8: 73%|███████▎ | 86/118 [00:01<00:00, 78.57it/s, loss=0.0573, v_num=0]
Epoch 8: 73%|███████▎ | 86/118 [00:01<00:00, 78.45it/s, loss=0.0561, v_num=0]
Epoch 8: 81%|████████ | 95/118 [00:01<00:00, 79.63it/s, loss=0.0485, v_num=0]
Epoch 8: 81%|████████▏ | 96/118 [00:01<00:00, 80.00it/s, loss=0.0497, v_num=0]
Epoch 8: 92%|█████████▏| 108/118 [00:01<00:00, 82.93it/s, loss=0.059, v_num=0]
(RayTrainWorker pid=17232)
Validation: 0it [00:00, ?it/s]2)
(RayTrainWorker pid=17232)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 8: 92%|█████████▏| 109/118 [00:01<00:00, 72.21it/s, loss=0.059, v_num=0]
(RayTrainWorker pid=17232)
Epoch 8: 93%|█████████▎| 110/118 [00:01<00:00, 72.61it/s, loss=0.059, v_num=0]
Epoch 8: 94%|█████████▍| 111/118 [00:01<00:00, 72.91it/s, loss=0.059, v_num=0]
Epoch 8: 95%|█████████▍| 112/118 [00:01<00:00, 73.36it/s, loss=0.059, v_num=0]
Epoch 8: 96%|█████████▌| 113/118 [00:01<00:00, 73.53it/s, loss=0.059, v_num=0]
Epoch 8: 97%|█████████▋| 114/118 [00:01<00:00, 74.02it/s, loss=0.059, v_num=0]
Epoch 8: 97%|█████████▋| 115/118 [00:01<00:00, 74.44it/s, loss=0.059, v_num=0]
Epoch 8: 98%|█████████▊| 116/118 [00:01<00:00, 74.94it/s, loss=0.059, v_num=0]
Epoch 8: 99%|█████████▉| 117/118 [00:01<00:00, 75.06it/s, loss=0.059, v_num=0]
Epoch 8: 100%|██████████| 118/118 [00:01<00:00, 75.20it/s, loss=0.059, v_num=0]
Epoch 8: 100%|██████████| 118/118 [00:01<00:00, 75.15it/s, loss=0.059, v_num=0]
Epoch 9: 0%| | 0/118 [00:00<?, ?it/s, loss=0.059, v_num=0]
Epoch 9: 4%|▍ | 5/118 [00:00<00:06, 18.72it/s, loss=0.0632, v_num=0]
Epoch 9: 5%|▌ | 6/118 [00:00<00:05, 21.95it/s, loss=0.0632, v_num=0]
Epoch 9: 14%|█▎ | 16/118 [00:00<00:02, 42.58it/s, loss=0.0603, v_num=0]
Epoch 9: 21%|██ | 25/118 [00:00<00:01, 52.25it/s, loss=0.0543, v_num=0]
Epoch 9: 31%|███ | 36/118 [00:00<00:01, 61.88it/s, loss=0.0572, v_num=0]
Epoch 9: 31%|███ | 36/118 [00:00<00:01, 61.81it/s, loss=0.0562, v_num=0]
Epoch 9: 39%|███▉ | 46/118 [00:00<00:01, 66.98it/s, loss=0.0504, v_num=0]
Epoch 9: 47%|████▋ | 56/118 [00:00<00:00, 71.09it/s, loss=0.0501, v_num=0]
Epoch 9: 57%|█████▋ | 67/118 [00:00<00:00, 75.35it/s, loss=0.0489, v_num=0]
Epoch 9: 64%|██████▍ | 76/118 [00:00<00:00, 76.36it/s, loss=0.0412, v_num=0]
Epoch 9: 74%|███████▎ | 87/118 [00:01<00:00, 79.08it/s, loss=0.0452, v_num=0]
Epoch 9: 82%|████████▏ | 97/118 [00:01<00:00, 81.13it/s, loss=0.0506, v_num=0]
Epoch 9: 83%|████████▎ | 98/118 [00:01<00:00, 81.21it/s, loss=0.0506, v_num=0]
Epoch 9: 92%|█████████▏| 108/118 [00:01<00:00, 84.51it/s, loss=0.0563, v_num=0]
(RayTrainWorker pid=17232)
Validation: 0it [00:00, ?it/s]2)
(RayTrainWorker pid=17232)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 9: 92%|█████████▏| 109/118 [00:01<00:00, 74.09it/s, loss=0.0563, v_num=0]
Epoch 9: 93%|█████████▎| 110/118 [00:01<00:00, 74.36it/s, loss=0.0563, v_num=0]
Epoch 9: 94%|█████████▍| 111/118 [00:01<00:00, 74.67it/s, loss=0.0563, v_num=0]
Epoch 9: 95%|█████████▍| 112/118 [00:01<00:00, 75.07it/s, loss=0.0563, v_num=0]
Epoch 9: 96%|█████████▌| 113/118 [00:01<00:00, 75.46it/s, loss=0.0563, v_num=0]
Epoch 9: 97%|█████████▋| 114/118 [00:01<00:00, 75.88it/s, loss=0.0563, v_num=0]
Epoch 9: 97%|█████████▋| 115/118 [00:01<00:00, 76.36it/s, loss=0.0563, v_num=0]
(RayTrainWorker pid=17232)
Epoch 9: 98%|█████████▊| 116/118 [00:01<00:00, 75.64it/s, loss=0.0563, v_num=0]
Epoch 9: 99%|█████████▉| 117/118 [00:01<00:00, 76.00it/s, loss=0.0563, v_num=0]
Epoch 9: 100%|██████████| 118/118 [00:01<00:00, 76.04it/s, loss=0.0563, v_num=0]
Epoch 9: 100%|██████████| 118/118 [00:01<00:00, 75.99it/s, loss=0.0563, v_num=0]
Epoch 9: 100%|██████████| 118/118 [00:01<00:00, 68.09it/s, loss=0.0563, v_num=0]
2023-06-13 16:05:52,777 INFO tune.py:1111 -- Total run time: 39.46 seconds (39.28 seconds for the tuning loop).
Validation Accuracy: 0.9700015783309937
Result(
metrics={'_report_on': 'train_epoch_end', 'train_loss': 0.03159911185503006, 'val_accuracy': 0.9700015783309937, 'val_loss': -12.346744537353516, 'epoch': 9, 'step': 1080, 'should_checkpoint': True, 'done': True, 'trial_id': 'c0d28_00000', 'experiment_tag': '0'},
path='/tmp/ray_results/ptl-mnist-example/LightningTrainer_c0d28_00000_0_2023-06-13_16-05-13',
checkpoint=LightningCheckpoint(local_path=/tmp/ray_results/ptl-mnist-example/LightningTrainer_c0d28_00000_0_2023-06-13_16-05-13/checkpoint_000009)
)
Evaluate your model on test dataset#
Next, let’s evaluate the model’s performance on the MNIST test set. We will first retrieve the best checkpoint from the fitting results and load it into the model.
If you lost the in-memory result object, you can also restore the model from the checkpoint file. Here the checkpoint path is: /tmp/ray_results/ptl-mnist-example/LightningTrainer_c0d28_00000_0_2023-06-13_16-05-13/checkpoint_000009/model.
checkpoint: LightningCheckpoint = result.checkpoint
best_model: pl.LightningModule = checkpoint.get_model(MNISTClassifier)
Single-node Testing#
If you have a relatively small test set, like MNIST, the easiest way is to use PyTorch Lightning’s native interface to evaluate the best model. Pass the loaded model and test data loader to pl.Trainer.test(), which will execute the test loop using your custom pl.LightningModule.test_step() method on your head node.
# Download and setup MNIST datamodule on the head node
datamodule.setup()
test_dataloader = datamodule.test_dataloader()
trainer = pl.Trainer()
result = trainer.test(best_model, dataloaders=test_dataloader)
/home/ray/anaconda3/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:92: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
rank_zero_warn(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/ray/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1814: PossibleUserWarning: GPU available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='gpu', devices=1)`.
rank_zero_warn(
2023-06-13 16:05:53.932195: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-13 16:05:54.097738: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-06-13 16:05:55.022170: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-06-13 16:05:55.022249: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-06-13 16:05:55.022255: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_accuracy │ 0.9740999937057495 │ └───────────────────────────┴───────────────────────────┘
Multi-node Testing#
Alternatively, if you have a large test set and want to speed up the testing process in parallel, you can create a group of Ray Actors to leverage multiple GPUs across multiple nodes for distributed inference. Here we demonstrate how to set up a process group and do evaluation using 4 GPUs.
import ray
import pytorch_lightning as pl
from pytorch_lightning.plugins.environments.lightning_environment import (
LightningEnvironment,
)
from ray.air.util.torch_dist import (
TorchDistributedWorker,
init_torch_dist_process_group,
shutdown_torch_dist_process_group,
)
class RayEnvironment(LightningEnvironment):
"""Setup Lightning DDP training environment for Ray cluster."""
def world_size(self) -> int:
return int(os.environ["WORLD_SIZE"])
def global_rank(self) -> int:
return int(os.environ["RANK"])
def local_rank(self) -> int:
return int(os.environ["LOCAL_RANK"])
def set_world_size(self, size: int) -> None:
# Disable it since `world_size()` directly returns data from AIR session.
pass
def set_global_rank(self, rank: int) -> None:
# Disable it since `global_rank()` directly returns data from AIR session.
pass
def teardown(self):
pass
@ray.remote
class TestWorker(TorchDistributedWorker):
def run(self):
trainer = pl.Trainer(
num_nodes=num_workers,
accelerator="gpu",
strategy="ddp",
plugins=[RayEnvironment()],
)
return trainer.test(best_model, dataloaders=test_dataloader)
# Create 4 remote Ray Actors, each with 1 GPU
workers = [TestWorker.options(num_gpus=1).remote() for _ in range(num_workers)]
# Initialize the Torch distributed group among the 4 actors.
# This will set up the required environment variables including
# RANK, LOCAL_RANK, WORLD_SIZE, MASTER_ADDRESS, ...
init_torch_dist_process_group(workers=workers, backend="nccl")
# Execute the testing run in parallel
results = ray.get([worker.run.remote() for worker in workers])
# Shutdown the process group
shutdown_torch_dist_process_group(workers=workers)
2023-06-13 16:05:56,270 WARNING worker.py:2019 -- Warning: The actor TestWorker is very large (53 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.
(RayTrainWorker pid=17232) Global seed set to 888 [repeated 7x across cluster]
(RayTrainWorker pid=7319, ip=10.0.58.90) Missing logger folder: logs/lightning_logs [repeated 3x across cluster]
(RayTrainWorker pid=7319, ip=10.0.58.90) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] [repeated 3x across cluster]
(RayTrainWorker pid=6371, ip=10.0.1.80) [W reducer.cpp:1298] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) [repeated 3x across cluster]
(pid=9162, ip=10.0.26.229) from pandas import MultiIndex, Int64Index
(pid=9162, ip=10.0.26.229) from pandas import MultiIndex, Int64Index
(pid=9976, ip=10.0.58.90) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.
(pid=9976, ip=10.0.58.90) from pandas import MultiIndex, Int64Index
(TestWorker pid=20600) /home/ray/anaconda3/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:92: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
(TestWorker pid=20600) rank_zero_warn(
(TestWorker pid=20600) GPU available: True, used: True
(TestWorker pid=20600) TPU available: False, using: 0 TPU cores
(TestWorker pid=20600) IPU available: False, using: 0 IPUs
(TestWorker pid=20600) HPU available: False, using: 0 HPUs
(TestWorker pid=20600) /home/ray/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:330: PossibleUserWarning: Using `DistributedSampler` with the dataloaders. During `trainer.test()`, it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates some samples to make sure all devices have same batch size in case of uneven inputs.
(TestWorker pid=20600) rank_zero_warn(
Testing: 0it [00:00, ?it/s]600)
Testing DataLoader 0: 0%| | 0/20 [00:00<?, ?it/s]
Testing DataLoader 0: 10%|█ | 2/20 [00:00<00:13, 1.36it/s]
Testing DataLoader 0: 75%|███████▌ | 15/20 [00:00<00:00, 22.10it/s]
(TestWorker pid=20600) 2023-06-13 16:06:07.550225: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
(TestWorker pid=20600) To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
(TestWorker pid=9976, ip=10.0.58.90) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] [repeated 4x across cluster]
(pid=20600) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead. [repeated 2x across cluster]
(pid=20600) from pandas import MultiIndex, Int64Index [repeated 2x across cluster]
(TestWorker pid=20600) 2023-06-13 16:06:07.708119: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Testing DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 22.10it/s]
(TestWorker pid=20600) 2023-06-13 16:06:08.680418: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
(TestWorker pid=20600) 2023-06-13 16:06:08.680524: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
(TestWorker pid=20600) 2023-06-13 16:06:08.680532: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Testing DataLoader 0: 100%|██████████| 20/20 [00:02<00:00, 7.31it/s]
(TestWorker pid=20600) ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
(TestWorker pid=20600) ┃ Test metric ┃ DataLoader 0 ┃
(TestWorker pid=20600) ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
(TestWorker pid=20600) │ test_accuracy │ 0.9740999937057495 │
(TestWorker pid=20600) └───────────────────────────┴───────────────────────────┘