ray.rllib.core.learner.learner.Learner#

class ray.rllib.core.learner.learner.Learner(*, module_spec: Optional[Union[ray.rllib.core.rl_module.rl_module.SingleAgentRLModuleSpec, ray.rllib.core.rl_module.marl_module.MultiAgentRLModuleSpec]] = None, module: Optional[ray.rllib.core.rl_module.rl_module.RLModule] = None, learner_group_scaling_config: Optional[ray.rllib.core.learner.scaling_config.LearnerGroupScalingConfig] = None, learner_hyperparameters: Optional[ray.rllib.core.learner.learner.LearnerHyperparameters] = None, framework_hyperparameters: Optional[ray.rllib.core.learner.learner.FrameworkHyperparameters] = None)[source]#

Bases: object

Base class for Learners.

This class will be used to train RLModules. It is responsible for defining the loss function, and updating the neural network weights that it owns. It also provides a way to add/remove modules to/from RLModules in a multi-agent scenario, in the middle of training (This is useful for league based training).

TF and Torch specific implementation of this class fills in the framework-specific implementation details for distributed training, and for computing and applying gradients. User should not need to sub-class this class, but instead inherit from the TF or Torch specific sub-classes to implement their algorithm-specific update logic.

Parameters
  • module_spec – The module specification for the RLModule that is being trained. If the module is a single agent module, after building the module it will be converted to a multi-agent module with a default key. Can be none if the module is provided directly via the module argument. Refer to ray.rllib.core.rl_module.SingleAgentRLModuleSpec or ray.rllib.core.rl_module.MultiAgentRLModuleSpec for more info.

  • module – If learner is being used stand-alone, the RLModule can be optionally passed in directly instead of the through the module_spec.

  • scaling_config – Configuration for scaling the learner actors. Refer to ray.rllib.core.learner.scaling_config.LearnerGroupScalingConfig for more info.

  • learner_hyperparameters – The hyper-parameters for the Learner. Algorithm specific learner hyper-parameters will passed in via this argument. For example in PPO the vf_loss_coeff hyper-parameter will be passed in via this argument. Refer to ray.rllib.core.learner.learner.LearnerHyperparameters for more info.

  • framework_hps – The framework specific hyper-parameters. This will be used to pass in any framework specific hyper-parameter that will impact the module creation. For example eager_tracing in TF or torch.compile() in Torch. Refer to ray.rllib.core.learner.learner.FrameworkHyperparameters for more info.

Usage pattern:

Note: We use PPO and torch as an example here because many of the showcased components need implementations to come together. However, the same pattern is generally applicable.

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
import gymnasium as gym

env = gym.make("CartPole-v1")

# Create a single agent RL module spec.
module_spec = SingleAgentRLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)

# Create a learner instance that will train the module
learner = TorchLearner(module_spec=module_spec)

# Note: the learner should be built before it can be used.
learner.build()

# Take one gradient update on the module and report the results
# results = learner.update(...)

# Add a new module, perhaps for league based training
learner.add_module(
    module_id="new_player",
    module_spec=SingleAgentRLModuleSpec(
        module_class=PPOTorchRLModule,
        observation_space=env.observation_space,
        action_space=env.action_space,
        model_config_dict = {"hidden": [128, 128]},
        catalog_class = PPOCatalog,
    )
)

# Take another gradient update with both previous and new modules.
# results = learner.update(...)

# Remove a module
learner.remove_module("new_player")

# Will train previous modules only.
# results = learner.update(...)

# Get the state of the learner
state = learner.get_state()

# Set the state of the learner
learner.set_state(state)

# Get the weights of the underly multi-agent RLModule
weights = learner.get_module_state()

# Set the weights of the underly multi-agent RLModule
learner.set_module_state(weights)

Extension pattern:

from ray.rllib.core.learner.torch.torch_learner import TorchLearner

class MyLearner(TorchLearner):

   def compute_loss(self, fwd_out, batch):
       # compute the loss based on batch and output of the forward pass
       # to access the learner hyper-parameters use `self._hps`
       return {ALL_MODULES: loss}

Methods

add_module(*, module_id, module_spec)

Add a module to the underlying MultiAgentRLModule and the Learner.

additional_update(*[, module_ids_to_update])

Apply additional non-gradient based updates to this Algorithm.

additional_update_for_module(*, module_id, ...)

Apply additional non-gradient based updates for a single module.

apply_gradients(gradients_dict)

Applies the gradients to the MultiAgentRLModule parameters.

build()

Builds the Learner.

compile_results(*, batch, fwd_out, ...)

Compile results from the update in a numpy-friendly format.

compute_gradients(loss_per_module, **kwargs)

Computes the gradients based on the given losses.

compute_loss(*, fwd_out, batch)

Computes the loss for the module being optimized.

compute_loss_for_module(*, module_id, hps, ...)

Computes the loss for a single module.

configure_optimizers()

Configures, creates, and registers the optimizers for this Learner.

configure_optimizers_for_module(module_id, hps)

Configures an optimizer for the given module_id.

filter_param_dict_for_optimizer(param_dict, ...)

Reduces the given ParamDict to contain only parameters for given optimizer.

get_module_state([module_ids])

Returns the state of the underlying MultiAgentRLModule.

get_optimizer([module_id, optimizer_name])

Returns the optimizer object, configured under the given module_id and name.

get_optimizer_state()

Returns the state of all optimizers currently registered in this Learner.

get_optimizers_for_module([module_id])

Returns a list of (optimizer_name, optimizer instance)-tuples for module_id.

get_param_ref(param)

Returns a hashable reference to a trainable parameter.

get_parameters(module)

Returns the list of parameters of a module.

get_state()

Get the state of the learner.

load_state(path)

Load the state of the learner from path

postprocess_gradients(gradients_dict)

Applies potential postprocessing operations on the gradients.

postprocess_gradients_for_module(*, ...)

Applies postprocessing operations on the gradients of the given module.

register_metric(module_id, key, value)

Registers a single key/value metric pair for loss- and gradient stats.

register_metrics(module_id, metrics_dict)

Registers several key/value metric pairs for loss- and gradient stats.

register_optimizer(*[, module_id, ...])

Registers an optimizer with a ModuleID, name, param list and lr-scheduler.

remove_module(module_id)

Remove a module from the Learner.

save_state(path)

Save the state of the learner to path

set_module_state(state)

Sets the state of the underlying MultiAgentRLModule

set_optimizer_state(state)

Sets the state of all optimizers currently registered in this Learner.

set_state(state)

Set the state of the learner.

update(batch, *[, minibatch_size, ...])

Do num_iters minibatch updates given the original batch.

Attributes

TOTAL_LOSS_KEY

distributed

Whether the learner is running in distributed mode.

framework

hps

The hyper-parameters for the learner.

module

The multi-agent RLModule that is being trained.