ray.rllib.utils.numpy.flatten_inputs_to_1d_tensor#

ray.rllib.utils.numpy.flatten_inputs_to_1d_tensor(inputs: Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor, dict, tuple], spaces_struct: Optional[Union[<MagicMock name='mock.spaces.Space' id='140494124151808'>, dict, tuple]] = None, time_axis: bool = False) Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor][source]#

Flattens arbitrary input structs according to the given spaces struct.

Returns a single 1D tensor resulting from the different input components’ values.

Thereby: - Boxes (any shape) get flattened to (B, [T]?, -1). Note that image boxes are not treated differently from other types of Boxes and get flattened as well. - Discrete (int) values are one-hot’d, e.g. a batch of [1, 0, 3] (B=3 with Discrete(4) space) results in [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]. - MultiDiscrete values are multi-one-hot’d, e.g. a batch of [[0, 2], [1, 4]] (B=2 with MultiDiscrete([2, 5]) space) results in [[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 1]].

Parameters
  • inputs – The inputs to be flattened.

  • spaces_struct – The structure of the spaces that behind the input

  • time_axis – Whether all inputs have a time-axis (after the batch axis). If True, will keep not only the batch axis (0th), but the time axis (1st) as-is and flatten everything from the 2nd axis up.

Returns

A single 1D tensor resulting from concatenating all flattened/one-hot’d input components. Depending on the time_axis flag, the shape is (B, n) or (B, T, n).

Examples

>>> # B=2
>>> from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor
>>> from gymnasium.spaces import Discrete, Box
>>> out = flatten_inputs_to_1d_tensor( 
...     {"a": [1, 0], "b": [[[0.0], [0.1]], [1.0], [1.1]]},
...     spaces_struct=dict(a=Discrete(2), b=Box(shape=(2, 1)))
... ) 
>>> print(out) 
[[0.0, 1.0,  0.0, 0.1], [1.0, 0.0,  1.0, 1.1]]  # B=2 n=4
>>> # B=2; T=2
>>> out = flatten_inputs_to_1d_tensor( 
...     ([[1, 0], [0, 1]],
...      [[[0.0, 0.1], [1.0, 1.1]], [[2.0, 2.1], [3.0, 3.1]]]),
...     spaces_struct=tuple([Discrete(2), Box(shape=(2, ))]),
...     time_axis=True
... ) 
>>> print(out) 
[[[0.0, 1.0, 0.0, 0.1], [1.0, 0.0, 1.0, 1.1]],        [[1.0, 0.0, 2.0, 2.1], [0.0, 1.0, 3.0, 3.1]]]  # B=2 T=2 n=4