ray.air.session.get_checkpoint#

ray.air.session.get_checkpoint() Optional[ray.air.checkpoint.Checkpoint][source]#

Access the session’s last checkpoint to resume from if applicable.

Returns

Checkpoint object if the session is currently being resumed.

Otherwise, return None.

import tensorflow as tf

######## Using it in the *per worker* train loop (TrainSession) ######
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

def train_func():
    ckpt = train.get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as loaded_checkpoint_dir:
            model = tf.keras.models.load_model(loaded_checkpoint_dir)
    else:
        model = tf.keras.applications.resnet50.ResNet50()

    model.save("my_model", overwrite=True)
    train.report(
        metrics={"iter": 1},
        checkpoint=Checkpoint.from_directory("my_model")
    )

scaling_config = ScalingConfig(num_workers=2)
trainer = TensorflowTrainer(
    train_loop_per_worker=train_func, scaling_config=scaling_config
)
result = trainer.fit()

# trainer2 will pick up from the checkpoint saved by trainer1.
trainer2 = TensorflowTrainer(
    train_loop_per_worker=train_func,
    scaling_config=scaling_config,
    # this is ultimately what is accessed through
    # ``ray.train.get_checkpoint()``
    resume_from_checkpoint=result.checkpoint,
)
result2 = trainer2.fit()

PublicAPI (beta): This API is in beta and may change before becoming stable.