导出函数#

MLX 提供了一套 API,用于将函数导出到文件或从文件导入函数。这使得您可以在一个 MLX 前端(例如 Python)中编写的计算,在另一个 MLX 前端(例如 C++)中运行。

本指南将通过一些示例介绍 MLX 导出 API 的基础知识。要查看完整的函数列表,请查阅 API 文档

导出基础#

让我们从一个简单的示例开始

def fun(x, y):
  return x + y

x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("add.mlxfn", fun, x, y)

要导出函数,需要提供函数可以调用的示例输入数组。数据本身不重要,但数组的形状和类型很重要。在上面的示例中,我们导出了带有两个 float32 标量数组的函数 fun。然后我们可以导入并运行该函数

add_fun = mx.import_function("add.mlxfn")

out, = add_fun(mx.array(1.0), mx.array(2.0))
# Prints: array(3, dtype=float32)
print(out)

out, = add_fun(mx.array(1.0), mx.array(3.0))
# Prints: array(4, dtype=float32)
print(out)

# Raises an exception
add_fun(mx.array(1), mx.array(3.0))

# Raises an exception
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))

注意,第三次和第四次调用 add_fun 会引发异常,因为输入的形状和类型与我们导出函数时使用的示例输入的形状和类型不同。

另请注意,尽管原始函数 fun 返回单个输出数组,但导入的函数始终返回一个包含一个或多个数组的元组。

传递给 export_function() 以及导入函数的输入可以指定为可变位置参数,或作为数组的元组

def fun(x, y):
  return x + y

x = mx.array(1.0)
y = mx.array(1.0)

# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)

# Same as above
mx.export_function("add.mlxfn", fun, (x, y))

imported_fun = mx.import_function("add.mlxfn")

# Ok
out, = imported_fun(x, y)

# Also ok
out, = imported_fun((x, y))

您可以将示例输入作为位置参数或关键字参数传递给函数。如果您使用关键字参数导出函数,那么在调用导入函数时也必须使用相同的关键字参数。

def fun(x, y):
  return x + y

# One argument to fun is positional, the other is a kwarg
mx.export_function("add.mlxfn", fun, x, y=y)

imported_fun = mx.import_function("add.mlxfn")

# Ok
out, = imported_fun(x, y=y)

# Also ok
out, = imported_fun((x,), {"y": y})

# Raises since the keyword argument is missing
out, = imported_fun(x, y)

# Raises since the keyword argument has the wrong key
out, = imported_fun(x, z=y)

导出模块#

可以将 mlx.nn.Module 导出,导出的函数中可以包含或不包含其参数。下面是一个示例

model = nn.Linear(4, 4)
mx.eval(model.parameters())

def call(x):
   return model(x)

mx.export_function("model.mlxfn", call, mx.zeros(4))

在上面的示例中,导出了 mlx.nn.Linear 模块。其参数也保存到 model.mlxfn 文件中。

注意

对于导出函数中包含的数组,请特别注意确保它们已被求值。导出的计算图将包含产生这些包含输入所需的计算。

如果上面的示例缺少 mx.eval(model.parameters(),导出的函数将包含 mlx.nn.Module 参数的随机初始化过程。

如果您只想导出 Module.__call__ 函数而不包含参数,请将参数作为输入传递给 call 包装器

model = nn.Linear(4, 4)
mx.eval(model.parameters())

def call(x, **params):
  # Set the model's parameters to the input parameters
  model.update(tree_unflatten(list(params.items())))
  return model(x)

params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)

无形导出#

就像 compile() 一样,函数也可以针对动态形状的输入进行导出。向 export_function()exporter() 传递 shapeless=True 即可导出可用于可变形状输入的函数

mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")

# Ok
out, = imported_abs(mx.array(-1.0))

# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))

如果使用 shapeless=False(这是默认值),第二次调用 imported_abs 将因形状不匹配而引发异常。

无形导出与无形编译的工作原理相同,应谨慎使用。更多信息请参阅无形编译文档

导出多个追踪#

在某些情况下,函数会为不同的输入参数构建不同的计算图。一个简单的管理方法是针对每组输入导出到一个新文件。在许多情况下,这是一个不错的选择。但如果导出的函数包含大量重复的常量数据(例如 mlx.nn.Module 的参数),这种方法可能不是最优的。

MLX 中的导出 API 允许您使用 exporter() 创建一个导出上下文管理器,将同一函数的多个追踪导出到单个文件中

def fun(x, y=None):
    constant = mx.array(3.0)
    if y is not None:
      x += y
    return x + constant

with mx.exporter("fun.mlxfn", fun) as exporter:
    exporter(mx.array(1.0))
    exporter(mx.array(1.0), y=mx.array(0.0))

imported_function = mx.import_function("fun.mlxfn")

# Call the function with y=None
out, = imported_function(mx.array(1.0))
print(out)

# Call the function with y specified
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
print(out)

在上面的示例中,函数的常量数据(即 constant)只保存了一次。

使用导入函数进行变换#

函数变换,如 grad()vmap()compile(),对导入的函数同样适用,就像对普通 Python 函数一样

def fun(x):
    return mx.sin(x)

x = mx.array(0.0)
mx.export_function("sine.mlxfn", fun, x)

imported_fun = mx.import_function("sine.mlxfn")

# Take the derivative of the imported function
dfdx = mx.grad(lambda x: imported_fun(x)[0])
# Prints: array(1, dtype=float32)
print(dfdx(x))

# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])

在 C++ 中导入函数#

在 C++ 中导入和运行函数与在 Python 中导入和运行函数基本相同。首先,请遵循说明来设置一个使用 MLX 作为库的简单 C++ 项目。

接下来,从 Python 导出一个简单函数

def fun(x, y):
    return mx.exp(x + y)

x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("fun.mlxfn", fun, x, y)

只需几行代码即可在 C++ 中导入并运行该函数

auto fun = mx::import_function("fun.mlxfn");

auto inputs = {mx::array(1.0), mx::array(1.0)};
auto outputs = fun(inputs);

// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;

导入的函数在 C++ 中也可以像在 Python 中一样进行变换。在 C++ 中调用导入函数时,对位置参数使用 std::vector<mx::array>,对关键字参数使用 std::map<std::string, mx::array>

更多示例#

这里有一些更完整的示例,演示如何从 Python 导出更复杂的函数,并在 C++ 中导入和运行它们