mlx.core.unflatten# unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: None | Stream | Device = None) → array# 将数组的某个轴展平(unflatten)到指定形状。 参数: a (array) – 输入数组。 axis (int) – 要展平(unflatten)的轴。 shape (tuple(int)) – 要展平(unflatten)成的形状。最多一个条目可以是 -1,此时将推断相应的尺寸。 stream (Stream, optional) – 流或设备。默认为 None,此时使用默认设备的默认流。 返回值: 展平(unflatten)后的数组。 返回类型: array 示例 >>> a = mx.array([1, 2, 3, 4]) >>> mx.unflatten(a, 0, (2, -1)) array([[1, 2], [3, 4]], dtype=int32)