Source code for ray.train._internal.session

import os
import logging
import platform
import queue
import sys
import threading
import time
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto
import functools
from pathlib import Path
import shutil
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
import warnings

import ray
from ray.air._internal.session import _get_session
from ray.air._internal.util import StartTraceback, RunnerThread
from ray.air.checkpoint import Checkpoint
from ray.air.constants import (
    _RESULT_FETCH_TIMEOUT,
    _ERROR_FETCH_TIMEOUT,
    SESSION_MISUSE_LOG_ONCE_KEY,
    TIMESTAMP,
    TIME_THIS_ITER_S,
)
from ray.data import Dataset, DatasetPipeline
from ray.train._internal.accelerator import Accelerator
from ray.train.constants import (
    CHECKPOINT_METADATA_KEY,
    CHECKPOINT_RANK_KEY,
    DETAILED_AUTOFILLED_KEYS,
    WORKER_HOSTNAME,
    WORKER_NODE_IP,
    WORKER_PID,
    TIME_TOTAL_S,
    LAZY_CHECKPOINT_MARKER_FILE,
)

from ray.train.error import SessionMisuseError
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.debug import log_once
from ray.train._internal.storage import _use_storage_context, StorageContext

if TYPE_CHECKING:
    from ray.data import DataIterator
    from ray.tune.execution.placement_groups import PlacementGroupFactory


_INDEX_FILE_EXTENSION = ".files"
_INDEX_FILE = ".RANK_{0}" + _INDEX_FILE_EXTENSION


class TrainingResultType(Enum):
    REPORT = auto()
    CHECKPOINT = auto()


logger = logging.getLogger(__name__)


@dataclass
class TrialInfo:
    """The trial information to propagate to TrainSession."""

    name: str
    id: str
    resources: Dict[str, float]
    logdir: str
    driver_ip: str
    experiment_name: Optional[str] = None


@dataclass
class TrainingResult:
    type: TrainingResultType
    data: Union[Dict, Checkpoint, str]
    metadata: Optional[Dict] = None


