Source code for ray.rllib.algorithms.bandit.bandit
import logging
from typing import Optional, Type, Union
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.bandit.bandit_tf_policy import BanditTFPolicy
from ray.rllib.algorithms.bandit.bandit_torch_policy import BanditTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
logger = logging.getLogger(__name__)
class BanditConfig(AlgorithmConfig):
"""Defines a contextual bandit configuration class from which
a contexual bandit algorithm can be built. Note this config is shared
between BanditLinUCB and BanditLinTS. You likely
want to use the child classes BanditLinTSConfig or BanditLinUCBConfig
instead.
"""
def __init__(self, algo_class: Union["BanditLinTS", "BanditLinUCB"] = None):
super().__init__(algo_class=algo_class)
# fmt: off
# __sphinx_doc_begin__
# Override some of AlgorithmConfig's default values with bandit-specific values.
self.framework_str = "torch"
self.rollout_fragment_length = 1
self.train_batch_size = 1
# Make sure, a `train()` call performs at least 100 env sampling
# timesteps, before reporting results. Not setting this (default is 0)
# would significantly slow down the Bandit Algorithm.
self.min_sample_timesteps_per_iteration = 100
# __sphinx_doc_end__
# fmt: on
[docs]class BanditLinTSConfig(BanditConfig):
"""Defines a configuration class from which a Thompson-sampling bandit can be built.
Example:
>>> from ray.rllib.algorithms.bandit import BanditLinTSConfig # doctest: +SKIP
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
>>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4)# doctest: +SKIP
>>> print(config.to_dict()) # doctest: +SKIP
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> algo = config.build(env=WheelBanditEnv) # doctest: +SKIP
>>> algo.train() # doctest: +SKIP
"""
def __init__(self):
super().__init__(algo_class=BanditLinTS)
# fmt: off
# __sphinx_doc_begin__
# Override some of AlgorithmConfig's default values with bandit-specific values.
self.exploration_config = {"type": "ThompsonSampling"}
# __sphinx_doc_end__
# fmt: on
[docs]class BanditLinUCBConfig(BanditConfig):
"""Defines a config class from which an upper confidence bound bandit can be built.
Example:
>>> from ray.rllib.algorithms.bandit import BanditLinUCBConfig# doctest: +SKIP
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
>>> config = BanditLinUCBConfig() # doctest: +SKIP
>>> config = config.rollouts(num_rollout_workers=4) # doctest: +SKIP
>>> print(config.to_dict()) # doctest: +SKIP
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> algo = config.build(env=WheelBanditEnv) # doctest: +SKIP
>>> algo.train() # doctest: +SKIP
"""
def __init__(self):
super().__init__(algo_class=BanditLinUCB)
# fmt: off
# __sphinx_doc_begin__
# Override some of AlgorithmConfig's default values with bandit-specific values.
self.exploration_config = {"type": "UpperConfidenceBound"}
# __sphinx_doc_end__
# fmt: on
@Deprecated(
old="rllib/algorithms/bandit/",
new="rllib_contrib/bandit/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class BanditLinTS(Algorithm):
"""Bandit Algorithm using ThompsonSampling exploration."""
@classmethod
@override(Algorithm)
def get_default_config(cls) -> BanditLinTSConfig:
return BanditLinTSConfig()
@classmethod
@override(Algorithm)
def get_default_policy_class(
cls, config: AlgorithmConfig
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
return BanditTorchPolicy
elif config["framework"] == "tf2":
return BanditTFPolicy
else:
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
@Deprecated(
old="rllib/algorithms/bandit/",
new="rllib_contrib/bandit/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class BanditLinUCB(Algorithm):
@classmethod
@override(Algorithm)
def get_default_config(cls) -> BanditLinUCBConfig:
return BanditLinUCBConfig()
@classmethod
@override(Algorithm)
def get_default_policy_class(
cls, config: AlgorithmConfig
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
return BanditTorchPolicy
elif config["framework"] == "tf2":
return BanditTFPolicy
else:
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")