ray.train.torch.TorchPredictor.from_checkpoint
ray.train.torch.TorchPredictor.from_checkpoint#
- classmethod TorchPredictor.from_checkpoint(checkpoint: ray.air.checkpoint.Checkpoint, model: Optional[torch.nn.modules.module.Module] = None, use_gpu: bool = False) ray.train.torch.torch_predictor.TorchPredictor[source]#
Instantiate the predictor from a Checkpoint.
The checkpoint is expected to be a result of
TorchTrainer.- Parameters
checkpoint – The checkpoint to load the model and preprocessor from. It is expected to be from the result of a
TorchTrainerrun.model – If the checkpoint contains a model state dict, and not the model itself, then the state dict will be loaded to this
model. If the checkpoint already contains the model itself, this model argument will be discarded.use_gpu – If set, the model will be moved to GPU on instantiation and prediction happens on GPU.