Using Experiment Tracking Tools in LightningTrainer#

W&B, CometML, MLFlow, and Tensorboard are all popular tools in the field of machine learning for managing, visualizing, and tracking experiments. The LightningTrainer integration in Ray AIR allows you to continue using these built-in experiment tracking integrations.

Note

This guide shows how to use the native Logger integrations in PyTorch Lightning. Ray AIR also provides experiment tracking integrations for all the tools mentioned in this example. We recommend sticking with the PyTorch Lightning loggers.

Define your model and dataloader#

In this example, we simply create a dummy model with dummy datasets for demonstration. There is no need for any code change here. We report 3 metrics(“train_loss”, “metric_1”, “metric_2”) in the training loop. Lightning’s Loggers will capture and report them to the corresponding experiment tracking tools.

import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import TensorDataset, DataLoader

# create dummy data
X = torch.randn(128, 3)  # 128 samples, 3 features
y = torch.randint(0, 2, (128,))  # 128 binary labels

# create a TensorDataset to wrap the data
dataset = TensorDataset(X, y)

# create a DataLoader to iterate over the dataset
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Define a dummy model
class DummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(3, 1)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float())

        # The metrics below will be reported to Loggers
        self.log("train_loss", loss)
        self.log_dict({"metric_1": 1 / (batch_idx + 1), "metric_2": batch_idx * 100})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

Define your loggers#

For offline loggers, no changes are required in the Logger initialization.

For online loggers (W&B and CometML), you need to do two things:

  • Set up your API keys as environment variables.

  • Set rank_zero_only.rank = None to avoid Lightning creating a new experiment run on the driver node.

from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.loggers.comet import CometLogger
from pytorch_lightning.loggers.mlflow import MLFlowLogger
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
import wandb


# A callback to login wandb in each worker
class WandbLoginCallback(pl.Callback):
    def __init__(self, key):
        self.key = key

    def setup(self, trainer, pl_module, stage) -> None:
        wandb.login(key=self.key)


def create_loggers(name, project_name, save_dir="./logs", offline=False):
    # Avoid creating a new experiment run on the driver node.
    rank_zero_only.rank = None

    # Wandb
    wandb_api_key = os.environ.get("WANDB_API_KEY", None)

    class RayWandbLogger(WandbLogger):
        # wandb.finish() ensures all artifacts get uploaded at the end of training.
        def finalize(self, status):
            super().finalize(status)
            wandb.finish()

    wandb_logger = RayWandbLogger(
        name=name, 
        project=project_name, 
        # Specify a unique id to avoid reporting to a new run after restoration
        id="unique_id", 
        save_dir=f"{save_dir}/wandb", 
        offline=offline
    )
    callbacks = [] if offline else [WandbLoginCallback(key=wandb_api_key)]

    # CometML
    comet_api_key = os.environ.get("COMET_API_KEY", None)
    comet_logger = CometLogger(
        api_key=comet_api_key,
        experiment_name=name,
        project_name=project_name,
        save_dir=f"{save_dir}/comet",
        offline=offline,
    )

    # MLFlow
    mlflow_logger = MLFlowLogger(
        run_name=name,
        experiment_name=project_name,
        tracking_uri=f"file:{save_dir}/mlflow",
    )

    # Tensorboard
    tensorboard_logger = TensorBoardLogger(
        name=name, save_dir=f"{save_dir}/tensorboard"
    )

    return [wandb_logger, comet_logger, mlflow_logger, tensorboard_logger], callbacks
CometLogger will be initialized in online mode
YOUR_SAVE_DIR = "./logs"
loggers, callbacks = create_loggers(
    name="demo-run", project_name="demo-project", save_dir=YOUR_SAVE_DIR, offline=False
)

Train the model and view logged results#

from ray.air.config import RunConfig, ScalingConfig
from ray.train.lightning import LightningConfigBuilder, LightningTrainer

builder = LightningConfigBuilder()
builder.module(cls=DummyModel)
builder.trainer(
    max_epochs=5,
    accelerator="cpu",
    logger=loggers,
    callbacks=callbacks,
    log_every_n_steps=1,
)
builder.fit_params(train_dataloaders=dataloader)

lightning_config = builder.build()

scaling_config = ScalingConfig(num_workers=4, use_gpu=False)

run_config = RunConfig(
    name="ptl-exp-tracking",
    storage_path="/tmp/ray_results",
)

trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

trainer.fit()

Now let’s take a look at our experiment results!

Wandb alt

CometML alt

Tensorboard

MLFlow