Source code for ray.train.lightning.lightning_checkpoint

import os
import logging
import pytorch_lightning as pl
import tempfile
import shutil

from inspect import isclass
from typing import Optional, Type, Dict, Any

from ray.air.constants import MODEL_KEY
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.data import Preprocessor
from ray.train.torch import TorchCheckpoint
from ray.util.annotations import PublicAPI

logger = logging.getLogger(__name__)


[docs]@PublicAPI(stability="alpha") class LightningCheckpoint(TorchCheckpoint): """A :class:`~ray.air.checkpoint.Checkpoint` with Lightning-specific functionality. LightningCheckpoint only support file based checkpoint loading. Create this by calling ``LightningCheckpoint.from_directory(ckpt_dir)``, ``LightningCheckpoint.from_uri(uri)`` or ``LightningCheckpoint.from_path(path)`` LightningCheckpoint loads file named ``model`` under the specified directory. Examples: .. testcode:: :skipif: True from ray.train.lightning import LightningCheckpoint # Suppose we saved a checkpoint in "./checkpoint_000000/model": # Option 1 (Preferred): Load from the checkpoint file checkpoint = LightningCheckpoint.from_path( path="./checkpoint_00000/model" ) # Option 2: Load from a directory checkpoint = LightningCheckpoint.from_directory( path="./checkpoint_00000/" ) # Suppose we saved a checkpoint in an S3 bucket: # Option 3: Load from URI checkpoint = LightningCheckpoint.from_uri( path="s3://path/to/checkpoint/directory/" ) """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._cache_dir = None
[docs] @classmethod def from_path( cls, path: str, *, preprocessor: Optional["Preprocessor"] = None, ) -> "LightningCheckpoint": """Create a ``ray.air.lightning.LightningCheckpoint`` from a checkpoint file. Args: path: The file path to the PyTorch Lightning checkpoint file. preprocessor: A fitted preprocessor to be applied before inference. Returns: An :py:class:`LightningCheckpoint` containing the model. """ assert os.path.exists(path), f"Lightning checkpoint {path} doesn't exists!" if os.path.isdir(path): raise ValueError( f"`from_path()` expects a file path, but `{path}` is a directory. " "A valid checkpoint file name is normally with .ckpt extension." "If you have an AIR checkpoint folder, you can also try to use " "`LightningCheckpoint.from_directory()` instead." ) cache_dir = tempfile.mkdtemp() new_checkpoint_path = os.path.join(cache_dir, MODEL_KEY) shutil.copy(path, new_checkpoint_path) if preprocessor: save_preprocessor_to_dir(preprocessor, cache_dir) checkpoint = cls.from_directory(cache_dir) checkpoint._cache_dir = cache_dir return checkpoint
[docs] def get_model( self, model_class: Type[pl.LightningModule], **load_from_checkpoint_kwargs: Optional[Dict[str, Any]], ) -> pl.LightningModule: """Retrieve the model stored in this checkpoint. Example: .. testcode:: import pytorch_lightning as pl from ray.train.lightning import LightningCheckpoint, LightningPredictor class MyLightningModule(pl.LightningModule): def __init__(self, input_dim, output_dim) -> None: super().__init__() self.linear = nn.Linear(input_dim, output_dim) self.save_hyperparameters() # ... # After the training is finished, LightningTrainer saves AIR # checkpoints in the result directory, for example: # ckpt_dir = "{storage_path}/LightningTrainer_.*/checkpoint_000000" # You can load model checkpoint with model init arguments def load_checkpoint(ckpt_dir): ckpt = LightningCheckpoint.from_directory(ckpt_dir) # `get_model()` takes the argument list of # `LightningModule.load_from_checkpoint()` as additional kwargs. # Please refer to PyTorch Lightning API for more details. return checkpoint.get_model( model_class=MyLightningModule, input_dim=32, output_dim=10, ) # You can also load checkpoint with a hyperparameter file def load_checkpoint_with_hparams( ckpt_dir, hparam_file="./hparams.yaml" ): ckpt = LightningCheckpoint.from_directory(ckpt_dir) return ckpt.get_model( model_class=MyLightningModule, hparams_file=hparam_file ) Args: model_class: A subclass of ``pytorch_lightning.LightningModule`` that defines your model and training logic. **load_from_checkpoint_kwargs: Arguments to pass into ``pl.LightningModule.load_from_checkpoint``. Returns: pl.LightningModule: An instance of the loaded model. """ if not isclass(model_class): raise ValueError( "'model_class' must be a class, not an instantiated Lightning trainer." ) with self.as_directory() as checkpoint_dir: ckpt_path = os.path.join(checkpoint_dir, MODEL_KEY) if not os.path.exists(ckpt_path): raise RuntimeError( f"File {ckpt_path} not found under the checkpoint directory." ) model = model_class.load_from_checkpoint( ckpt_path, **load_from_checkpoint_kwargs ) return model
def __del__(self): if self._cache_dir and os.path.exists(self._cache_dir): shutil.rmtree(self._cache_dir)