转换为 NumPy 和其他框架

转换为 NumPy 和其他框架#

MLX 数组支持使用以下方法在其他框架之间进行转换:

让我们将数组转换为 NumPy 并转回 MLX。

import mlx.core as mx
import numpy as np

a = mx.arange(3)
b = np.array(a) # copy of a
c = mx.array(b) # copy of b

注意

由于 NumPy 不支持 bfloat16 数组,因此您需要先转换为 float16float32np.array(a.astype(mx.float32))。否则,您将收到类似以下错误:Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.

默认情况下,NumPy 会将数据复制到一个新数组中。可以通过创建数组视图来避免这种情况。

a = mx.arange(3)
a_view = np.array(a, copy=False)
print(a_view.flags.owndata) # False
a_view[0] = 1
print(a[0].item()) # 1

注意

类型为 float64 的 NumPy 数组将默认转换为类型为 float32 的 MLX 数组。

NumPy 数组视图是一个普通的 NumPy 数组,但它不拥有自己的内存。这意味着对视图的写入会反映到原始数组中。

虽然这在防止数组复制方面非常强大,但需要注意的是,对数组内存的外部更改无法反映在梯度中。

让我们通过一个例子来演示这一点。

def f(x):
    x_view = np.array(x, copy=False)
    x_view[:] *= x_view # modify memory without telling mx
    return x.sum()

x = mx.array([3.0])
y, df = mx.value_and_grad(f)(x)
print("f(x) = x² =", y.item()) # 9.0
print("f'(x) = 2x !=", df.item()) # 1.0

函数 f 通过内存视图间接修改数组 x。然而,这种修改并未反映在梯度中,正如最后一行输出的 1.0 所示,它仅代表求和操作的梯度。对 x 进行平方的操作发生在 MLX 外部,这意味着没有包含任何梯度信息。值得注意的是,在数组转换和复制过程中也会出现类似的问题。例如,一个定义为 mx.array(np.array(x)**2).sum() 的函数也会导致不正确的梯度,即使没有在 MLX 内存上执行原地操作。

PyTorch#

警告

PyTorch 对 memoryview 的支持是实验性的,对于多维数组可能会中断。目前建议先转换为 NumPy。

PyTorch 支持 buffer protocol,但需要显式的 memoryview

import mlx.core as mx
import torch

a = mx.arange(3)
b = torch.tensor(memoryview(a))
c = mx.array(b.numpy())

从 PyTorch 张量转回 MLX 数组必须通过中间的 NumPy 数组使用 numpy() 完成。

JAX#

JAX 完全支持 buffer protocol。

import mlx.core as mx
import jax.numpy as jnp

a = mx.arange(3)
b = jnp.array(a)
c = mx.array(b)

TensorFlow#

TensorFlow 支持 buffer protocol,但需要显式的 memoryview

import mlx.core as mx
import tensorflow as tf

a = mx.arange(3)
b = tf.constant(memoryview(a))
c = mx.array(b)