# TODO(xwjiang): This needs a better name.
@DeveloperAPI
class _TrainSession:
    """Holds information for training on each worker."""

    def __init__(
        self,
        training_func: Callable,
        world_rank: int,
        local_rank: int,
        node_rank: int,
        local_world_size: int,
        world_size: int,
        # TODO(xwjiang): Legacy Ray Train trainer clean up!
        trial_info: Optional[TrialInfo] = None,
        dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None,
        # TODO(xwjiang): Legacy Ray Train trainer clean up!
        checkpoint: Optional[Checkpoint] = None,
        # Deprecated
        encode_data_fn: Optional[Callable] = None,
        detailed_autofilled_metrics: bool = False,
        # If True and the worker is on the same node as driver,
        # will send over checkpoint path and metadata instead of
        # the whole checkpoint to avoid unnecessary serialization.
        enable_lazy_checkpointing: bool = True,
        checkpoint_keep_all_ranks: bool = False,
        checkpoint_upload_from_workers: bool = False,
        storage: Optional[StorageContext] = None,
    ):

        self.dataset_shard = dataset_shard

        self.world_rank = world_rank
        self.local_rank = local_rank
        self.node_rank = node_rank
        self.local_world_size = local_world_size
        self.world_size = world_size
        self.trial_info = trial_info
        # TODO(xwjiang): Legacy Ray Train trainer clean up!
        self.loaded_checkpoint = checkpoint
        self.enable_lazy_checkpointing = enable_lazy_checkpointing
        self.checkpoint_keep_all_ranks = checkpoint_keep_all_ranks
        self.checkpoint_upload_from_workers = checkpoint_upload_from_workers

        if _use_storage_context():
            assert storage
            logger.debug(f"StorageContext on TRAIN WORKER {world_rank}:\n{storage}")
            storage._check_validation_file()

        self.storage = storage

        # Only used if checkpoint_upload_from_workers is True.
        self.legacy_checkpoint_uri = None

        # Function to encode checkpoint dict before sending to the driver.
        if not encode_data_fn:

            def noop(x):
                return x

            encode_data_fn = noop
        self._encode_data_fn = encode_data_fn

        if _use_storage_context():
            # Change the working directory to the local trial directory.
            # -> All workers on the same node share a working directory.
            os.makedirs(storage.trial_local_path, exist_ok=True)
            os.chdir(storage.trial_local_path)
        else:
            if trial_info:
                # Change the working directory to `logdir`.
                logdir = os.path.join(trial_info.logdir, f"rank_{self.world_rank}")
                os.makedirs(logdir, exist_ok=True)
                os.chdir(logdir)

        # This lock is used to control the execution of the training thread.
        self.continue_lock = threading.Semaphore(0)

        # Queue for sending results across threads.
        self.result_queue = queue.Queue(1)

        # Queue for raising exceptions from runner thread to main thread.
        # The error queue has a max size of one to prevent stacking error and force
        # error reporting to block until finished.
        self.error_queue = queue.Queue(1)

        # The Thread object that is running the training function.
        self.training_thread = RunnerThread(
            target=training_func, daemon=True, error_queue=self.error_queue
        )

        # Autofilled metrics attributes.
        self.detailed_autofilled_metrics = detailed_autofilled_metrics
        self.last_report_time = time.time()
        self.iteration = 0
        self.time_total = 0.0
        self.local_ip = self.get_current_ip()

        self.ignore_report = False
        self.training_started = False

        self.accelerator = None

    def get_current_ip(self):
        self.local_ip = ray.util.get_node_ip_address()
        return self.local_ip

    def start(self):
        """Starts the training thread."""
        self.training_started = True
        self.training_thread.start()

    def pause_reporting(self):
        """Ignore all future ``session.report()`` calls."""
        self.ignore_report = True

    def finish(self):
        """Finishes the training thread.

        Either returns the output from training or raises any Exception from
        training.
        """

        # Wait for training to finish.
        # This will raise any errors that occur during training, including
        # SystemError
        func_output = self.training_thread.join()
        # If training finished successfully, then return results.
        return func_output

    def get_next(self) -> Optional[TrainingResult]:
        """Gets the next ``TrainingResult`` from the result queue.

        If the result queue is empty, then this function returns ``None``.
        """
        if not self.training_started:
            raise RuntimeError("Please call start before calling get_next.")
        result = None
        # While training is still ongoing, attempt to get the result.
        while result is None and self.training_thread.is_alive():
            try:
                result = self.result_queue.get(
                    block=True, timeout=_RESULT_FETCH_TIMEOUT
                )
            except queue.Empty:
                pass

        # If no result was found, then the runner must no longer be alive.
        if result is None:
            # Try one last time to fetch results in case results were
            # reported in between the time of the last check and the
            # termination of the thread runner.
            try:
                result = self.result_queue.get(
                    block=False, timeout=_RESULT_FETCH_TIMEOUT
                )
            except queue.Empty:
                pass

        # check if error occurred inside the thread runner.
        if result is None:
            # only raise an error from the runner if all results are consumed
            self._report_thread_runner_error(block=True)
        else:
            if not self.error_queue.empty():
                logger.debug(
                    (
                        "Runner error waiting to be raised in main thread. "
                        "Logging all available results first."
                    )
                )

        # Release the lock to trigger training to continue.
        self.continue_lock.release()

        # Return None if there are no more results to fetch.
        return result

    def _auto_fill_metrics(self, result: dict) -> dict:
        """Add autofilled metrics and update attributes."""
        current_time = time.time()
        current_datetime = datetime.now()
        if TIME_THIS_ITER_S in result:
            time_this_iter = result[TIME_THIS_ITER_S]
        else:
            time_this_iter = current_time - self.last_report_time
        self.iteration += 1
        self.time_total += time_this_iter
        self.last_report_time = current_time

        auto_filled_metrics = {
            TIMESTAMP: int(time.mktime(current_datetime.timetuple())),
            TIME_TOTAL_S: self.time_total,
            WORKER_PID: os.getpid(),
            WORKER_HOSTNAME: platform.node(),
            WORKER_NODE_IP: self.local_ip,
        }

        if not self.detailed_autofilled_metrics:
            auto_filled_metrics = {
                k: v
                for k, v in auto_filled_metrics.items()
                if k not in DETAILED_AUTOFILLED_KEYS
            }

        result = result.copy()
        result.update(auto_filled_metrics)
        return result

    def _report_legacy(self, **kwargs):
        """Adds kwargs to the queue to be consumed by main thread."""
        if self.ignore_report:
            return

        kwargs = self._auto_fill_metrics(kwargs)

        result = TrainingResult(type=TrainingResultType.REPORT, data=kwargs)

        # Add result to a thread-safe queue.
        self.result_queue.put(result, block=True)

        # Acquire lock to stop the training thread until main thread
        # triggers resume.
        self.continue_lock.acquire()

    def _auto_fill_checkpoint_metrics(self, result: dict) -> dict:
        """Add autofilled metrics and update attributes."""
        current_datetime = datetime.now()

        auto_filled_metrics = {
            TIMESTAMP: int(time.mktime(current_datetime.timetuple()))
        }
        result = result.copy()
        result.update(auto_filled_metrics)
        return result

    def _report_thread_runner_error(self, block=False):
        try:
            e = self.error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
            raise StartTraceback from e
        except queue.Empty:
            pass

    def _create_checkpoint_file_list(self, checkpoint: Checkpoint):
        """Create an index of the folder contents

        So we know which files belong to which rank.
        """
        root = checkpoint._local_path
        ckpt_files = []
        for dir, _, files in os.walk(root):
            # Strip the root path from the path though, since
            # we are only interested in the part relative to
            # the root of this checkpoint.
            dir = dir[len(root) :]
            for fn in files:
                ckpt_files.append(os.path.join(dir, fn))
        # Write these files into the index file.
        with open(os.path.join(root, _INDEX_FILE.format(self.world_rank)), "w") as f:
            for fn in ckpt_files:
                f.write(f"{fn}\n")

    def _remove_uploaded_checkpoint_files(self, checkpoint: Checkpoint):
        """Get rid of already uploaded large checkpoint files.

        This is so they don't get shipped to the driver node.
        """
        root = checkpoint._local_path
        for f in os.listdir(root):
            if f.endswith(_INDEX_FILE_EXTENSION):
                # We will leave the index file in there so local
                # checkpoint has knowledge about the cloud files.
                continue
            fp = os.path.join(root, f)
            if os.path.isfile(fp):
                os.unlink(fp)
            elif os.path.isdir(fp):
                shutil.rmtree(fp)

    def checkpoint(self, checkpoint: Checkpoint):
        """Adds kwargs to the queue to be consumed by main thread.

        Also stores the checkpoint in ``self.loaded_checkpoint``.
        """
        checkpoint_type, _ = checkpoint.get_internal_representation()

        if checkpoint_type == "data_dict" and self.checkpoint_keep_all_ranks:
            if log_once("keep_all_ranks_dict_checkpoint"):
                logger.warning(
                    "Saving checkpoints from all ranks does not work with "
                    "dictionary checkpoints. Set `ray.train.CheckpointConfig"
                    "(_checkpoint_keep_all_ranks=False)`, or write checkpoints "
                    "to a directory and report directory checkpoints that "
                    "contain unique files per worker rank. For example, "
                    "use filenames that contain the unique rank. You can "
                    "retrieve the rank with `session.get_world_rank()` within "
                    "your training loop per worker."
                )

        upload_from_workers = (
            checkpoint_type == "local_path"
            and self.checkpoint_upload_from_workers
            and self.legacy_checkpoint_uri
        )
        if upload_from_workers:
            self._create_checkpoint_file_list(checkpoint)
            logger.info(
                f"Uploading checkpoint files from worker rank {self.world_rank} "
                f"to cloud URI {self.legacy_checkpoint_uri}."
            )
            # We want to upload the files directly to cloud storage,
            # so that they won't need to be shipped to the driver node
            # via object store.
            checkpoint.to_uri(self.legacy_checkpoint_uri)
            logger.info("Done uploading checkpoint files.")
            self._remove_uploaded_checkpoint_files(checkpoint)

        # Update session checkpoint to latest checkpoint.
        self.loaded_checkpoint = checkpoint

        # Only store checkpoints on worker with rank 0.
        if self.world_rank != 0 and not self.checkpoint_keep_all_ranks:
            checkpoint = None
        elif checkpoint:
            checkpoint = self._encode_data_fn(checkpoint)

        metadata = self._auto_fill_checkpoint_metrics({})

        if (
            checkpoint
            and self.enable_lazy_checkpointing
            and checkpoint._local_path
            and (Path(self.trial_info.logdir) / LAZY_CHECKPOINT_MARKER_FILE).exists()
        ):
            metadata.update({CHECKPOINT_METADATA_KEY: checkpoint._metadata})
            checkpoint = str(checkpoint._local_path)

        # Save the rank of the worker that created this checkpoint.
        metadata.update({CHECKPOINT_RANK_KEY: self.world_rank})

        result = TrainingResult(
            type=TrainingResultType.CHECKPOINT,
            data=checkpoint,
            metadata=metadata,
        )

        # Add result to a thread-safe queue.
        self.result_queue.put(result, block=True)

        # Acquire lock to stop the training thread until
        # checkpoint has been processed.
        self.continue_lock.acquire()

    def _set_legacy_checkpoint_uri(self, uri: str):
        """Tell session where to save the next directory checkpoint on the cloud.

        Args:
            uri: URI to the location where next checkpoint should be saved.
        """
        self.legacy_checkpoint_uri = uri

    def new_checkpoint(self, checkpoint):
        from ray.train._checkpoint import Checkpoint as NewCheckpoint

        if not isinstance(checkpoint, NewCheckpoint):
            raise ValueError(
                "You must pass a `ray.train.checkpoint.Checkpoint` "
                "object to `train.report`. `ray.air.Checkpoint` is deprecated."
            )

        # Persist the reported checkpoint files to storage.
        persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)

        self.loaded_checkpoint = persisted_checkpoint

        metadata = self._auto_fill_checkpoint_metrics({})

        # Save the rank of the worker that created this checkpoint.
        metadata.update({CHECKPOINT_RANK_KEY: self.world_rank})

        result = TrainingResult(
            type=TrainingResultType.CHECKPOINT,
            data=persisted_checkpoint,
            metadata=metadata,
        )

        # Add result to a thread-safe queue.
        self.result_queue.put(result, block=True)

        # Acquire lock to stop the training thread until
        # checkpoint has been processed.
        self.continue_lock.acquire()

    def new_report(self, metrics: Dict, checkpoint=None) -> None:
        if checkpoint:
            self.new_checkpoint(checkpoint)

        # TODO(justinvyu): Unify checkpoint / report logic to just report a single
        # (metrics, Checkpoint) result for the consumer to handle.
        self._report_legacy(**metrics)

    def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
        # TODO(xwjiang): tons of optimizations.

        # Special case: early fail for Torch tensors
        if "torch" in sys.modules:
            from ray.air._internal.torch_utils import contains_tensor

            if contains_tensor(metrics):
                raise ValueError(
                    "Passing objects containg Torch tensors as metrics "
                    "is not supported as it will throw an exception on "
                    "deserialization. You can either convert the tensors "
                    "to Python objects or use a `TorchCheckpoint` as the "
                    "`checkpoint` argument of `ray.train.report` to "
                    "store your Torch objects."
                )

        if _use_storage_context():
            return self.new_report(metrics, checkpoint=checkpoint)

        if checkpoint:
            self.checkpoint(checkpoint)
        self._report_legacy(**metrics)

    @property
    def experiment_name(self) -> str:
        return self.trial_info.experiment_name

    @property
    def trial_name(self) -> str:
        return self.trial_info.name

    @property
    def trial_id(self) -> str:
        return self.trial_info.id

    @property
    def trial_resources(self) -> "PlacementGroupFactory":
        return self.trial_info.resources

    @property
    def trial_dir(self) -> str:
        return self.trial_info.logdir

    def get_dataset_shard(
        self,
        dataset_name: Optional[str] = None,
    ) -> Optional["DataIterator"]:
        shard = self.dataset_shard
        if shard is None:
            warnings.warn(
                "No dataset passed in. Returning None. Make sure to "
                "pass in a Dataset to Trainer.run to use this "
                "function."
            )
        elif isinstance(shard, dict):
            if not dataset_name:
                raise RuntimeError(
                    "Multiple datasets were passed into ``Trainer``, "
                    "but no ``dataset_name`` is passed into "
                    "``get_dataset_shard``. Please specify which "
                    "dataset shard to retrieve."
                )
            return shard.get(dataset_name)
        return shard


