ray.rllib.core.learner.learner.Learner.compute_loss
ray.rllib.core.learner.learner.Learner.compute_loss#
- Learner.compute_loss(*, fwd_out: Union[ray.rllib.policy.sample_batch.MultiAgentBatch, ray.rllib.utils.nested_dict.NestedDict], batch: Union[ray.rllib.policy.sample_batch.MultiAgentBatch, ray.rllib.utils.nested_dict.NestedDict]) Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor, Mapping[str, Any]][source]#
Computes the loss for the module being optimized.
This method must be overridden by multiagent-specific algorithm learners to specify the specific loss computation logic. If the algorithm is single agent
compute_loss_for_module()should be overridden instead.fwd_outis the output of theforward_train()method of the underlying MultiAgentRLModule.batchis the data that was used to computefwd_out. The returned dictionary must contain a key called ALL_MODULES, which will be used to compute gradients. It is recommended to not compute any forward passes within this method, and to use theforward_train()outputs of the RLModule(s) to compute the required tensors for loss calculations.- Parameters
fwd_out – Output from a call to the
forward_train()method of self.module during training (self.update()).batch – The training batch that was used to compute
fwd_out.
- Returns
A dictionary mapping module IDs to individual loss terms. The dictionary must contain one protected key ALL_MODULES which will be used for computing gradients through.