PBT Function Example
PBT Function Example#
The following script produces the following results. For a population of 8 trials, the PBT learning rate schedule roughly matches the optimal learning rate schedule.
#!/usr/bin/env python
import numpy as np
import argparse
import random
import ray
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import PopulationBasedTraining
def pbt_function(config):
"""Toy PBT problem for benchmarking adaptive learning rate.
The goal is to optimize this trainable's accuracy. The accuracy increases
fastest at the optimal lr, which is a function of the current accuracy.
The optimal lr schedule for this problem is the triangle wave as follows.
Note that many lr schedules for real models also follow this shape:
best lr
^
| /\
| / \
| / \
| / \
------------> accuracy
In this problem, using PBT with a population of 2-4 is sufficient to
roughly approximate this lr schedule. Higher population sizes will yield
faster convergence. Training will not converge without PBT.
"""
lr = config["lr"]
checkpoint_interval = config.get("checkpoint_interval", 1)
accuracy = 0.0 # end = 1000
# NOTE: See below why step is initialized to 1
step = 1
if train.get_checkpoint():
state = train.get_checkpoint().to_dict()
accuracy = state["acc"]
last_step = state["step"]
# Current step should be 1 more than the last checkpoint step
step = last_step + 1
# triangle wave:
# - start at 0.001 @ t=0,
# - peak at 0.01 @ t=midpoint,
# - end at 0.001 @ t=midpoint * 2,
midpoint = 100 # lr starts decreasing after acc > midpoint
q_tolerance = 3 # penalize exceeding lr by more than this multiple
noise_level = 2 # add gaussian noise to the acc increase
# Let `stop={"done": True}` in the configs below handle trial stopping
while True:
if accuracy < midpoint:
optimal_lr = 0.01 * accuracy / midpoint
else:
optimal_lr = 0.01 - 0.01 * (accuracy - midpoint) / midpoint
optimal_lr = min(0.01, max(0.001, optimal_lr))
# compute accuracy increase
q_err = max(lr, optimal_lr) / min(lr, optimal_lr)
if q_err < q_tolerance:
accuracy += (1.0 / q_err) * random.random()
elif lr > optimal_lr:
accuracy -= (q_err - q_tolerance) * random.random()
accuracy += noise_level * np.random.normal()
accuracy = max(0, accuracy)
checkpoint = None
if step % checkpoint_interval == 0:
# Checkpoint every `checkpoint_interval` steps
# NOTE: if we initialized `step=0` above, our checkpointing and perturbing
# would be out of sync by 1 step.
# Ex: if `checkpoint_interval` = `perturbation_interval` = 3
# step: 0 (checkpoint) 1 2 3 (checkpoint)
# training_iteration: 1 2 3 (perturb) 4
checkpoint = Checkpoint.from_dict({"acc": accuracy, "step": step})
train.report(
{
"mean_accuracy": accuracy,
"cur_lr": lr,
"optimal_lr": optimal_lr, # for debugging
"q_err": q_err, # for debugging
"done": accuracy > midpoint * 2, # this stops the training process
},
checkpoint=checkpoint,
)
step += 1
def run_tune_pbt(smoke_test=False):
perturbation_interval = 5
pbt = PopulationBasedTraining(
time_attr="training_iteration",
perturbation_interval=perturbation_interval,
hyperparam_mutations={
# distribution for resampling
"lr": tune.uniform(0.0001, 0.02),
# allow perturbations within this set of categorical values
"some_other_factor": [1, 2],
},
)
tuner = tune.Tuner(
pbt_function,
run_config=train.RunConfig(
name="pbt_function_api_example",
verbose=False,
stop={
# Stop when done = True or at some # of train steps
# (whichever comes first)
"done": True,
"training_iteration": 10 if smoke_test else 1000,
},
failure_config=train.FailureConfig(
fail_fast=True,
),
checkpoint_config=train.CheckpointConfig(
checkpoint_score_attribute="mean_accuracy",
num_to_keep=2,
),
),
tune_config=tune.TuneConfig(
scheduler=pbt,
metric="mean_accuracy",
mode="max",
num_samples=8,
),
param_space={
"lr": 0.0001,
# Note: `some_other_factor` is perturbed because it is specified under
# the PBT scheduler's `hyperparam_mutations` argument, but has no effect on
# the model training in this example
"some_other_factor": 1,
# Note: `checkpoint_interval` will not be perturbed (since it's not
# included above), and it will be used to determine how many steps to take
# between each checkpoint.
# We recommend matching `perturbation_interval` and `checkpoint_interval`
# (e.g. checkpoint every 4 steps, and perturb on those same steps)
# or making `perturbation_interval` a multiple of `checkpoint_interval`
# (e.g. checkpoint every 2 steps, and perturb every 4 steps).
# This is to ensure that the lastest checkpoints are being used by PBT
# when trials decide to exploit. If checkpointing and perturbing are not
# aligned, then PBT may use a stale checkpoint to resume from.
"checkpoint_interval": perturbation_interval,
},
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing",
)
args, _ = parser.parse_known_args()
if args.smoke_test:
ray.init(num_cpus=2) # force pausing to happen for test
run_tune_pbt(smoke_test=args.smoke_test)