转换为 NumPy 和其他框架#
MLX 数组支持使用以下方法在其他框架之间进行转换:
让我们将数组转换为 NumPy 并转回 MLX。
import mlx.core as mx
import numpy as np
a = mx.arange(3)
b = np.array(a) # copy of a
c = mx.array(b) # copy of b
注意
由于 NumPy 不支持 bfloat16
数组,因此您需要先转换为 float16
或 float32
:np.array(a.astype(mx.float32))
。否则,您将收到类似以下错误:Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.
默认情况下,NumPy 会将数据复制到一个新数组中。可以通过创建数组视图来避免这种情况。
a = mx.arange(3)
a_view = np.array(a, copy=False)
print(a_view.flags.owndata) # False
a_view[0] = 1
print(a[0].item()) # 1
注意
类型为 float64
的 NumPy 数组将默认转换为类型为 float32
的 MLX 数组。
NumPy 数组视图是一个普通的 NumPy 数组,但它不拥有自己的内存。这意味着对视图的写入会反映到原始数组中。
虽然这在防止数组复制方面非常强大,但需要注意的是,对数组内存的外部更改无法反映在梯度中。
让我们通过一个例子来演示这一点。
def f(x):
x_view = np.array(x, copy=False)
x_view[:] *= x_view # modify memory without telling mx
return x.sum()
x = mx.array([3.0])
y, df = mx.value_and_grad(f)(x)
print("f(x) = x² =", y.item()) # 9.0
print("f'(x) = 2x !=", df.item()) # 1.0
函数 f
通过内存视图间接修改数组 x
。然而,这种修改并未反映在梯度中,正如最后一行输出的 1.0
所示,它仅代表求和操作的梯度。对 x
进行平方的操作发生在 MLX 外部,这意味着没有包含任何梯度信息。值得注意的是,在数组转换和复制过程中也会出现类似的问题。例如,一个定义为 mx.array(np.array(x)**2).sum()
的函数也会导致不正确的梯度,即使没有在 MLX 内存上执行原地操作。
PyTorch#
警告
PyTorch 对 memoryview
的支持是实验性的,对于多维数组可能会中断。目前建议先转换为 NumPy。
PyTorch 支持 buffer protocol,但需要显式的 memoryview
。
import mlx.core as mx
import torch
a = mx.arange(3)
b = torch.tensor(memoryview(a))
c = mx.array(b.numpy())
从 PyTorch 张量转回 MLX 数组必须通过中间的 NumPy 数组使用 numpy()
完成。
JAX#
JAX 完全支持 buffer protocol。
import mlx.core as mx
import jax.numpy as jnp
a = mx.arange(3)
b = jnp.array(a)
c = mx.array(b)
TensorFlow#
TensorFlow 支持 buffer protocol,但需要显式的 memoryview
。
import mlx.core as mx
import tensorflow as tf
a = mx.arange(3)
b = tf.constant(memoryview(a))
c = mx.array(b)