ray.train.lightning.LightningPredictor.from_checkpoint#

classmethod LightningPredictor.from_checkpoint(checkpoint: ray.train.lightning.lightning_checkpoint.LightningCheckpoint, model_class: Type[pytorch_lightning.core.lightning.LightningModule], *, preprocessor: Optional[ray.data.preprocessor.Preprocessor] = None, use_gpu: bool = False, **load_from_checkpoint_kwargs) ray.train.lightning.lightning_predictor.LightningPredictor[source]#

Instantiate the LightningPredictor from a Checkpoint.

The checkpoint is expected to be a result of LightningTrainer.

Example

import pytorch_lightning as pl
from ray.train.lightning import LightningCheckpoint, LightningPredictor

class MyLightningModule(pl.LightningModule):
    def __init__(self, input_dim, output_dim) -> None:
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    # ...

# After the training is finished, LightningTrainer saves AIR
# checkpoints in the result directory, for example:
# ckpt_dir = "{storage_path}/LightningTrainer_.*/checkpoint_000000"

def load_predictor_from_checkpoint(ckpt_dir):
    checkpoint = LightningCheckpoint.from_directory(ckpt_dir)

    # `from_checkpoint()` takes the argument list of
    # `LightningModule.load_from_checkpoint()` as additional kwargs.

    return LightningPredictor.from_checkpoint(
        checkpoint=checkpoint,
        use_gpu=False,
        model_class=MyLightningModule,
        input_dim=32,
        output_dim=10,
    )
Parameters
  • checkpoint – The checkpoint to load the model and preprocessor from. It is expected to be from the result of a LightningTrainer run.

  • model_class – A subclass of pytorch_lightning.LightningModule that defines your model and training logic. Note that this is a class type instead of a model instance.

  • preprocessor – A preprocessor used to transform data batches prior to prediction.

  • use_gpu – If set, the model will be moved to GPU on instantiation and prediction happens on GPU.

  • **load_from_checkpoint_kwargs – Arguments to pass into pl.LightningModule.load_from_checkpoint.