ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.maybe_add_time_dimension
ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.maybe_add_time_dimension#
- EagerTFPolicyV2.maybe_add_time_dimension(input_dict: Dict[str, Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]], seq_lens: Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor], framework: str = None)#
Adds a time dimension for recurrent RLModules.
- Parameters
input_dict – The input dict.
seq_lens – The sequence lengths.
framework – The framework to use for adding the time dimensions. If None, will default to the framework of the policy.
- Returns
The input dict, with a possibly added time dimension.