编译#

MLX 提供了 compile() 函数变换,它可以编译计算图。函数编译通过合并通用操作和融合特定操作来生成更小的图。在许多情况下,这可以显著提高运行时性能和内存使用效率。

开始使用 compile() 很简单,但对于更复杂的图和高级用法,需要注意一些边缘情况。

编译基础#

我们从一个简单的示例开始

def fun(x, y):
    return mx.exp(-x) + y

x = mx.array(1.0)
y = mx.array(2.0)

# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))

# Compile the function
compiled_fun = mx.compile(fun)

# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))

常规函数和编译后的函数的输出在数值精度上是一致的。

第一次调用编译后的函数时,MLX 会构建计算图、优化它并生成和编译代码。这个过程可能相对较慢。然而,MLX 会缓存编译后的函数,因此多次调用同一个编译函数不会触发新的编译。这意味着您通常应该编译计划多次使用的函数。

def fun(x, y):
    return mx.exp(-x) + y

x = mx.array(1.0)
y = mx.array(2.0)

compiled_fun = mx.compile(fun)

# Compiled here
compiled_fun(x, y)

# Not compiled again
compiled_fun(x, y)

# Not compiled again
mx.compile(fun)(x, y)

需要注意一些重要情况,它们可能导致函数重新编译

  • 改变形状或维度数量

  • 改变任何输入的类型

  • 改变函数的输入数量

在某些情况下,只有部分编译栈会重新运行(例如改变形状时),而在其他情况下,完整的编译栈会重新运行(例如改变类型时)。总的来说,您应该避免过于频繁地编译函数。

另一个需要注意的模式是编译那些频繁创建和销毁的函数。例如,在循环中编译匿名函数时就可能发生这种情况

a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
    mx.compile(lambda x: mx.exp(mx.abs(x)))(a)

示例加速#

mlx.nn.gelu() 是一个非线性激活函数,常用于基于 Transformer 的模型。其实现涉及多个一元和二元逐元素操作

def gelu(x):
    return x * (1 + mx.erf(x / math.sqrt(2))) / 2

如果您将此函数用于小型数组,它将受开销限制。如果您将其用于大型数组,它将受内存带宽限制。然而,gelu 中的所有操作都可以使用 compile() 融合到一个单独的内核中。这可以显著加速这两种情况。

让我们比较常规函数和编译后函数的运行时。我们将使用以下计时辅助函数,它会进行热身并处理同步

import time

def timeit(fun, x):
    # warm up
    for _ in range(10):
        mx.eval(fun(x))

    tic = time.perf_counter()
    for _ in range(100):
        mx.eval(fun(x))
    toc = time.perf_counter()
    tpi = 1e3 * (toc - tic) / 100
    print(f"Time per iteration {tpi:.3f} (ms)")

现在创建一个数组,并对两个函数进行基准测试

x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)

在 M1 Max 上,时间分别为 15.5 毫秒和 3.1 毫秒。编译后的 gelu 快了五倍。

调试#

首次调用编译函数时,它会使用占位符输入进行追踪。这意味着您不能在编译函数内部评估数组(例如打印其内容)。

@mx.compile
def fun(x):
    z = -x
    print(z)  # Crash
    return mx.exp(z)

fun(mx.array(5.0))

对于调试,检查数组会很有帮助。一种方法是使用 disable_compile() 函数或 MLX_DISABLE_COMPILE 标志全局禁用编译。例如,即使 fun 是编译后的函数,以下代码也是可以的

@mx.compile
def fun(x):
    z = -x
    print(z) # Okay
    return mx.exp(z)

mx.disable_compile()
fun(mx.array(5.0))

纯函数#

编译后的函数旨在成为纯函数;也就是说,它们不应该有副作用。例如

state = []

@mx.compile
def fun(x, y):
    z = x + y
    state.append(z)
    return mx.exp(z)

fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)

第一次调用 fun 后,state 列表将持有一个占位符数组。占位符不包含任何数据;它仅用于构建计算图。打印此类数组会导致程序崩溃。

您有两种方法来处理这个问题。第一种方法是简单地将 state 作为输出返回

state = []

@mx.compile
def fun(x, y):
   z = x + y
   state.append(z)
   return mx.exp(z), state

 _, state = fun(mx.array(1.0), mx.array(2.0))
 # Prints [array(3, dtype=float32)]
 print(state)

在某些情况下,返回更新后的状态可能很不方便。因此,compile() 有一个参数用于捕获隐式输出

from functools import partial

state = []

# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
    z = x + y
    state.append(z)
    return mx.exp(z), state

fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)

这对于编译包含更新数组容器的函数特别有用,这在训练 mlx.nn.Module 参数时很常见。

