ray.train.tensorflow.TensorflowCheckpoint.from_saved_model#

classmethod TensorflowCheckpoint.from_saved_model(dir_path: str, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint[source]#

Create a Checkpoint that stores a Keras model from SavedModel format.

The checkpoint generated by this method contains all the information needed. Thus no model is needed to be supplied when using this checkpoint.

dir_path must maintain validity even after this function returns. Some new files/directories may be added to dir_path, as a side effect of this method.

Parameters
  • dir_path – The directory containing the saved model. This is the same directory as used by model.save(dir_path).

  • preprocessor – A fitted preprocessor to be applied before inference.

Returns

A TensorflowCheckpoint converted from SavedModel format.

Examples:

import tensorflow as tf

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.batch_predictor import BatchPredictor
from ray.train.tensorflow import (
TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor)

def train_fn():
    model = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=()),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(10),
            tf.keras.layers.Dense(1),
        ])
    model.save("my_model")
    checkpoint = TensorflowCheckpoint.from_saved_model("my_model")
    train.report({"my_metric": 1}, checkpoint=checkpoint)

trainer = TensorflowTrainer(
    train_loop_per_worker=train_fn,
    scaling_config=ScalingConfig(num_workers=2))

result_checkpoint = trainer.fit().checkpoint

batch_predictor = BatchPredictor.from_checkpoint(
    result_checkpoint, TensorflowPredictor)
batch_predictor.predict(ray.data.range(3))