ray.rllib.core.learner.learner.Learner.compute_loss#

Learner.compute_loss(*, fwd_out: Union[ray.rllib.policy.sample_batch.MultiAgentBatch, ray.rllib.utils.nested_dict.NestedDict], batch: Union[ray.rllib.policy.sample_batch.MultiAgentBatch, ray.rllib.utils.nested_dict.NestedDict]) Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor, Mapping[str, Any]][source]#

Computes the loss for the module being optimized.

This method must be overridden by multiagent-specific algorithm learners to specify the specific loss computation logic. If the algorithm is single agent compute_loss_for_module() should be overridden instead. fwd_out is the output of the forward_train() method of the underlying MultiAgentRLModule. batch is the data that was used to compute fwd_out. The returned dictionary must contain a key called ALL_MODULES, which will be used to compute gradients. It is recommended to not compute any forward passes within this method, and to use the forward_train() outputs of the RLModule(s) to compute the required tensors for loss calculations.

Parameters
  • fwd_out – Output from a call to the forward_train() method of self.module during training (self.update()).

  • batch – The training batch that was used to compute fwd_out.

Returns

A dictionary mapping module IDs to individual loss terms. The dictionary must contain one protected key ALL_MODULES which will be used for computing gradients through.