ray.train.tensorflow.TensorflowCheckpoint.from_h5
ray.train.tensorflow.TensorflowCheckpoint.from_h5#
- classmethod TensorflowCheckpoint.from_h5(file_path: str, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint[source]#
Create a
Checkpointthat stores a Keras model from H5 format.The checkpoint generated by this method contains all the information needed. Thus no
modelis needed to be supplied when using this checkpoint.file_pathmust maintain validity even after this function returns. Some new files/directories may be added to the parent directory offile_path, as a side effect of this method.- Parameters
file_path – The path to the .h5 file to load model from. This is the same path that is used for
model.save(path).preprocessor – A fitted preprocessor to be applied before inference.
- Returns
A
TensorflowCheckpointconverted from h5 format.
Examples:
import tensorflow as tf import ray from ray import train from ray.train import ScalingConfig from ray.train.batch_predictor import BatchPredictor from ray.train.tensorflow import ( TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor ) def train_func(): model = tf.keras.Sequential( [ tf.keras.layers.InputLayer(input_shape=()), tf.keras.layers.Flatten(), tf.keras.layers.Dense(10), tf.keras.layers.Dense(1), ] ) model.save("my_model.h5") checkpoint = TensorflowCheckpoint.from_h5("my_model.h5") train.report({"my_metric": 1}, checkpoint=checkpoint) trainer = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=ScalingConfig(num_workers=2)) result_checkpoint = trainer.fit().checkpoint batch_predictor = BatchPredictor.from_checkpoint( result_checkpoint, TensorflowPredictor) batch_predictor.predict(ray.data.range(3))