ray.rllib.utils.torch_utils.concat_multi_gpu_td_errors#

ray.rllib.utils.torch_utils.concat_multi_gpu_td_errors(policy: Union[TorchPolicy, TorchPolicyV2]) Dict[str, Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]][source]#

Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy.

TD-errors are extracted from the TorchPolicy via its tower_stats property.

Parameters

policy – The TorchPolicy to extract the TD-error values from.

Returns

A dict mapping strings “td_error” and “mean_td_error” to the corresponding concatenated and mean-reduced values.