mlx.core.slice_update

mlx.core.slice_update#

slice_update(a: array, update: array, start_indices: array, axes: Sequence[int], *, stream: None | Stream | Device = None) array#

更新输入数组的子数组。

参数:
  • a (array) – 要更新的输入数组

  • update (array) – 更新数组。

  • start_indices (array) – 切片的起始索引位置。

  • axes (tuple(int)) – 对应于 start_indices 中的索引的轴。

返回值:

形状和类型与输入数组相同的输出数组。

返回类型:

array

示例

>>> a = mx.zeros((3, 3))
>>> mx.slice_update(a, mx.ones((1, 2)), start_indices=mx.array(1, 1), axes=(0, 1))
array([[0, 0, 0],
       [0, 1, 0],
       [0, 1, 0]], dtype=float32)