ray.train.tensorflow.TensorflowCheckpoint.from_model#

classmethod TensorflowCheckpoint.from_model(model: keras.engine.training.Model, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint[source]#

Create a Checkpoint that stores a Keras model.

The checkpoint created with this method needs to be paired with model when used.

Parameters
  • model – The Keras model, whose weights are stored in the checkpoint.

  • preprocessor – A fitted preprocessor to be applied before inference.

Returns

A TensorflowCheckpoint containing the specified model.

Examples

from ray.train.tensorflow import TensorflowCheckpoint
import tensorflow as tf

model = tf.keras.applications.resnet.ResNet101()
checkpoint = TensorflowCheckpoint.from_model(model)