ray.train.lightning.LightningPredictor.call_model#

LightningPredictor.call_model(inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) Union[torch.Tensor, Dict[str, torch.Tensor]]#

Runs inference on a single batch of tensor data.

This method is called by TorchPredictor.predict after converting the original data batch to torch tensors.

Override this method to add custom logic for processing the model input or output.

Parameters

inputs – A batch of data to predict on, represented as either a single PyTorch tensor or for multi-input models, a dictionary of tensors.

Returns

The model outputs, either as a single tensor or a dictionary of tensors.

Example

import numpy as np
import torch
from ray.train.torch import TorchPredictor

# List outputs are not supported by default TorchPredictor.
# So let's define a custom TorchPredictor and override call_model
class MyModel(torch.nn.Module):
    def forward(self, input_tensor):
        return [input_tensor, input_tensor]

# Use a custom predictor to format model output as a dict.
class CustomPredictor(TorchPredictor):
    def call_model(self, inputs):
        model_output = super().call_model(inputs)
        return {
            str(i): model_output[i] for i in range(len(model_output))
        }

# create our data batch
data_batch = np.array([1, 2])
# create custom predictor and predict
predictor = CustomPredictor(model=MyModel())
predictions = predictor.predict(data_batch)
print(f"Predictions: {predictions.get('0')}, {predictions.get('1')}")
Predictions: [1 2], [1 2]

DeveloperAPI: This API may change across minor Ray releases.