import functools
import os
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Tuple, Union
from ray import train
from ray.train import Checkpoint, RunConfig, ScalingConfig
from ray.train import DataConfig
from ray.train.torch import TorchConfig
from ray.train.trainer import GenDataset
from ray.train.torch import TorchTrainer, get_device
from ray.train.torch.config import _set_torch_distributed_env_vars
ACCELERATE_IMPORT_ERROR: Optional[ImportError] = None
try:
from ray.train.huggingface.accelerate._accelerate_utils import (
launch_command,
AccelerateDefaultNamespace,
AccelerateConfigWrapper,
load_accelerate_config,
)
except ImportError as e:
ACCELERATE_IMPORT_ERROR = e
launch_command = None
AccelerateDefaultNamespace = None
AccelerateConfigWrapper = None
load_accelerate_config = None
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
from ray.tune.trainable import Trainable
[docs]class AccelerateTrainer(TorchTrainer):
"""A Trainer for data parallel HuggingFace Accelerate training with PyTorch.
This Trainer is a wrapper around the :class:`~ray.train.torch.TorchTrainer`,
providing the following extra functionality:
1. Loading and parsing of Accelerate configuration files (created by
``accelerate config`` CLI command),
2. Applying the configuration files on all workers, making sure the environment
is set up correctly.
This Trainer runs the function ``train_loop_per_worker`` on multiple Ray
Actors. These actors already have the necessary torch process group
configured for distributed PyTorch training, as well as all environment variables
required by Accelerate, as defined in the configuration file. This allows you
to use Accelerate APIs (such as ``Accelerator``) inside ``train_loop_per_worker``
as you would without Ray.
Inside the ``train_loop_per_worker`` function, In addition to Accelerate APIs, you
can use any of the :ref:`Ray AIR session methods <air-session-ref>`.
See full example code below.
.. testcode::
def train_loop_per_worker():
# Report intermediate results for callbacks or logging and
# checkpoint data.
train.report(...)
# Get dict of last saved checkpoint.
train.get_checkpoint()
# Get the Dataset shard for the given key.
train.get_dataset_shard("my_dataset")
# Get the total number of workers executing training.
train.get_context().get_world_size()
# Get the rank of this worker.
train.get_context().get_world_rank()
# Get the rank of the worker on the current node.
train.get_context().get_local_rank()
For more information, see the documentation of
:class:`~ray.train.torch.TorchTrainer`.
.. note::
You need to use ``ray.train.report()`` to communicate results and checkpoints
back to Ray Train.
Accelerate integrations with DeepSpeed, FSDP, MegatronLM etc. are fully supported.
If the Accelerate configuration contains a path to a DeepSpeed config file
(``deepspeed_config_file``), that file will also be loaded and applied on the
workers.
The following Accelerate configuration options will be ignored and automatically
set by the Trainer according to Ray AIR configs (eg. ``ScalingConfig``):
- Number of machines (``num_machines``)
- Number of processes (``num_processes``)
- Rank of the current machine (``machine_rank``)
- Local rank of the current machine
- GPU IDs (``gpu_ids``)
- Number of CPU threads per process (``num_cpu_threads_per_process``)
- IP of the head process (``main_process_ip``)
- Port of the head process (``main_process_port``)
- Whether all machines are on the same network (``same_network``)
- Whether to force a CPU-only mode (``cpu``/``use_cpu``)
- rdzv backend (``rdzv_backend``)
- Main training function (``main_training_function``)
- Type of launcher
This Trainer requires ``accelerate>=0.17.0`` package.
Example:
.. testcode::
import torch
import torch.nn as nn
from accelerate import Accelerator
import ray
from ray import train
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train.huggingface import AccelerateTrainer
# If using GPUs, set this to True.
use_gpu = False
# Define NN layers archicture, epochs, and number of workers
input_size = 1
layer_size = 32
output_size = 1
num_epochs = 30
num_workers = 3
# Define your network structure
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.layer1 = nn.Linear(input_size, layer_size)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(layer_size, output_size)
def forward(self, input):
return self.layer2(self.relu(self.layer1(input)))
# Define your train worker loop
def train_loop_per_worker():
torch.manual_seed(42)
# Initialize the Accelerator
accelerator = Accelerator()
# Fetch training set
dataset_shard = train.get_dataset_shard("train")
model = NeuralNetwork()
# Loss function, optimizer, prepare model for training.
# This moves the data and prepares model for distributed
# execution
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(
model.parameters(), lr=0.01, weight_decay=0.01
)
model, optimizer = accelerator.prepare(model, optimizer)
# Iterate over epochs and batches
for epoch in range(num_epochs):
for batches in dataset_shard.iter_torch_batches(
batch_size=32, dtypes=torch.float
):
# Add batch or unsqueeze as an additional dimension [32, x]
inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"]
output = model(inputs)
# Make output shape same as the as labels
loss = loss_fn(output.squeeze(), labels)
# Zero out grads, do backward, and update optimizer
optimizer.zero_grad()
accelerator.backward(loss)
optimizer.step()
# Print what's happening with loss per 30 epochs
if epoch % 20 == 0:
print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}")
# Report and record metrics, checkpoint model at end of each
# epoch
train.report(
{"loss": loss.item(), "epoch": epoch},
checkpoint=Checkpoint.from_dict(
dict(
epoch=epoch,
model=accelerator.unwrap_model(model).state_dict(),
)
),
)
train_dataset = ray.data.from_items(
[{"x": x, "y": 2 * x + 1} for x in range(2000)]
)
# Define scaling and run configs
scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))
trainer = AccelerateTrainer(
train_loop_per_worker=train_loop_per_worker,
# Instead of using a dict, you can run ``accelerate config``.
# The default value of None will then load that configuration
# file.
accelerate_config={},
scaling_config=scaling_config,
run_config=run_config,
datasets={"train": train_dataset},
)
result = trainer.fit()
best_checkpoint_loss = result.metrics["loss"]
# Assert loss is less 0.09
assert best_checkpoint_loss <= 0.09
.. testoutput::
:hide:
...
Args:
train_loop_per_worker: The training function to execute.
This can either take in no arguments or a ``config`` dict.
train_loop_config: Configurations to pass into
``train_loop_per_worker`` if it accepts an argument.
accelerate_config: Accelerate configuration to be applied on every worker.
This can be a path to a file generated with ``accelerate config``,
a configuration dict or None, in which case it will load the configuration
file from the default location as defined by Accelerate.
torch_config: Configuration for setting up the PyTorch backend. If set to
None, use the default configuration. This replaces the ``backend_config``
arg of ``DataParallelTrainer``.
scaling_config: Configuration for how to scale data parallel training.
dataset_config: Configuration for dataset ingest.
run_config: Configuration for the execution of the training run.
datasets: Any Datasets to use for training. Use
the key "train" to denote which dataset is the training
dataset. If a ``preprocessor`` is provided and has not already been fit,
it will be fit on the training dataset. All datasets will be transformed
by the ``preprocessor`` if one is provided.
preprocessor: A ``ray.data.Preprocessor`` to preprocess the
provided datasets.
resume_from_checkpoint: A checkpoint to resume training from.
"""
def __init__(
self,
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
*,
train_loop_config: Optional[Dict] = None,
accelerate_config: Optional[Union[dict, str, Path, os.PathLike]] = None,
torch_config: Optional[TorchConfig] = None,
scaling_config: Optional[ScalingConfig] = None,
dataset_config: Optional[DataConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
):
if ACCELERATE_IMPORT_ERROR is not None:
raise ACCELERATE_IMPORT_ERROR
self.accelerate_config = accelerate_config
(
self._accelerate_config_raw,
self._deepspeed_config_file_raw,
) = self._unwrap_accelerate_config_if_needed(accelerate_config)
super().__init__(
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
torch_config=torch_config,
scaling_config=scaling_config,
dataset_config=dataset_config,
run_config=run_config,
datasets=datasets,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
)
def _unwrap_accelerate_config_if_needed(
self,
accelerate_config: Optional[
Union[dict, str, Path, os.PathLike, AccelerateConfigWrapper]
],
) -> Tuple[str, Optional[str]]:
# The AccelerateConfigWrapper is used to deal with the issue of the
# Trainer being initialized twice (by the user and by us in the Trainable).
# If it's initialized by the user, accelerate_config will not be an instance
# of AccelerateConfigWrapper. This means we should read the config file from
# given path.
# If accelerate_config is an instance of AccelerateConfigWrapper, that means
# we are dealing with a file that was already read, and we should instead use
# the string in the wrapper as the raw contents of the file. This should
# only happen internally, during initialization of this class in the Trainable.
if isinstance(accelerate_config, AccelerateConfigWrapper):
return (
accelerate_config.config_raw,
accelerate_config.deepspeed_config_raw,
)
else:
return load_accelerate_config(accelerate_config)
def as_trainable(self) -> Type["Trainable"]:
# We want to load the config when the Trainer is first instantiated,
# and share the contents with the Trainables (which may be on different)
# nodes
old_accelerate_config = self._param_dict.get("accelerate_config", None)
self._param_dict["accelerate_config"] = AccelerateConfigWrapper(
self._accelerate_config_raw, self._deepspeed_config_file_raw
)
try:
ret = super().as_trainable()
finally:
self._param_dict["accelerate_config"] = old_accelerate_config
return ret
def training_loop(self) -> None:
old_train_loop_per_worker = self._train_loop_per_worker
self._train_loop_per_worker = self._wrap_train_loop_per_worker(
self._train_loop_per_worker,
self._accelerate_config_raw,
self._deepspeed_config_file_raw,
)
try:
ret = super().training_loop()
finally:
self._train_loop_per_worker = old_train_loop_per_worker
return ret
@classmethod
def _wrap_train_loop_per_worker(
cls,
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
accelerate_config_raw: str,
deepspeed_config_file_raw: str,
):
"""Wrap around train_loop_per_worker to set necessary Accelerate env vars."""
@functools.wraps(train_loop_per_worker)
def _accelerate_train_loop_per_worker(*args, **kwargs):
with tempfile.TemporaryDirectory() as tempdir:
# Write Accelerate config to file so it can be read
# by Accelerate
temp_config_file = os.path.join(tempdir, "default_config.yaml")
with open(temp_config_file, "w") as f:
f.write(accelerate_config_raw)
# Set by TorchBackend
master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
namespace = AccelerateDefaultNamespace()
namespace.config_file = temp_config_file
namespace.num_processes = 1
namespace.num_machines = train.get_context().get_world_size()
namespace.machine_rank = train.get_context().get_world_rank()
namespace.num_cpu_threads_per_process = (
train.get_context().get_trial_resources().bundles[-1].get("CPU", 1)
)
namespace.gpu_ids = None
namespace.main_process_ip = master_addr
namespace.main_process_port = master_port
namespace.same_network = False
device = get_device()
if isinstance(device, list):
device = device[0]
if device.type == "cpu":
os.environ["LOCAL_RANK"] = "-1"
namespace.use_cpu = True
else:
namespace.use_cpu = False
# Handle DeepSpeed config
if isinstance(deepspeed_config_file_raw, dict):
namespace.deepspeed_config_file = deepspeed_config_file_raw
elif deepspeed_config_file_raw:
deepspeed_config_file = os.path.join(
tempdir, "deepspeed_config.json"
)
with open(deepspeed_config_file, "w") as f:
f.write(deepspeed_config_file_raw)
namespace.deepspeed_config_file = deepspeed_config_file
# Let Accelerate set all env vars
launch_command(namespace)
# Set our env vars again to override Accelerate
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
_set_torch_distributed_env_vars()
if device.type == "cpu":
os.environ["LOCAL_RANK"] = "-1"
return train_loop_per_worker(*args, **kwargs)
return _accelerate_train_loop_per_worker