_session: Optional[_TrainSession] = None


def init_session(*args, **kwargs) -> None:
    global _session
    if _session:
        raise ValueError(
            "A Train session is already in use. Do not call "
            "`init_session()` manually."
        )
    _session = _TrainSession(*args, **kwargs)


def get_session() -> Optional[_TrainSession]:
    return _session


def shutdown_session():
    """Shuts down the initialized session."""
    global _session
    _session = None


def _raise_accelerator_session_misuse():
    """Raises a SessionMisuseError because a utility function was used improperly."""
    raise SessionMisuseError(
        "prepare/accelerate utility functions should be called inside a training "
        "function executed by `Trainer.run`"
    )


def get_accelerator(default_accelerator_cls: Type[Accelerator]) -> Accelerator:
    """The accelerator for this training session.

    If an accelerator has not been set, then this method will construct an
    accelerator using the provided accelerator class.

    Raises:
        SessionMisuseError: if the session is uninitialized.
    """
    session = get_session()
    if session is None:
        _raise_accelerator_session_misuse()
    if session.accelerator is None:
        session.accelerator = default_accelerator_cls()
    return session.accelerator


def set_accelerator(accelerator: Accelerator) -> None:
    """Sets the accelerator for this training session.

    Args:
        accelerator: The accelerator to use for training.

    Raises:
        SessionMisuseError: if the session is unitialized.
        RuntimeError: if the accelerator has already been set.
    """
    session = get_session()
    if session is None:
        _raise_accelerator_session_misuse()
    if session.accelerator is not None:
        raise RuntimeError("Cannot change accelerator once set.")
    session.accelerator = accelerator


