ray.train.tensorflow.TensorflowCheckpoint.from_model
ray.train.tensorflow.TensorflowCheckpoint.from_model#
- classmethod TensorflowCheckpoint.from_model(model: keras.engine.training.Model, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint[source]#
Create a
Checkpointthat stores a Keras model.The checkpoint created with this method needs to be paired with
modelwhen used.- Parameters
model – The Keras model, whose weights are stored in the checkpoint.
preprocessor – A fitted preprocessor to be applied before inference.
- Returns
A
TensorflowCheckpointcontaining the specified model.
Examples
from ray.train.tensorflow import TensorflowCheckpoint import tensorflow as tf model = tf.keras.applications.resnet.ResNet101() checkpoint = TensorflowCheckpoint.from_model(model)