mlx.core.flatten

目录

mlx.core.flatten#

flatten(a: array, /, start_axis: int = 0, end_axis: int = -1, *, stream: None | Stream | Device = None) array#

展平一个数组。

展平的轴将在 start_axisend_axis 之间(包含两端)。支持负轴。将负轴转换为正轴后,超出有效范围的轴将被限定在有效值内,start_axis 变为 0end_axis 变为 ndim - 1

参数:
  • a (array) – 输入数组。

  • start_axis (int, 可选) – 第一个要展平的维度。默认为 0

  • end_axis (int, 可选) – 最后一个要展平的维度。默认为 -1

  • stream (Stream, 可选) – 流或设备。默认为 None,此时使用默认设备的默认流。

返回值:

展平后的数组。

返回类型:

array

示例

>>> a = mx.array([[1, 2], [3, 4]])
>>> mx.flatten(a)
array([1, 2, 3, 4], dtype=int32)
>>>
>>> mx.flatten(a, start_axis=0, end_axis=-1)
array([1, 2, 3, 4], dtype=int32)