延迟计算#
为何使用延迟计算#
在 MLX 中执行操作时,实际并未发生计算。取而代之的是记录了一个计算图。实际的计算只在执行 eval()
时才会发生。
MLX 使用延迟计算,因为它具有一些很好的特性,其中一些特性我们在下面进行描述。
变换计算图#
延迟计算允许我们在不实际执行任何计算的情况下记录计算图。这对于 grad()
和 vmap()
等函数变换以及图优化很有用。
目前,MLX 不会编译和重新运行计算图。它们都是动态生成的。然而,延迟计算使得将来更容易集成编译以提升性能。
只计算所需内容#
在 MLX 中,您无需过多担心计算从未使用过的输出。例如
def fun(x):
a = fun1(x)
b = expensive_fun(a)
return a, b
y, _ = fun(x)
在这里,我们从未实际计算 expensive_fun
的输出。但请谨慎使用这种模式,因为 expensive_fun
的图仍然会构建,这会产生一些相关的开销。
类似地,延迟计算有助于节省内存,同时保持代码简洁。假设您有一个继承自 mlx.nn.Module
的非常大的模型 Model
。您可以使用 model = Model()
实例化此模型。通常,这会将所有权重初始化为 float32
,但在执行 eval()
之前,初始化实际上不会计算任何内容。如果您使用 float16
权重更新模型,则最大内存消耗将是在使用即时计算时所需内存的一半。
由于延迟计算,在 MLX 中实现这种模式非常简单
model = Model() # no memory used yet
model.load_weights("weights_fp16.safetensors")
何时进行评估#
一个常见的问题是何时使用 eval()
。其中的权衡在于避免图变得过大以及批量处理足够多的有用工作。
例如
for _ in range(100):
a = a + b
mx.eval(a)
b = b * 2
mx.eval(b)
这是一个糟糕的主意,因为每次图评估都会产生固定的开销。另一方面,会存在随计算图大小增长的轻微开销,因此极大的图(尽管计算上是正确的)可能会耗费很高成本。
幸运的是,MLX 能够很好地处理各种大小的计算图:每次评估从几十个操作到数千个操作都可以接受。
大多数数值计算都有一个迭代外循环(例如随机梯度下降中的迭代)。使用 eval()
的一个自然且通常效率很高的地方是在这个外循环的每一次迭代中。
这里有一个具体的例子
for batch in dataset:
# Nothing has been evaluated yet
loss, grad = value_and_grad_fn(model, batch)
# Still nothing has been evaluated
optimizer.update(model, grad)
# Evaluate the loss and the new parameters which will
# run the full gradient computation and optimizer update
mx.eval(loss, model.parameters())
需要注意的一个重要行为是图何时会被隐式评估。任何时候您 print
一个数组,将其转换为 numpy.ndarray
,或通过 memoryview
访问其内存时,图都会被评估。通过 save()
(或任何其他 MLX 保存函数)保存数组时也会评估数组。
对标量数组调用 array.item()
也会评估它。在上面的例子中,打印损失 (print(loss)
) 或将损失标量添加到列表中 (losses.append(loss.item())
) 都会导致图评估。如果这些行位于 mx.eval(loss, model.parameters())
之前,则这将是部分评估,仅计算前向传播。
此外,对一个数组或一组数组多次调用 eval()
是完全没问题的。这实际上是一个空操作。
警告
使用标量数组进行控制流将导致评估。
这里有一个例子
def fun(x):
h, y = first_layer(x)
if y > 0: # An evaluation is done here!
z = second_layer_a(h)
else:
z = second_layer_b(h)
return z
使用数组进行控制流应谨慎进行。上面的例子可行,甚至可以与梯度变换一起使用。但是,如果评估过于频繁,这可能会非常低效。