ray.train.tensorflow.TensorflowPredictor
ray.train.tensorflow.TensorflowPredictor#
- class ray.train.tensorflow.TensorflowPredictor(*, model: Optional[keras.engine.training.Model] = None, preprocessor: Optional[Preprocessor] = None, use_gpu: bool = False)[source]#
Bases:
ray.train._internal.dl_predictor.DLPredictorA predictor for TensorFlow models.
- Parameters
model – A Tensorflow Keras model to use for predictions.
preprocessor – A preprocessor used to transform data batches prior to prediction.
model_weights – List of weights to use for the model.
use_gpu – If set, the model will be moved to GPU on instantiation and prediction happens on GPU.
PublicAPI (beta): This API is in beta and may change before becoming stable.
Methods
call_model(inputs)Runs inference on a single batch of tensor data.
from_checkpoint(checkpoint[, ...])Instantiate the predictor from a Checkpoint.
from_pandas_udf(pandas_udf)Create a Predictor from a Pandas UDF.
Get the preprocessor to use prior to executing predictions.
predict(data[, dtype])Run inference on data batch.
DeveloperAPI: This API may change across minor Ray releases.
set_preprocessor(preprocessor)Set the preprocessor to use prior to executing predictions.