import logging
import numpy as np
import random
from typing import Optional, Type
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.dreamer.dreamer_torch_policy import DreamerTorchPolicy
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
concat_samples,
convert_ma_batch_to_sample_batch,
)
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.algorithms.dreamer.dreamer_model import DreamerModel
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
SAMPLE_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import (
ResultDict,
)
from ray.rllib.utils.replay_buffers import ReplayBuffer, StorageUnit
logger = logging.getLogger(__name__)
[docs]class DreamerConfig(AlgorithmConfig):
"""Defines a configuration class from which a Dreamer Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.dreamer import DreamerConfig
>>> config = DreamerConfig().training(gamma=0.9, lr=0.01) # doctest: +SKIP
>>> config = config.resources(num_gpus=0) # doctest: +SKIP
>>> config = config.rollouts(num_rollout_workers=4) # doctest: +SKIP
>>> print(config.to_dict()) # doctest: +SKIP
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> algo = config.build(env="CartPole-v1") # doctest: +SKIP
>>> algo.train() # doctest: +SKIP
Example:
>>> from ray import air
>>> from ray import tune
>>> from ray.rllib.algorithms.dreamer import DreamerConfig
>>> config = DreamerConfig()
>>> # Print out some default values.
>>> print(config.clip_param) # doctest: +SKIP
>>> # Update the config object.
>>> config = config.training( # doctest: +SKIP
... lr=tune.grid_search([0.001, 0.0001]), clip_param=0.2)
>>> # Set the config object's env.
>>> config = config.environment(env="CartPole-v1") # doctest: +SKIP
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.Tuner( # doctest: +SKIP
... "Dreamer",
... run_config=air.RunConfig(stop={"episode_reward_mean": 200}),
... param_space=config.to_dict(),
... ).fit()
"""
def __init__(self):
"""Initializes a PPOConfig instance."""
super().__init__(algo_class=Dreamer)
# fmt: off
# __sphinx_doc_begin__
# Dreamer specific settings:
self.td_model_lr = 6e-4
self.actor_lr = 8e-5
self.critic_lr = 8e-5
self.grad_clip = 100.0
# Note: Only when using _enable_learner_api=True can the clipping mode be
# configured by the user. On the old API stack, RLlib will always clip by
# global_norm, no matter the value of `grad_clip_by`.
self.grad_clip_by = "global_norm"
self.lambda_ = 0.95
self.dreamer_train_iters = 100
self.batch_size = 50
self.batch_length = 50
self.imagine_horizon = 15
self.free_nats = 3.0
self.kl_coeff = 1.0
self.prefill_timesteps = 5000
self.explore_noise = 0.3
self.dreamer_model = {
"custom_model": DreamerModel,
# RSSM/PlaNET parameters
"deter_size": 200,
"stoch_size": 30,
# CNN Decoder Encoder
"depth_size": 32,
# General Network Parameters
"hidden_size": 400,
# Action STD
"action_init_std": 5.0,
}
# Override some of AlgorithmConfig's default values with Dreamer-specific
# values.
# .rollouts()
self.num_envs_per_worker = 1
self.batch_mode = "complete_episodes"
self.clip_actions = False
# .training()
self.gamma = 0.99
# Number of timesteps to collect from rollout workers before we start
# sampling from replay buffers for learning. Whether we count this in agent
# steps or environment steps depends on config.multi_agent(count_steps_by=..).
self.num_steps_sampled_before_learning_starts = 0
# .environment()
self.env_config.update({
# Repeats action send by policy for frame_skip times in env
"frame_skip": 2,
})
# .exploration()
# This dreamer implementation does not need an exploration config
self.exploration_config = {}
# __sphinx_doc_end__
# fmt: on
[docs] @override(AlgorithmConfig)
def training(
self,
*,
td_model_lr: Optional[float] = NotProvided,
actor_lr: Optional[float] = NotProvided,
critic_lr: Optional[float] = NotProvided,
grad_clip: Optional[float] = NotProvided,
lambda_: Optional[float] = NotProvided,
dreamer_train_iters: Optional[int] = NotProvided,
batch_size: Optional[int] = NotProvided,
batch_length: Optional[int] = NotProvided,
imagine_horizon: Optional[int] = NotProvided,
free_nats: Optional[float] = NotProvided,
kl_coeff: Optional[float] = NotProvided,
prefill_timesteps: Optional[int] = NotProvided,
explore_noise: Optional[float] = NotProvided,
dreamer_model: Optional[dict] = NotProvided,
num_steps_sampled_before_learning_starts: Optional[int] = NotProvided,
**kwargs,
) -> "DreamerConfig":
"""
Args:
td_model_lr: PlaNET (transition dynamics) model learning rate.
actor_lr: Actor model learning rate.
critic_lr: Critic model learning rate.
grad_clip: If specified, clip the global norm of gradients by this amount.
lambda_: The GAE (lambda) parameter.
dreamer_train_iters: Training iterations per data collection from real env.
batch_size: Number of episodes to sample for loss calculation.
batch_length: Length of each episode to sample for loss calculation.
imagine_horizon: Imagination horizon for training Actor and Critic.
free_nats: Free nats.
kl_coeff: KL coefficient for the model Loss.
prefill_timesteps: Prefill timesteps.
explore_noise: Exploration Gaussian noise.
dreamer_model: Custom model config.
num_steps_sampled_before_learning_starts: Number of timesteps to collect
from rollout workers before we start sampling from replay buffers for
learning. Whether we count this in agent steps or environment steps
depends on config.multi_agent(count_steps_by=..).
Returns:
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if td_model_lr is not NotProvided:
self.td_model_lr = td_model_lr
if actor_lr is not NotProvided:
self.actor_lr = actor_lr
if critic_lr is not NotProvided:
self.critic_lr = critic_lr
if grad_clip is not NotProvided:
self.grad_clip = grad_clip
if lambda_ is not NotProvided:
self.lambda_ = lambda_
if dreamer_train_iters is not NotProvided:
self.dreamer_train_iters = dreamer_train_iters
if batch_size is not NotProvided:
self.batch_size = batch_size
if batch_length is not NotProvided:
self.batch_length = batch_length
if imagine_horizon is not NotProvided:
self.imagine_horizon = imagine_horizon
if free_nats is not NotProvided:
self.free_nats = free_nats
if kl_coeff is not NotProvided:
self.kl_coeff = kl_coeff
if prefill_timesteps is not NotProvided:
self.prefill_timesteps = prefill_timesteps
if explore_noise is not NotProvided:
self.explore_noise = explore_noise
if dreamer_model is not NotProvided:
self.dreamer_model = dreamer_model
if num_steps_sampled_before_learning_starts is not NotProvided:
self.num_steps_sampled_before_learning_starts = (
num_steps_sampled_before_learning_starts
)
return self
@override(AlgorithmConfig)
def validate(self) -> None:
# Call super's validation method.
super().validate()
if self.num_gpus > 1:
raise ValueError("`num_gpus` > 1 not yet supported for Dreamer!")
if self.framework_str != "torch":
raise ValueError("Dreamer not supported in Tensorflow yet!")
if self.batch_mode != "complete_episodes":
raise ValueError("truncate_episodes not supported")
if self.num_rollout_workers != 0:
raise ValueError("Distributed Dreamer not supported yet!")
if self.clip_actions:
raise ValueError("Clipping is done inherently via policy tanh!")
if self.dreamer_train_iters <= 0:
raise ValueError(
"`dreamer_train_iters` must be a positive integer. "
f"Received {self.dreamer_train_iters} instead."
)
if self.env_config.get("frame_skip", 0) > 1:
self.imagine_horizon //= self.env_config["frame_skip"]
def _postprocess_gif(gif: np.ndarray):
"""Process provided gif to a format that can be logged to Tensorboard."""
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
B, T, C, H, W = gif.shape
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
return frames
class EpisodeSequenceBuffer(ReplayBuffer):
def __init__(self, capacity: int = 1000, replay_sequence_length: int = 50):
"""Stores episodes and samples sequences of size `replay_sequence_length`.
Args:
capacity: Maximum number of episodes this buffer can store
replay_sequence_length: Episode chunking length in sample()
"""
super().__init__(capacity=capacity, storage_unit=StorageUnit.EPISODES)
self.replay_sequence_length = replay_sequence_length
def sample(self, num_items: int):
"""Samples [batch_size, length] from the list of episodes
Args:
num_items: batch_size to be sampled
"""
episodes_buffer = []
while len(episodes_buffer) < num_items:
episode = super().sample(1)
if episode.count < self.replay_sequence_length:
continue
available = episode.count - self.replay_sequence_length
index = int(random.randint(0, available))
episodes_buffer.append(episode[index : index + self.replay_sequence_length])
return concat_samples(episodes_buffer)
def total_sampled_timesteps(worker):
return worker.policy_map[DEFAULT_POLICY_ID].global_timestep
class DreamerIteration:
def __init__(
self, worker, episode_buffer, dreamer_train_iters, batch_size, act_repeat
):
self.worker = worker
self.episode_buffer = episode_buffer
self.dreamer_train_iters = dreamer_train_iters
self.repeat = act_repeat
self.batch_size = batch_size
def __call__(self, samples):
# Update target network every `target_network_update_freq` sample steps.
cur_ts = self._counters[
NUM_AGENT_STEPS_SAMPLED
if self.config.count_steps_by == "agent_steps"
else NUM_ENV_STEPS_SAMPLED
]
if cur_ts > self.config.num_steps_sampled_before_learning_starts:
# Dreamer training loop.
for n in range(self.dreamer_train_iters):
print(f"sub-iteration={n}/{self.dreamer_train_iters}")
batch = self.episode_buffer.sample(self.batch_size)
fetches = self.worker.learn_on_batch(batch)
else:
fetches = {}
# Custom Logging
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
if "log_gif" in policy_fetches:
gif = policy_fetches["log_gif"]
policy_fetches["log_gif"] = self.postprocess_gif(gif)
# Metrics Calculation
metrics = _get_shared_metrics()
metrics.info[LEARNER_INFO] = fetches
metrics.counters[STEPS_SAMPLED_COUNTER] = self.episode_buffer.timesteps
metrics.counters[STEPS_SAMPLED_COUNTER] *= self.repeat
res = collect_metrics(local_worker=self.worker)
res["info"] = metrics.info
res["info"].update(metrics.counters)
res["timesteps_total"] = metrics.counters[STEPS_SAMPLED_COUNTER]
self.episode_buffer.add(samples)
return res
def postprocess_gif(self, gif: np.ndarray):
return _postprocess_gif(gif=gif)
@Deprecated(
old="rllib/algorithms/dreamer/",
new="rllib_contrib/dreamer/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class Dreamer(Algorithm):
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfig:
return DreamerConfig()
@classmethod
@override(Algorithm)
def get_default_policy_class(
cls, config: AlgorithmConfig
) -> Optional[Type[Policy]]:
return DreamerTorchPolicy
@override(Algorithm)
def setup(self, config: AlgorithmConfig):
super().setup(config)
# Setup buffer.
self.local_replay_buffer = EpisodeSequenceBuffer(
replay_sequence_length=config["batch_length"]
)
# Prefill episode buffer with initial exploration (uniform sampling)
while (
total_sampled_timesteps(self.workers.local_worker())
< self.config.prefill_timesteps
):
samples = self.workers.local_worker().sample()
# Dreamer only ever has one policy and we receive MA batches when
# connectors are on
samples = convert_ma_batch_to_sample_batch(samples)
self.local_replay_buffer.add(samples)
@override(Algorithm)
def training_step(self) -> ResultDict:
local_worker = self.workers.local_worker()
# Number of sub-iterations for Dreamer
dreamer_train_iters = self.config.dreamer_train_iters
batch_size = self.config.batch_size
# Collect SampleBatches from rollout workers.
with self._timers[SAMPLE_TIMER]:
batch = synchronous_parallel_sample(worker_set=self.workers)
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
fetches = {}
# Update target network every `target_network_update_freq` sample steps.
cur_ts = self._counters[
NUM_AGENT_STEPS_SAMPLED
if self.config.count_steps_by == "agent_steps"
else NUM_ENV_STEPS_SAMPLED
]
if cur_ts > self.config.num_steps_sampled_before_learning_starts:
# Dreamer training loop.
# Run multiple sub-iterations for each training iteration.
for n in range(dreamer_train_iters):
print(f"sub-iteration={n}/{dreamer_train_iters}")
batch = self.local_replay_buffer.sample(batch_size)
fetches = local_worker.learn_on_batch(batch)
if fetches:
# Custom logging.
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
if "log_gif" in policy_fetches:
gif = policy_fetches["log_gif"]
policy_fetches["log_gif"] = self._postprocess_gif(gif)
self.local_replay_buffer.add(batch)
return fetches