Ray AIR Integrations with ML Libraries#

PyTorch#

There are 2 recommended ways to train PyTorch models on a Ray cluster.

Note

If you’re training PyTorch models with PyTorch Lightning, see below for the available PyTorch Lightning Ray AIR integrations.

See the options 1️⃣ 2️⃣ below, along with the usage scenarios and API references for each:

1️⃣ Vanilla PyTorch with Ray Tune#

Usage Scenario: Non-distributed training, where the dataset is relatively small and there are many trials (e.g., many hyperparameter configurations). Use vanilla PyTorch with Ray Tune to parallelize model training.

2️⃣ TorchTrainer#

Usage Scenario: Data-parallel training, such as multi-GPU or multi-node training.

TorchTrainer(*args, **kwargs)

A Trainer for data parallel PyTorch training.

PyTorch Lightning#

There are 2 recommended ways to train with PyTorch Lightning on a Ray cluster.

See the options 1️⃣ 2️⃣ below, along with the usage scenarios and API references for each:

1️⃣ Vanilla PyTorch Lightning with a Ray Callback#

Usage Scenario: Non-distributed training, where the dataset is relatively small and there are many trials (e.g., many hyperparameter configurations). Use vanilla PyTorch Lightning with Ray Tune to parallelize model training.

TuneReportCallback([metrics, on])

PyTorch Lightning to Ray Tune reporting callback

TuneReportCheckpointCallback([metrics, ...])

PyTorch Lightning report and checkpoint callback

2️⃣ LightningTrainer#

Usage Scenario: Distributed training, such as multi-GPU or multi-node data-parallel training.

LightningTrainer(*args, **kwargs)

A Trainer for data parallel PyTorch Lightning training.

Tensorflow/Keras#

There are 2 recommended ways to train Tensorflow/Keras models with Ray.

See the options 1️⃣ 2️⃣ below, along with the usage scenarios and API references for each:

1️⃣ Vanilla Keras with a Ray Callback#

Usage Scenario: Non-distributed training, where the dataset is relatively small and there are many trials (e.g., many hyperparameter configurations). Use vanilla Tensorflow/Keras with Ray Tune to parallelize model training.

ReportCheckpointCallback([checkpoint_on, ...])

Keras callback for Ray AIR reporting and checkpointing.

2️⃣ TensorflowTrainer#

Usage Scenario: Data-parallel training, such as multi-GPU or multi-node training.

TensorflowTrainer(*args, **kwargs)

A Trainer for data parallel Tensorflow training.

ReportCheckpointCallback([checkpoint_on, ...])

Keras callback for Ray AIR reporting and checkpointing.

XGBoost#

There are 3 recommended ways to train XGBoost models with Ray.

See the options 1️⃣ 2️⃣ 3️⃣ below, along with the usage scenarios and API references for each:

1️⃣ Vanilla XGBoost with a Ray Callback#

Usage Scenario: Non-distributed training, where the dataset is relatively small and there are many trials (e.g., many hyperparameter configurations). Use vanilla XGBoost with these Ray Tune callbacks to parallelize model training.

TuneReportCallback([metrics, ...])

XGBoost to Ray Tune reporting callback

TuneReportCheckpointCallback([metrics, ...])

XGBoost report and checkpoint callback

2️⃣ XGBoostTrainer#

Usage Scenario: Data-parallel training, such as multi-GPU or multi-node training.

XGBoostTrainer(*args, **kwargs)

A Trainer for data parallel XGBoost training.

3️⃣ xgboost_ray#

Usage Scenario: Use as a (nearly) drop-in replacement for the regular xgboost API, with added support for distributed training on a Ray cluster.

See the xgboost_ray documentation.

LightGBM#

There are 3 recommended ways to train LightGBM models with Ray.

See the options 1️⃣ 2️⃣ 3️⃣ below, along with the usage scenarios and API references for each:

1️⃣ Vanilla LightGBM with a Ray Callback#

Usage Scenario: Non-distributed training, where the dataset is relatively small and there are many trials (e.g., many hyperparameter configurations). Use vanilla LightGBM with these Ray Tune callbacks to parallelize model training.

TuneReportCallback([metrics, ...])

Create a callback that reports metrics to Ray Tune.

TuneReportCheckpointCallback([metrics, ...])

Creates a callback that reports metrics and checkpoints model.

2️⃣ LightGBMTrainer#

Usage Scenario: Data-parallel training, such as multi-GPU or multi-node training.

LightGBMTrainer(*args, **kwargs)

A Trainer for data parallel LightGBM training.

3️⃣ lightgbm_ray#

Usage Scenario: Use as a (nearly) drop-in replacement for the regular lightgbm API, with added support for distributed training on a Ray cluster.

See the lightgbm_ray documentation.