延迟计算#

为何使用延迟计算#

在 MLX 中执行操作时,实际并未发生计算。取而代之的是记录了一个计算图。实际的计算只在执行 eval() 时才会发生。

MLX 使用延迟计算,因为它具有一些很好的特性,其中一些特性我们在下面进行描述。

变换计算图#

延迟计算允许我们在不实际执行任何计算的情况下记录计算图。这对于 grad()vmap() 等函数变换以及图优化很有用。

目前,MLX 不会编译和重新运行计算图。它们都是动态生成的。然而,延迟计算使得将来更容易集成编译以提升性能。

只计算所需内容#

在 MLX 中,您无需过多担心计算从未使用过的输出。例如

def fun(x):
    a = fun1(x)
    b = expensive_fun(a)
    return a, b

y, _ = fun(x)

在这里,我们从未实际计算 expensive_fun 的输出。但请谨慎使用这种模式,因为 expensive_fun 的图仍然会构建,这会产生一些相关的开销。

类似地,延迟计算有助于节省内存,同时保持代码简洁。假设您有一个继承自 mlx.nn.Module 的非常大的模型 Model。您可以使用 model = Model() 实例化此模型。通常,这会将所有权重初始化为 float32,但在执行 eval() 之前,初始化实际上不会计算任何内容。如果您使用 float16 权重更新模型,则最大内存消耗将是在使用即时计算时所需内存的一半。

由于延迟计算,在 MLX 中实现这种模式非常简单

model = Model() # no memory used yet
model.load_weights("weights_fp16.safetensors")

何时进行评估#

一个常见的问题是何时使用 eval()。其中的权衡在于避免图变得过大以及批量处理足够多的有用工作。

例如

for _ in range(100):
     a = a + b
     mx.eval(a)
     b = b * 2
     mx.eval(b)

这是一个糟糕的主意,因为每次图评估都会产生固定的开销。另一方面,会存在随计算图大小增长的轻微开销,因此极大的图(尽管计算上是正确的)可能会耗费很高成本。

幸运的是,MLX 能够很好地处理各种大小的计算图:每次评估从几十个操作到数千个操作都可以接受。

大多数数值计算都有一个迭代外循环(例如随机梯度下降中的迭代)。使用 eval() 的一个自然且通常效率很高的地方是在这个外循环的每一次迭代中。

这里有一个具体的例子

for batch in dataset:

    # Nothing has been evaluated yet
    loss, grad = value_and_grad_fn(model, batch)

    # Still nothing has been evaluated
    optimizer.update(model, grad)

    # Evaluate the loss and the new parameters which will
    # run the full gradient computation and optimizer update
    mx.eval(loss, model.parameters())

需要注意的一个重要行为是图何时会被隐式评估。任何时候您 print 一个数组,将其转换为 numpy.ndarray,或通过 memoryview 访问其内存时,图都会被评估。通过 save()(或任何其他 MLX 保存函数)保存数组时也会评估数组。

对标量数组调用 array.item() 也会评估它。在上面的例子中,打印损失 (print(loss)) 或将损失标量添加到列表中 (losses.append(loss.item())) 都会导致图评估。如果这些行位于 mx.eval(loss, model.parameters()) 之前,则这将是部分评估,仅计算前向传播。

此外,对一个数组或一组数组多次调用 eval() 是完全没问题的。这实际上是一个空操作。

警告

使用标量数组进行控制流将导致评估。

这里有一个例子

def fun(x):
    h, y = first_layer(x)
    if y > 0:  # An evaluation is done here!
        z  = second_layer_a(h)
    else:
        z  = second_layer_b(h)
    return z

使用数组进行控制流应谨慎进行。上面的例子可行,甚至可以与梯度变换一起使用。但是,如果评估过于频繁,这可能会非常低效。