快速入门指南

快速入门指南#

基础#

导入 mlx.core 并创建一个 array

>> import mlx.core as mx
>> a = mx.array([1, 2, 3, 4])
>> a.shape
[4]
>> a.dtype
int32
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
>> b.dtype
float32

MLX 中的操作是惰性的。MLX 操作的输出不会立即计算,直到需要它们时才会计算。要强制求值一个数组,请使用 eval()。在某些情况下,数组会自动求值。例如,使用 array.item() 查看标量,打印数组,或将 array 数组转换为 numpy.ndarray 都会自动求值该数组。

>> c = a + b    # c not yet evaluated
>> mx.eval(c)  # evaluates c
>> c = a + b
>> print(c)     # Also evaluates c
array([2, 4, 6, 8], dtype=float32)
>> c = a + b
>> import numpy as np
>> np.array(c)   # Also evaluates c
array([2., 4., 6., 8.], dtype=float32)

有关更多详细信息,请参阅惰性求值页面。

函数和图变换#

MLX 提供了标准的函数变换,例如 grad()vmap()。变换可以任意组合。例如,允许使用 grad(vmap(grad(fn))) (或任何其他组合)。

>> x = mx.array(0.0)
>> mx.sin(x)
array(0, dtype=float32)
>> mx.grad(mx.sin)(x)
array(1, dtype=float32)
>> mx.grad(mx.grad(mx.sin))(x)
array(-0, dtype=float32)

其他梯度变换包括用于向量-雅可比乘积的 vjp() 和用于雅可比-向量乘积的 jvp()

使用 value_and_grad() 可以高效地计算函数的输出以及相对于函数输入的梯度。