ray.train.lightning.LightningPredictor.call_model
ray.train.lightning.LightningPredictor.call_model#
- LightningPredictor.call_model(inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) Union[torch.Tensor, Dict[str, torch.Tensor]]#
Runs inference on a single batch of tensor data.
This method is called by
TorchPredictor.predictafter converting the original data batch to torch tensors.Override this method to add custom logic for processing the model input or output.
- Parameters
inputs – A batch of data to predict on, represented as either a single PyTorch tensor or for multi-input models, a dictionary of tensors.
- Returns
The model outputs, either as a single tensor or a dictionary of tensors.
Example
import numpy as np import torch from ray.train.torch import TorchPredictor # List outputs are not supported by default TorchPredictor. # So let's define a custom TorchPredictor and override call_model class MyModel(torch.nn.Module): def forward(self, input_tensor): return [input_tensor, input_tensor] # Use a custom predictor to format model output as a dict. class CustomPredictor(TorchPredictor): def call_model(self, inputs): model_output = super().call_model(inputs) return { str(i): model_output[i] for i in range(len(model_output)) } # create our data batch data_batch = np.array([1, 2]) # create custom predictor and predict predictor = CustomPredictor(model=MyModel()) predictions = predictor.predict(data_batch) print(f"Predictions: {predictions.get('0')}, {predictions.get('1')}")
Predictions: [1 2], [1 2]
DeveloperAPI: This API may change across minor Ray releases.