ray.train.tensorflow.TensorflowCheckpoint.from_h5#

classmethod TensorflowCheckpoint.from_h5(file_path: str, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint[source]#

Create a Checkpoint that stores a Keras model from H5 format.

The checkpoint generated by this method contains all the information needed. Thus no model is needed to be supplied when using this checkpoint.

file_path must maintain validity even after this function returns. Some new files/directories may be added to the parent directory of file_path, as a side effect of this method.

Parameters
  • file_path – The path to the .h5 file to load model from. This is the same path that is used for model.save(path).

  • preprocessor – A fitted preprocessor to be applied before inference.

Returns

A TensorflowCheckpoint converted from h5 format.

Examples:

import tensorflow as tf

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.batch_predictor import BatchPredictor
from ray.train.tensorflow import (
    TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor
)

def train_func():
    model = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=()),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(10),
            tf.keras.layers.Dense(1),
        ]
    )
    model.save("my_model.h5")
    checkpoint = TensorflowCheckpoint.from_h5("my_model.h5")
    train.report({"my_metric": 1}, checkpoint=checkpoint)

trainer = TensorflowTrainer(
    train_loop_per_worker=train_func,
    scaling_config=ScalingConfig(num_workers=2))

result_checkpoint = trainer.fit().checkpoint

batch_predictor = BatchPredictor.from_checkpoint(
    result_checkpoint, TensorflowPredictor)
batch_predictor.predict(ray.data.range(3))