ray.rllib.core.learner.learner.Learner.postprocess_gradients_for_module#

Learner.postprocess_gradients_for_module(*, module_id: str, hps: ray.rllib.core.learner.learner.LearnerHyperparameters, module_gradients_dict: Dict[Hashable, Union[torch.Tensor, tensorflow.python.ops.variables.Variable]]) Dict[Hashable, Union[torch.Tensor, tensorflow.python.ops.variables.Variable]][source]#

Applies postprocessing operations on the gradients of the given module.

Parameters
  • module_id – The module ID for which we will postprocess computed gradients. Note that module_gradients_dict already only carries those gradient tensors that belong to this module_id. Other module_id’s gradients are not available in this call.

  • hps – The LearnerHyperparameters specific to the given module_id.

  • module_gradients_dict – A dictionary of gradients in the same (flat) format as self._params, mapping gradient refs to gradient tensors, which are to be postprocessed. You may alter these tensors in place or create new ones and return these in a new dict.

Returns

A dictionary with the updated gradients and the exact same (flat) structure as the incoming module_gradients_dict arg.