快速入门指南#
基础#
导入 mlx.core
并创建一个 array
>> import mlx.core as mx
>> a = mx.array([1, 2, 3, 4])
>> a.shape
[4]
>> a.dtype
int32
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
>> b.dtype
float32
MLX 中的操作是惰性的。MLX 操作的输出不会立即计算,直到需要它们时才会计算。要强制求值一个数组,请使用 eval()
。在某些情况下,数组会自动求值。例如,使用 array.item()
查看标量,打印数组,或将 array
数组转换为 numpy.ndarray
都会自动求值该数组。
>> c = a + b # c not yet evaluated
>> mx.eval(c) # evaluates c
>> c = a + b
>> print(c) # Also evaluates c
array([2, 4, 6, 8], dtype=float32)
>> c = a + b
>> import numpy as np
>> np.array(c) # Also evaluates c
array([2., 4., 6., 8.], dtype=float32)
有关更多详细信息,请参阅惰性求值页面。
函数和图变换#
MLX 提供了标准的函数变换,例如 grad()
和 vmap()
。变换可以任意组合。例如,允许使用 grad(vmap(grad(fn)))
(或任何其他组合)。
>> x = mx.array(0.0)
>> mx.sin(x)
array(0, dtype=float32)
>> mx.grad(mx.sin)(x)
array(1, dtype=float32)
>> mx.grad(mx.grad(mx.sin))(x)
array(-0, dtype=float32)
其他梯度变换包括用于向量-雅可比乘积的 vjp()
和用于雅可比-向量乘积的 jvp()
。
使用 value_and_grad()
可以高效地计算函数的输出以及相对于函数输入的梯度。