ray.train.lightning.LightningPredictor.from_checkpoint
ray.train.lightning.LightningPredictor.from_checkpoint#
- classmethod LightningPredictor.from_checkpoint(checkpoint: ray.train.lightning.lightning_checkpoint.LightningCheckpoint, model_class: Type[pytorch_lightning.core.lightning.LightningModule], *, preprocessor: Optional[ray.data.preprocessor.Preprocessor] = None, use_gpu: bool = False, **load_from_checkpoint_kwargs) ray.train.lightning.lightning_predictor.LightningPredictor[source]#
Instantiate the LightningPredictor from a Checkpoint.
The checkpoint is expected to be a result of
LightningTrainer.Example
import pytorch_lightning as pl from ray.train.lightning import LightningCheckpoint, LightningPredictor class MyLightningModule(pl.LightningModule): def __init__(self, input_dim, output_dim) -> None: super().__init__() self.linear = nn.Linear(input_dim, output_dim) # ... # After the training is finished, LightningTrainer saves AIR # checkpoints in the result directory, for example: # ckpt_dir = "{storage_path}/LightningTrainer_.*/checkpoint_000000" def load_predictor_from_checkpoint(ckpt_dir): checkpoint = LightningCheckpoint.from_directory(ckpt_dir) # `from_checkpoint()` takes the argument list of # `LightningModule.load_from_checkpoint()` as additional kwargs. return LightningPredictor.from_checkpoint( checkpoint=checkpoint, use_gpu=False, model_class=MyLightningModule, input_dim=32, output_dim=10, )
- Parameters
checkpoint – The checkpoint to load the model and preprocessor from. It is expected to be from the result of a
LightningTrainerrun.model_class – A subclass of
pytorch_lightning.LightningModulethat defines your model and training logic. Note that this is a class type instead of a model instance.preprocessor – A preprocessor used to transform data batches prior to prediction.
use_gpu – If set, the model will be moved to GPU on instantiation and prediction happens on GPU.
**load_from_checkpoint_kwargs – Arguments to pass into
pl.LightningModule.load_from_checkpoint.