ray.rllib.utils.torch_utils.concat_multi_gpu_td_errors
ray.rllib.utils.torch_utils.concat_multi_gpu_td_errors#
- ray.rllib.utils.torch_utils.concat_multi_gpu_td_errors(policy: Union[TorchPolicy, TorchPolicyV2]) Dict[str, Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]][source]#
Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy.
TD-errors are extracted from the TorchPolicy via its tower_stats property.
- Parameters
policy – The TorchPolicy to extract the TD-error values from.
- Returns
A dict mapping strings “td_error” and “mean_td_error” to the corresponding concatenated and mean-reduced values.