ray.train.tensorflow.TensorflowPredictor#

class ray.train.tensorflow.TensorflowPredictor(*, model: Optional[keras.engine.training.Model] = None, preprocessor: Optional[Preprocessor] = None, use_gpu: bool = False)[source]#

Bases: ray.train._internal.dl_predictor.DLPredictor

A predictor for TensorFlow models.

Parameters
  • model – A Tensorflow Keras model to use for predictions.

  • preprocessor – A preprocessor used to transform data batches prior to prediction.

  • model_weights – List of weights to use for the model.

  • use_gpu – If set, the model will be moved to GPU on instantiation and prediction happens on GPU.

PublicAPI (beta): This API is in beta and may change before becoming stable.

Methods

call_model(inputs)

Runs inference on a single batch of tensor data.

from_checkpoint(checkpoint[, ...])

Instantiate the predictor from a Checkpoint.

from_pandas_udf(pandas_udf)

Create a Predictor from a Pandas UDF.

get_preprocessor()

Get the preprocessor to use prior to executing predictions.

predict(data[, dtype])

Run inference on data batch.

preferred_batch_format()

DeveloperAPI: This API may change across minor Ray releases.

set_preprocessor(preprocessor)

Set the preprocessor to use prior to executing predictions.