Source code for ray.rllib.algorithms.bc.bc
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
from ray.rllib.utils.annotations import override
[docs]class BCConfig(MARWILConfig):
"""Defines a configuration class from which a new BC Algorithm can be built
Example:
>>> from ray.rllib.algorithms.bc import BCConfig
>>> # Run this from the ray directory root.
>>> config = BCConfig().training(lr=0.00001, gamma=0.99)
>>> config = config.offline_data( # doctest: +SKIP
... input_="./rllib/tests/data/cartpole/large.json")
>>> print(config.to_dict()) # doctest:+SKIP
>>> # Build an Algorithm object from the config and run 1 training iteration.
>>> algo = config.build() # doctest: +SKIP
>>> algo.train() # doctest: +SKIP
Example:
>>> from ray.rllib.algorithms.bc import BCConfig
>>> from ray import tune
>>> config = BCConfig()
>>> # Print out some default values.
>>> print(config.beta) # doctest: +SKIP
>>> # Update the config object.
>>> config.training( # doctest:+SKIP
... lr=tune.grid_search([0.001, 0.0001]), beta=0.75
... )
>>> # Set the config object's data path.
>>> # Run this from the ray directory root.
>>> config.offline_data( # doctest:+SKIP
... input_="./rllib/tests/data/cartpole/large.json"
... )
>>> # Set the config object's env, used for evaluation.
>>> config.environment(env="CartPole-v1") # doctest:+SKIP
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.Tuner( # doctest:+SKIP
... "BC",
... param_space=config.to_dict(),
... ).fit()
"""
def __init__(self, algo_class=None):
super().__init__(algo_class=algo_class or BC)
# fmt: off
# __sphinx_doc_begin__
# No need to calculate advantages (or do anything else with the rewards).
self.beta = 0.0
# Advantages (calculated during postprocessing)
# not important for behavioral cloning.
self.postprocess_inputs = False
# __sphinx_doc_end__
# fmt: on
@override(MARWILConfig)
def validate(self) -> None:
super().validate()
if self.beta != 0.0:
raise ValueError("For behavioral cloning, `beta` parameter must be 0.0!")
class BC(MARWIL):
"""Behavioral Cloning (derived from MARWIL).
Simply uses MARWIL with beta force-set to 0.0.
"""
@classmethod
@override(MARWIL)
def get_default_config(cls) -> AlgorithmConfig:
return BCConfig()