数组索引#
大多数情况下,索引 MLX array
的工作方式与索引 NumPy numpy.ndarray
相同。有关其工作原理的更多详细信息,请参阅 NumPy 文档。
例如,您可以使用普通整数和切片 (slice
) 来索引数组
>>> arr = mx.arange(10)
>>> arr[3]
array(3, dtype=int32)
>>> arr[-2] # negative indexing works
array(8, dtype=int32)
>>> arr[2:8:2] # start, stop, stride
array([2, 4, 6], dtype=int32)
对于多维数组,...
或 Ellipsis
语法在 NumPy 中也有效
>>> arr = mx.arange(8).reshape(2, 2, 2)
>>> arr[:, :, 0]
array(3, dtype=int32)
array([[0, 2],
[4, 6]], dtype=int32
>>> arr[..., 0]
array([[0, 2],
[4, 6]], dtype=int32
您可以使用 None
进行索引以创建新轴
>>> arr = mx.arange(8)
>>> arr.shape
[8]
>>> arr[None].shape
[1, 8]
>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)
混合使用整数、slice
、...
和 array
索引的工作方式与 NumPy 中完全相同。
其他可能对数组索引有用的函数包括 take()
和 take_along_axis()
。
与 NumPy 的区别#
注意
MLX 索引与 NumPy 索引在两个重要方面不同
索引不执行边界检查。越界索引的行为是未定义的。
尚不支持基于布尔掩码的索引。
不进行边界检查的原因是异常无法从 GPU 传播。在启动内核之前对数组索引进行边界检查将非常低效。
MLX 将来可能会支持使用布尔掩码进行索引。通常,MLX 对输出*形状*依赖于输入*数据*的操作支持有限。MLX 尚不支持的此类操作的其他示例包括 numpy.nonzero()
和单输入的 numpy.where()
。
原地更新#
在 MLX 中,对索引数组进行原地更新是可能的。例如
>>> a = mx.array([1, 2, 3])
>>> a[2] = 0
>>> a
array([1, 2, 0], dtype=int32)
正如在 NumPy 中一样,原地更新将反映在对同一数组的所有引用中
>>> a = mx.array([1, 2, 3])
>>> b = a
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 0], dtype=int32)
允许对使用原地更新的函数进行转换,并且它们按预期工作。例如
def fun(x, idx):
x[idx] = 2.0
return x.sum()
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
在上述示例中,dfdx
将具有正确的梯度,即在 idx
处为零,其他地方为一。