Source code for ray.rllib.algorithms.pg.pg

from typing import List, Optional, Type, Union

from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING


[docs]class PGConfig(AlgorithmConfig): """Defines a configuration class from which a PG Algorithm can be built. Example: >>> from ray.rllib.algorithms.pg import PGConfig >>> config = PGConfig().training(lr=0.01).resources(num_gpus=1) >>> print(config.to_dict()) # doctest: +SKIP >>> # Build a Algorithm object from the config and run 1 training iteration. >>> algo = config.build(env="CartPole-v1") # doctest: +SKIP >>> algo.train() # doctest: +SKIP Example: >>> from ray.rllib.algorithms.pg import PGConfig >>> from ray import air >>> from ray import tune >>> config = PGConfig() >>> # Print out some default values. >>> print(config.lr) # doctest: +SKIP 0.0004 >>> # Update the config object. >>> config = config.training(lr=tune.grid_search([0.001, 0.0001])) >>> # Set the config object's env. >>> config = config.environment(env="CartPole-v1") >>> # Use to_dict() to get the old-style python config dict >>> # when running with tune. >>> tune.Tuner( # doctest: +SKIP ... "PG", ... run_config=air.RunConfig(stop={"episode_reward_mean": 200}), ... param_space=config.to_dict(), ... ).fit() """ def __init__(self, algo_class=None): """Initializes a PGConfig instance.""" super().__init__(algo_class=algo_class or PG) # fmt: off # __sphinx_doc_begin__ # Override some of AlgorithmConfig's default values with PG-specific values. self.lr_schedule = None self.lr = 0.0004 self.rollout_fragment_length = "auto" self.train_batch_size = 200 self._disable_preprocessor_api = True self.exploration_config = { # The Exploration class to use. In the simplest case, this is the name # (str) of any class present in the `rllib.utils.exploration` package. # You can also provide the python class directly or the full location # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy. # EpsilonGreedy"). "type": "StochasticSampling", # Add constructor kwargs here (if any). } # __sphinx_doc_end__ # fmt: on
[docs] @override(AlgorithmConfig) def training( self, *, lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, **kwargs, ) -> "PGConfig": """Sets the training related configuration. Args: gamma: Float specifying the discount factor of the Markov Decision process. lr: The default learning rate. train_batch_size: Training batch size, if applicable. model: Arguments passed into the policy model. See models/catalog.py for a full list of the available model options. optimizer: Arguments to pass to the policy optimizer. lr_schedule: Learning rate schedule. In the format of [[timestep, lr-value], [timestep, lr-value], ...] Intermediary timesteps will be assigned to interpolated learning rate values. A schedule should normally start from timestep 0. Returns: This updated AlgorithmConfig object. """ # Pass kwargs onto super's `training()` method. super().training(**kwargs) if lr_schedule is not NotProvided: self.lr_schedule = lr_schedule return self
@override(AlgorithmConfig) def validate(self) -> None: # Call super's validation method. super().validate() # Synchronous sampling, on-policy PG algo -> Check mismatches between # `rollout_fragment_length` and `train_batch_size` to avoid user confusion. self.validate_train_batch_size_vs_rollout_fragment_length()
@Deprecated( old="rllib/algorithms/pg/", new="rllib_contrib/pg/", help=ALGO_DEPRECATION_WARNING, error=False, ) class PG(Algorithm): """Policy Gradient (PG) Algorithm. Defines the distributed Algorithm class for policy gradients. See `pg_[tf|torch]_policy.py` for the definition of the policy losses for TensorFlow and PyTorch. Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#pg Only overrides the default config- and policy selectors (`get_default_policy_class` and `get_default_config`). Utilizes the default `training_step()` method of `Algorithm`. """ @classmethod @override(Algorithm) def get_default_config(cls) -> AlgorithmConfig: return PGConfig() @classmethod @override(Algorithm) def get_default_policy_class( cls, config: AlgorithmConfig ) -> Optional[Type[Policy]]: if config["framework"] == "torch": from ray.rllib.algorithms.pg.pg_torch_policy import PGTorchPolicy return PGTorchPolicy elif config.framework_str == "tf": from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy return PGTF1Policy else: from ray.rllib.algorithms.pg.pg_tf_policy import PGTF2Policy return PGTF2Policy