mlx.core.unflatten

目录

mlx.core.unflatten#

unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: None | Stream | Device = None) array#

将数组的某个轴展平(unflatten)到指定形状。

参数:
  • a (array) – 输入数组。

  • axis (int) – 要展平(unflatten)的轴。

  • shape (tuple(int)) – 要展平(unflatten)成的形状。最多一个条目可以是 -1,此时将推断相应的尺寸。

  • stream (Stream, optional) – 流或设备。默认为 None,此时使用默认设备的默认流。

返回值:

展平(unflatten)后的数组。

返回类型:

array

示例

>>> a = mx.array([1, 2, 3, 4])
>>> mx.unflatten(a, 0, (2, -1))
array([[1, 2], [3, 4]], dtype=int32)