ray.rllib.policy.torch_policy_v2.TorchPolicyV2.loss
ray.rllib.policy.torch_policy_v2.TorchPolicyV2.loss#
- TorchPolicyV2.loss(model: ray.rllib.models.modelv2.ModelV2, dist_class: Type[ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper], train_batch: ray.rllib.policy.sample_batch.SampleBatch) Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor, List[Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]]][source]#
Constructs the loss function.
- Parameters
model – The Model to calculate the loss for.
dist_class – The action distr. class.
train_batch – The training data.
- Returns
Loss tensor given the input batch.