ray.train.lightning.LightningCheckpoint.from_model
ray.train.lightning.LightningCheckpoint.from_model#
- classmethod LightningCheckpoint.from_model(model: torch.nn.modules.module.Module, *, preprocessor: Optional[Preprocessor] = None) TorchCheckpoint#
Create a
Checkpointthat stores a Torch model.Note
PyTorch recommends storing state dictionaries. To create a
TorchCheckpointfrom a state dictionary, callfrom_state_dict(). To learn more about state dictionaries, read Saving and Loading Models. # noqa: E501- Parameters
model – The Torch model to store in the checkpoint.
preprocessor – A fitted preprocessor to be applied before inference.
- Returns
A
TorchCheckpointcontaining the specified model.
Examples
from ray.train.torch import TorchCheckpoint from ray.train.torch import TorchPredictor import torch # Set manual seed torch.manual_seed(42) # Create model identity and send a random tensor to it model = torch.nn.Identity() input = torch.randn(2, 2) output = model(input) # Create a checkpoint checkpoint = TorchCheckpoint.from_model(model) # You can use a class TorchCheckpoint to create an # a class ray.train.torch.TorchPredictor and perform inference. predictor = TorchPredictor.from_checkpoint(checkpoint) pred = predictor.predict(input.numpy()) # Convert prediction dictionary value into a tensor pred = torch.tensor(pred['predictions']) # Assert the output from the original and checkoint model are the same assert torch.equal(output, pred) print("worked")