函数变换#
MLX 使用可组合的函数变换进行自动微分、向量化和计算图优化。要查看完整的函数变换列表,请参阅API 文档。
可组合函数变换背后的关键思想是,每个变换都返回一个可以进一步变换的函数。
下面是一个简单的示例
>>> dfdx = mx.grad(mx.sin)
>>> dfdx(mx.array(mx.pi))
array(-1, dtype=float32)
>>> mx.cos(mx.array(mx.pi))
array(-1, dtype=float32)
grad()
对 sin()
的输出只是另一个函数。在这种情况下,它是正弦函数的梯度,恰好是余弦函数。要获得二阶导数,您可以这样做
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
>>> d2fdx2(mx.array(mx.pi / 2))
array(-1, dtype=float32)
>>> mx.sin(mx.array(mx.pi / 2))
array(1, dtype=float32)
对 grad()
的输出使用 grad()
总是可以的。您可以继续获取高阶导数。
任何 MLX 函数变换都可以按任何顺序组合到任何深度。有关自动微分和自动向量化的更多信息,请参阅以下部分。有关 compile()
的更多信息,请参阅编译文档。
自动微分#
MLX 中的自动微分作用于函数,而不是隐式图。
注意
如果您是从 PyTorch 转到 MLX,则不再需要像 backward
、zero_grad
和 detach
这样的函数,也不需要像 requires_grad
这样的属性。
最基本的例子是计算标量值函数的梯度,正如我们上面看到的。您可以使用 grad()
和 value_and_grad()
函数来计算更复杂函数的梯度。默认情况下,这些函数会计算关于第一个参数的梯度
def loss_fn(w, x, y):
return mx.mean(mx.square(w * x - y))
w = mx.array(1.0)
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to w:
grad_fn = mx.grad(loss_fn)
dloss_dw = grad_fn(w, x, y)
# Prints array(-1, dtype=float32)
print(dloss_dw)
# To get the gradient with respect to x we can do:
grad_fn = mx.grad(loss_fn, argnums=1)
dloss_dx = grad_fn(w, x, y)
# Prints array([-1, 1], dtype=float32)
print(dloss_dx)
一种获取损失和梯度的方法是先调用 loss_fn
,然后调用 grad_fn
,但这可能导致大量的重复工作。相反,您应该使用 value_and_grad()
。继续上面的例子
# Computes the gradient of loss_fn with respect to w:
loss_and_grad_fn = mx.value_and_grad(loss_fn)
loss, dloss_dw = loss_and_grad_fn(w, x, y)
# Prints array(1, dtype=float32)
print(loss)
# Prints array(-1, dtype=float32)
print(dloss_dw)
您还可以计算关于任意嵌套的 Python 数组容器(特别是 list
、tuple
或 dict
中的任何一个)的梯度。
假设我们想在上面的例子中使用一个权重和一个偏置参数。一个不错的做法是以下方法
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to both the
# weight and bias:
grad_fn = mx.grad(loss_fn)
grads = grad_fn(params, x, y)
# Prints
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
print(grads)
请注意,参数的树形结构在梯度中得到了保留。
在某些情况下,您可能希望阻止梯度通过函数的一部分传播。为此,您可以使用 stop_gradient()
。
自动向量化#
使用 vmap()
来自动化复杂函数的向量化。为了清晰起见,我们将通过一个基本且刻意的例子,但对于手动难以优化的更复杂函数,vmap()
功能强大。
警告
某些操作尚不支持 vmap()
。如果您遇到类似 ValueError: Primitive's vmap not implemented.
的错误,请提交一个问题并附上您的函数。我们将优先处理它。
一种朴素的向量集加法方法是使用循环
xs = mx.random.uniform(shape=(4096, 100))
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
相反,您可以使用 vmap()
自动向量化加法运算
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
in_axes
参数可用于指定对相应输入的哪些维度进行向量化。类似地,使用 out_axes
指定向量化轴应位于输出的哪个位置。
让我们对这两个不同版本进行计时
import timeit
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
在 M1 Max 上,朴素版本总共耗时 5.639
秒,而向量化版本仅耗时 0.024
秒,快了 200 多倍。
当然,这个操作相当刻意。更好的方法是直接执行 xs + ys.T
,但对于更复杂的函数,vmap()
非常方便。