编译后的函数还会将不在参数列表中的任何输入视为常量。例如

state = [mx.array(1.0)]

@mx.compile
def fun(x):
    return x + state[0]

# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))

# Update state
state[0] = mx.array(5.0)

# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))

为了让状态的改变反映在 fun 的输出中,您同样有两种选择。第一种选择是简单地将 state 作为输入传递给函数。在某些情况下,这可能很不方便。因此,compile() 也有一个参数用于捕获隐式输入

from functools import partial
state = [mx.array(1.0)]

# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
    return x + state[0]

# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))

# Update state
state[0] = mx.array(5.0)

# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))

编译训练图#

本节将通过一个常见设置的简单示例来演示如何使用 compile():使用带有状态的 mlx.optimizers.Optimizer 训练 mlx.nn.Module 模型。我们将展示如何使用 compile() 编译完整的正向、反向和更新过程。

首先,这里是没有进行任何编译的简单示例

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))

# 0, 1 targets
y = mx.array([0, 1, 0, 1])

# Simple linear model
model = nn.Linear(10, 1)

# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)

def loss_fn(model, x, y):
    logits = model(x).squeeze()
    return nn.losses.binary_cross_entropy(logits, y)

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Perform 10 steps of gradient descent
for it in range(10):
    loss, grads = loss_and_grad_fn(model, x, y)
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)

要编译更新过程,我们可以将其全部放入一个函数中,并使用适当的输入和输出捕获进行编译。这是同一个示例,但经过了编译

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial

# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))

# 0, 1 targets
y = mx.array([0, 1, 0, 1])

# Simple linear model
model = nn.Linear(10, 1)

# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)

def loss_fn(model, x, y):
    logits = model(x).squeeze()
    return nn.losses.binary_cross_entropy(logits, y)

# The state that will be captured as input and output
state = [model.state, optimizer.state]

@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    loss, grads = loss_and_grad_fn(model, x, y)
    optimizer.update(model, grads)
    return loss

# Perform 10 steps of gradient descent
for it in range(10):
    loss = step(x, y)
    # Evaluate the model and optimizer state
    mx.eval(state)
    print(loss)

注意

如果您使用执行随机采样的模块,例如 mlx.nn.Dropout(),请确保也将 mx.random.state 包含在 compile() 捕获的 state 中,即 state = [model.state, optimizer.state, mx.random.state]

注意

有关编译完整训练图的更多示例,请查看 MLX Examples GitHub 仓库。

使用编译进行变换#

在 MLX 中,函数变换是可组合的。您可以将任何函数变换应用于任何其他函数变换的输出。有关更多信息,请参阅 函数变换 文档。

编译变换后的函数工作正常

grad_fn = mx.grad(mx.exp)

compiled_grad_fn = mx.compile(grad_fn)

# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))

# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))

注意

为了尽可能多地进行编译,编译函数的变换默认不会被编译。要编译变换后的函数,只需将其通过 compile()

您也可以编译那些本身就调用了编译函数的函数。一个好的实践是编译最外层的函数,这样 compile() 就有最大的机会来优化计算图

@mx.compile
def inner(x):
    return mx.exp(-mx.abs(x))

def outer(x):
    inner(inner(x))

# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

无形编译#

当编译函数的输入形状改变时,函数会被重新编译。您可以通过在调用 compile() 时指定 shapeless=True 来只编译函数一次,并将其运行在具有可变形状的输入上。在这种情况下,输入的形状改变不会导致函数重新编译。

def fun(x, y):
    return mx.abs(x + y)

compiled_fun = mx.compile(fun, shapeless=True)

x = mx.array(1.0)
y = mx.array(-2.0)

# Firt call compiles the function
print(compiled_fun(x, y))

# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))

请谨慎使用无形编译。由于形状改变不会触发编译,任何依赖于输入形状的图将无法按预期工作。依赖形状的计算很常见,有时也很难检测。例如

def fun(x):
    return x.reshape(x.shape[0] * x.shape[1], -1)

compiled_fun = mx.compile(fun, shapeless=True)

x = mx.random.uniform(shape=(2, 3, 4))

out = compiled_fun(x)

x = mx.random.uniform(shape=(5, 5, 3))

# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)

第二次调用 compiled_fun 失败是因为调用了 reshape(),它使用了第一次调用时 x 的静态形状。我们可以通过使用 flatten() 来避免硬编码 x 的形状来解决此问题

def fun(x):
    return x.flatten(0, 1)

compiled_fun = mx.compile(fun, shapeless=True)

x = mx.random.uniform(shape=(2, 3, 4))

out = compiled_fun(x)

x = mx.random.uniform(shape=(5, 5, 3))

# Ok
out = compiled_fun(x)