Ray Train API
Contents
Ray Train API#
This page covers framework specific integrations with Ray Train and Ray Train Developer APIs.
For core Ray AIR APIs, take a look at the AIR package reference.
Ray Train Base Classes (Developer APIs)#
Trainer Base Classes#
|
Defines interface for distributed training on Ray. |
|
A Trainer for data parallel training. |
|
Class responsible for configuring Train dataset preprocessing. |
|
Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks. |
BaseTrainer API#
|
Runs training. |
|
Called during fit() to perform initial setup on the Trainer. |
Called during fit() to preprocess dataset attributes with preprocessor. |
|
Loop called by fit() to run training and report results to Tune. |
|
Converts self to a |
Train Backend Base Classes#
|
Singleton for distributed communication backend. |
Parent class for configurations of training backend. |
Ray Train Config#
|
Class responsible for configuring Train dataset preprocessing. |
Ray Train Loop#
Context for Ray training executions. |
|
Get or create a singleton training context. |
|
|
Returns the |
|
Report metrics and optionally save a checkpoint. |
Ray Train Checkpoints#
|
Ray AIR Checkpoint. |
Ray Train Context#
Get or create a singleton training context. |
|
Context for Ray training executions. |
Ray Train Integrations#
PyTorch#
|
A Trainer for data parallel PyTorch training. |
|
Configuration for torch process group setup. |
|
A |
PyTorch Training Loop Utilities#
|
Prepares the model for distributed execution. |
|
Wraps optimizer to support automatic mixed precision. |
|
Prepares DataLoader for distributed execution. |
Gets the correct torch device configured for this process. |
|
|
Enables training optimizations. |
|
Computes the gradient of the specified tensor w.r.t. |
|
Limits sources of nondeterministic behavior. |
PyTorch Lightning#
|
A Trainer for data parallel PyTorch Lightning training. |
Configuration Class to pass into LightningTrainer. |
|
|
A |
|
A predictor for PyTorch Lightning modules. |
Tensorflow/Keras#
|
A Trainer for data parallel Tensorflow training. |
PublicAPI (beta): This API is in beta and may change before becoming stable. |
|
|
A |
Tensorflow/Keras Training Loop Utilities#
|
A utility function that overrides default config for Tensorflow Dataset. |
|
Keras callback for Ray AIR reporting and checkpointing. |
Horovod#
|
A Trainer for data parallel Horovod training. |
|
Configurations for Horovod setup. |
XGBoost#
|
A Trainer for data parallel XGBoost training. |
|
A |
LightGBM#
|
A Trainer for data parallel LightGBM training. |
|
A |
Hugging Face#
Transformers#
|
A Trainer for data parallel HuggingFace Transformers on PyTorch training. |
|
A |
Accelerate#
|
A Trainer for data parallel HuggingFace Accelerate training with PyTorch. |
Scikit-Learn#
|
A Trainer for scikit-learn estimator training. |
|
A |
Mosaic#
|
A Trainer for data parallel Mosaic Composers on PyTorch training. |
Ray Train Experiment Restoration#
|
Restores a Train experiment from a previously interrupted/failed run. |
Note
All trainer classes have a restore method that takes in a path
pointing to the directory of the experiment to be restored.
restore also exposes a subset of construtor arguments that can be re-specified.
See Restoration API for Built-in Trainers
below for details on restore arguments for different AIR trainer integrations.
Restoration API for Built-in Trainers#
|
Restores a DataParallelTrainer from a previously interrupted/failed run. |
Restores a TransformersTrainer from a previously interrupted/failed run. |
Note
TorchTrainer.restore, TensorflowTrainer.restore, and HorovodTrainer.restore
can take in the same parameters as their parent class’s
DataParallelTrainer.restore.
Unless otherwise specified, other trainers will accept the same parameters as
BaseTrainer.restore.
See also
See Restore a Ray Train Experiment for more details on when and how trainer restore should be used.