ray.train.tensorflow.TensorflowCheckpoint.from_saved_model
ray.train.tensorflow.TensorflowCheckpoint.from_saved_model#
- classmethod TensorflowCheckpoint.from_saved_model(dir_path: str, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint[source]#
Create a
Checkpointthat stores a Keras model from SavedModel format.The checkpoint generated by this method contains all the information needed. Thus no
modelis needed to be supplied when using this checkpoint.dir_pathmust maintain validity even after this function returns. Some new files/directories may be added todir_path, as a side effect of this method.- Parameters
dir_path – The directory containing the saved model. This is the same directory as used by
model.save(dir_path).preprocessor – A fitted preprocessor to be applied before inference.
- Returns
A
TensorflowCheckpointconverted from SavedModel format.
Examples:
import tensorflow as tf import ray from ray import train from ray.train import ScalingConfig from ray.train.batch_predictor import BatchPredictor from ray.train.tensorflow import ( TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor) def train_fn(): model = tf.keras.Sequential( [ tf.keras.layers.InputLayer(input_shape=()), tf.keras.layers.Flatten(), tf.keras.layers.Dense(10), tf.keras.layers.Dense(1), ]) model.save("my_model") checkpoint = TensorflowCheckpoint.from_saved_model("my_model") train.report({"my_metric": 1}, checkpoint=checkpoint) trainer = TensorflowTrainer( train_loop_per_worker=train_fn, scaling_config=ScalingConfig(num_workers=2)) result_checkpoint = trainer.fit().checkpoint batch_predictor = BatchPredictor.from_checkpoint( result_checkpoint, TensorflowPredictor) batch_predictor.predict(ray.data.range(3))