def _warn_session_misuse(default_value: Any = None):
    """Warns if fn is being used outside of session and returns ``default_value``."""

    def inner(fn: Callable):
        fn_name = fn.__name__

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            session = _get_session()
            if not session:
                if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
                    warnings.warn(
                        f"`{fn_name}` is meant to only be "
                        "called inside a function that is executed by a Tuner"
                        f" or Trainer. Returning `{default_value}`."
                    )
                return default_value
            return fn(*args, **kwargs)

        return wrapper

    return inner


[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: """Report metrics and optionally save a checkpoint. Each invocation of this method will automatically increment the underlying iteration number. The physical meaning of this "iteration" is defined by user (or more specifically the way they call ``report``). It does not necessarily map to one epoch. This API is the canonical way to report metrics from Tune and Train, and replaces the legacy ``tune.report``, ``with tune.checkpoint_dir``, ``train.report`` and ``train.save_checkpoint`` calls. Note on directory checkpoints: AIR will take ownership of checkpoints passed to ``report()`` by moving them to a new path. The original directory will no longer be accessible to the caller after the report call. Example: .. testcode:: import tensorflow as tf from ray import train from ray.train import Checkpoint, ScalingConfig from ray.train.tensorflow import TensorflowTrainer ######## Using it in the *per worker* train loop (TrainSession) ####### def train_func(): model = tf.keras.applications.resnet50.ResNet50() model.save("my_model", overwrite=True) train.report( metrics={"foo": "bar"}, checkpoint=Checkpoint.from_directory("my_model") ) # Air guarantees by this point, you can safely write new stuff to # "my_model" directory. scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config ) result = trainer.fit() # If you navigate to result.checkpoint's path, you will find the # content of ``model.save()`` under it. # If you have `SyncConfig` configured, the content should also # show up in the corresponding cloud storage path. .. testoutput:: :hide: ... Args: metrics: The metrics you want to report. checkpoint: The optional checkpoint you want to report. """ _get_session().report(metrics, checkpoint=checkpoint)
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_checkpoint() -> Optional[Checkpoint]: """Access the session's last checkpoint to resume from if applicable. Returns: Checkpoint object if the session is currently being resumed. Otherwise, return None. .. testcode:: import tensorflow as tf ######## Using it in the *per worker* train loop (TrainSession) ###### from ray import train from ray.train import Checkpoint, ScalingConfig from ray.train.tensorflow import TensorflowTrainer def train_func(): ckpt = train.get_checkpoint() if ckpt: with ckpt.as_directory() as loaded_checkpoint_dir: model = tf.keras.models.load_model(loaded_checkpoint_dir) else: model = tf.keras.applications.resnet50.ResNet50() model.save("my_model", overwrite=True) train.report( metrics={"iter": 1}, checkpoint=Checkpoint.from_directory("my_model") ) scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config ) result = trainer.fit() # trainer2 will pick up from the checkpoint saved by trainer1. trainer2 = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config, # this is ultimately what is accessed through # ``ray.train.get_checkpoint()`` resume_from_checkpoint=result.checkpoint, ) result2 = trainer2.fit() .. testoutput:: :hide: ... """ return _get_session().loaded_checkpoint
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_experiment_name() -> str: """Experiment name for the corresponding trial.""" return _get_session().experiment_name
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_name() -> str: """Trial name for the corresponding trial.""" return _get_session().trial_name
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_id() -> str: """Trial id for the corresponding trial.""" return _get_session().trial_id
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_resources() -> "PlacementGroupFactory": """Trial resources for the corresponding trial.""" return _get_session().trial_resources
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_dir() -> str: """Log directory corresponding to the trial directory for a Tune session. If calling from a Train session, this will give the trial directory of its parent Tune session. .. testcode:: from ray import train, tune def train_func(config): print(train.get_context().get_trial_dir()) tuner = tune.Tuner(train_func) tuner.fit() .. testoutput:: :options: +MOCK /Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40 """ return _get_session().trial_dir
[docs]@PublicAPI(stability="beta") @_warn_session_misuse(default_value=1) def get_world_size() -> int: """Get the current world size (i.e. total number of workers) for this run. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.tensorflow import TensorflowTrainer NUM_WORKERS = 2 def train_loop_per_worker(config): assert train.get_context().get_world_size() == NUM_WORKERS train_dataset = ray.data.read_csv("s3://[email protected]/iris.csv") trainer = TensorflowTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=NUM_WORKERS), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... """ session = _get_session() if not hasattr(session, "world_size"): raise RuntimeError( "`get_world_size` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.world_size
[docs]@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_world_rank() -> int: """Get the world rank of this worker. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.tensorflow import TensorflowTrainer def train_loop_per_worker(config): if train.get_context().get_world_rank() == 0: print("Worker 0") train_dataset = ray.data.read_csv("s3://[email protected]/iris.csv") trainer = TensorflowTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... """ session = _get_session() if not hasattr(session, "world_rank"): raise RuntimeError( "`get_world_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.world_rank
[docs]@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_rank() -> int: """Get the local rank of this worker (rank of the worker on its node). .. testcode:: import torch import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(config): if torch.cuda.is_available(): torch.cuda.set_device(train.get_context().get_local_rank()) ... train_dataset = ray.data.read_csv("s3://[email protected]/iris.csv") trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2, use_gpu=True), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... """ session = _get_session() if not hasattr(session, "local_rank"): raise RuntimeError( "`get_local_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.local_rank
[docs]@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_world_size() -> int: """Get the local world size of this node (i.e. number of workers on this node). Example: .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(): print(train.get_context().get_local_world_size()) train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TorchTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), datasets={"train": train_dataset}) trainer.fit() .. testoutput:: :hide: ... """ session = _get_session() if not hasattr(session, "local_world_size"): raise RuntimeError( "`get_local_world_size` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.local_world_size
[docs]@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_node_rank() -> int: """Get the rank of this node. Example: .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(): print(train.get_context().get_node_rank()) train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TorchTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), datasets={"train": train_dataset}) trainer.fit() .. testoutput:: :hide: ... """ session = _get_session() if not hasattr(session, "node_rank"): raise RuntimeError( "`get_node_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.node_rank
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_dataset_shard( dataset_name: Optional[str] = None, ) -> Optional["DataIterator"]: """Returns the :class:`ray.data.DataIterator` shard for this worker. Call :meth:`~ray.data.DataIterator.iter_torch_batches` or :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the appropriate framework-specific data type. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(config): ... for epoch in range(2): # Trainer will automatically handle sharding. data_shard = train.get_dataset_shard("train") for batch in data_shard.iter_torch_batches(): ... train_dataset = ray.data.read_csv("s3://[email protected]/iris.csv") trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... Args: dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then specifies which dataset shard to return. Returns: The ``DataIterator`` shard to use for this worker. If no dataset is passed into Trainer, then return None. """ session = _get_session() if not hasattr(session, "get_dataset_shard"): raise RuntimeError( "`get_dataset_shard` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.get_dataset_shard(dataset_name)