mlx.core.take_along_axis

mlx.core.take_along_axis#

take_along_axis(a: array, /, indices: array, axis: int | None = None, *, stream: None | Stream | Device = None) array#

沿指定索引的轴获取值。

参数:
  • a (array) – 输入数组。

  • indices (array) – 索引数组。除了 axis 维度外,此数组应可与输入数组广播。

  • axis (intNone) – 输入数组中用于获取值的轴。如果 axis == None,则在索引操作之前,数组会被展平为 1D。

返回:

输出数组。

返回类型:

array