Source code for ray.rllib.algorithms.qmix.qmix

from typing import Optional, Type

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.simple_q.simple_q import SimpleQ, SimpleQConfig
from ray.rllib.algorithms.qmix.qmix_policy import QMixTorchPolicy
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
from ray.rllib.execution.rollout_ops import (
    synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
    multi_gpu_train_one_step,
    train_one_step,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
    LAST_TARGET_UPDATE_TS,
    NUM_AGENT_STEPS_SAMPLED,
    NUM_ENV_STEPS_SAMPLED,
    NUM_TARGET_UPDATES,
    SYNCH_WORKER_WEIGHTS_TIMER,
    SAMPLE_TIMER,
)
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.deprecation import (
    DEPRECATED_VALUE,
    Deprecated,
    deprecation_warning,
    ALGO_DEPRECATION_WARNING,
)


[docs]class QMixConfig(SimpleQConfig): """Defines a configuration class from which QMix can be built. Example: >>> from ray.rllib.examples.env.two_step_game import TwoStepGame >>> from ray.rllib.algorithms.qmix import QMixConfig >>> config = QMixConfig() # doctest: +SKIP >>> config = config.training(gamma=0.9, lr=0.01, kl_coeff=0.3) # doctest: +SKIP >>> config = config.resources(num_gpus=0) # doctest: +SKIP >>> config = config.rollouts(num_rollout_workers=4) # doctest: +SKIP >>> print(config.to_dict()) # doctest: +SKIP >>> # Build an Algorithm object from the config and run 1 training iteration. >>> algo = config.build(env=TwoStepGame) # doctest: +SKIP >>> algo.train() # doctest: +SKIP Example: >>> from ray.rllib.examples.env.two_step_game import TwoStepGame >>> from ray.rllib.algorithms.qmix import QMixConfig >>> from ray import air >>> from ray import tune >>> config = QMixConfig() >>> # Print out some default values. >>> print(config.optim_alpha) # doctest: +SKIP >>> # Update the config object. >>> config.training( # doctest: +SKIP ... lr=tune.grid_search([0.001, 0.0001]), optim_alpha=0.97 ... ) >>> # Set the config object's env. >>> config.environment(env=TwoStepGame) # doctest: +SKIP >>> # Use to_dict() to get the old-style python config dict >>> # when running with tune. >>> tune.Tuner( # doctest: +SKIP ... "QMix", ... run_config=air.RunConfig(stop={"episode_reward_mean": 200}), ... param_space=config.to_dict(), ... ).fit() """ def __init__(self): """Initializes a PPOConfig instance.""" super().__init__(algo_class=QMix) # fmt: off # __sphinx_doc_begin__ # QMix specific settings: self.mixer = "qmix" self.mixing_embed_dim = 32 self.double_q = True self.optim_alpha = 0.99 self.optim_eps = 0.00001 self.grad_clip = 10.0 # Note: Only when using _enable_learner_api=True can the clipping mode be # configured by the user. On the old API stack, RLlib will always clip by # global_norm, no matter the value of `grad_clip_by`. self.grad_clip_by = "global_norm" # QMix-torch overrides the TorchPolicy's learn_on_batch w/o specifying a # alternative `learn_on_loaded_batch` alternative for the GPU. # TODO: This hack will be resolved once we move all algorithms to the new # RLModule/Learner APIs. self.simple_optimizer = True # Override some of AlgorithmConfig's default values with QMix-specific values. # .training() self.lr = 0.0005 self.train_batch_size = 32 self.target_network_update_freq = 500 self.num_steps_sampled_before_learning_starts = 1000 self.replay_buffer_config = { "type": "ReplayBuffer", # Specify prioritized replay by supplying a buffer type that supports # prioritization, for example: MultiAgentPrioritizedReplayBuffer. "prioritized_replay": DEPRECATED_VALUE, # Size of the replay buffer in batches (not timesteps!). "capacity": 1000, # Choosing `fragments` here makes it so that the buffer stores entire # batches, instead of sequences, episodes or timesteps. "storage_unit": "fragments", # Whether to compute priorities on workers. "worker_side_prioritization": False, } self.model = { "lstm_cell_size": 64, "max_seq_len": 999999, } # .framework() self.framework_str = "torch" # .rollouts() self.rollout_fragment_length = 4 self.batch_mode = "complete_episodes" # .reporting() self.min_time_s_per_iteration = 1 self.min_sample_timesteps_per_iteration = 1000 # .exploration() self.exploration_config = { # The Exploration class to use. "type": "EpsilonGreedy", # Config for the Exploration class' constructor: "initial_epsilon": 1.0, "final_epsilon": 0.01, # Timesteps over which to anneal epsilon. "epsilon_timesteps": 40000, # For soft_q, use: # "exploration_config" = { # "type": "SoftQ" # "temperature": [float, e.g. 1.0] # } } # .evaluation() # Evaluate with epsilon=0 every `evaluation_interval` training iterations. # The evaluation stats will be reported under the "evaluation" metric key. self.evaluation( evaluation_config=AlgorithmConfig.overrides(explore=False) ) # __sphinx_doc_end__ # fmt: on self.worker_side_prioritization = DEPRECATED_VALUE
[docs] @override(SimpleQConfig) def training( self, *, mixer: Optional[str] = NotProvided, mixing_embed_dim: Optional[int] = NotProvided, double_q: Optional[bool] = NotProvided, target_network_update_freq: Optional[int] = NotProvided, replay_buffer_config: Optional[dict] = NotProvided, optim_alpha: Optional[float] = NotProvided, optim_eps: Optional[float] = NotProvided, grad_clip: Optional[float] = NotProvided, # Deprecated args. grad_norm_clipping=DEPRECATED_VALUE, **kwargs, ) -> "QMixConfig": """Sets the training related configuration. Args: mixer: Mixing network. Either "qmix", "vdn", or None. mixing_embed_dim: Size of the mixing network embedding. double_q: Whether to use Double_Q learning. target_network_update_freq: Update the target network every `target_network_update_freq` sample steps. replay_buffer_config: optim_alpha: RMSProp alpha. optim_eps: RMSProp epsilon. grad_clip: If not None, clip gradients during optimization at this value. grad_norm_clipping: Depcrecated in favor of grad_clip Returns: This updated AlgorithmConfig object. """ # Pass kwargs onto super's `training()` method. super().training(**kwargs) if grad_norm_clipping != DEPRECATED_VALUE: deprecation_warning( old="grad_norm_clipping", new="grad_clip", help="Parameter `grad_norm_clipping` has been " "deprecated in favor of grad_clip in QMix. " "This is now the same parameter as in other " "algorithms. `grad_clip` will be overwritten by " "`grad_norm_clipping={}`".format(grad_norm_clipping), error=True, ) grad_clip = grad_norm_clipping if mixer is not NotProvided: self.mixer = mixer if mixing_embed_dim is not NotProvided: self.mixing_embed_dim = mixing_embed_dim if double_q is not NotProvided: self.double_q = double_q if target_network_update_freq is not NotProvided: self.target_network_update_freq = target_network_update_freq if replay_buffer_config is not NotProvided: self.replay_buffer_config = replay_buffer_config if optim_alpha is not NotProvided: self.optim_alpha = optim_alpha if optim_eps is not NotProvided: self.optim_eps = optim_eps if grad_clip is not NotProvided: self.grad_clip = grad_clip return self
@override(SimpleQConfig) def validate(self) -> None: # Call super's validation method. super().validate() if self.framework_str != "torch": raise ValueError( "Only `config.framework('torch')` supported so far for QMix!" )
@Deprecated( old="rllib/algorithms/qmix/", new="rllib_contrib/qmix/", help=ALGO_DEPRECATION_WARNING, error=False, ) class QMix(SimpleQ): @classmethod @override(SimpleQ) def get_default_config(cls) -> AlgorithmConfig: return QMixConfig() @classmethod @override(SimpleQ) def get_default_policy_class( cls, config: AlgorithmConfig ) -> Optional[Type[Policy]]: return QMixTorchPolicy @override(SimpleQ) def training_step(self) -> ResultDict: """QMIX training iteration function. - Sample n MultiAgentBatches from n workers synchronously. - Store new samples in the replay buffer. - Sample one training MultiAgentBatch from the replay buffer. - Learn on the training batch. - Update the target network every `target_network_update_freq` sample steps. - Return all collected training metrics for the iteration. Returns: The results dict from executing the training iteration. """ # Sample n batches from n workers. with self._timers[SAMPLE_TIMER]: new_sample_batches = synchronous_parallel_sample( worker_set=self.workers, concat=False ) for batch in new_sample_batches: # Update counters. self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() # Store new samples in the replay buffer. self.local_replay_buffer.add(batch) # Update target network every `target_network_update_freq` sample steps. cur_ts = self._counters[ NUM_AGENT_STEPS_SAMPLED if self.config.count_steps_by == "agent_steps" else NUM_ENV_STEPS_SAMPLED ] train_results = {} if cur_ts > self.config.num_steps_sampled_before_learning_starts: # Sample n batches from replay buffer until the total number of timesteps # reaches `train_batch_size`. train_batch = sample_min_n_steps_from_buffer( replay_buffer=self.local_replay_buffer, min_steps=self.config.train_batch_size, count_by_agent_steps=self.config.count_steps_by == "agent_steps", ) # Learn on the training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # Update target network every `target_network_update_freq` sample steps. last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config.target_network_update_freq: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target() ) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts update_priorities_in_replay_buffer( self.local_replay_buffer, self.config, train_batch, train_results ) # Update weights and global_vars - after learning on the local worker - # on all remote workers. global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } # Update remote workers' weights and global vars after learning on local # worker. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Return all collected metrics for the iteration. return train_results