MLX 中的自定义扩展#

你可以在 CPU 或 GPU 上使用自定义操作来扩展 MLX。本指南通过一个简单的例子解释了如何做到这一点。

示例介绍#

假设你想要一个操作,它接受两个数组 xy,分别通过系数 alphabeta 进行缩放,然后将它们相加得到结果 z = alpha * x + beta * y。你可以直接在 MLX 中实现这一点。

import mlx.core as mx

def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
    return alpha * x + beta * y

这个函数执行该操作,同时将实现和函数变换留给 MLX 处理。

然而,你可能想要自定义底层实现,或许是为了让它更快。在本教程中,我们将介绍如何添加自定义扩展。它将涵盖:

  • MLX 库的结构。

  • 实现一个 CPU 操作。

  • 使用 Metal 实现一个 GPU 操作。

  • 添加 vjpjvp 函数变换。

  • 构建自定义扩展并将其绑定到 Python。

操作和原语#

MLX 中的操作构建计算图。原语 (Primitive) 提供了评估和变换计算图的规则。让我们先详细讨论一下操作。

操作#

操作是对数组进行操作的前端函数。它们在 C++ API (操作) 中定义,并通过 Python API (操作) 进行绑定。

我们想要一个名为 axpby() 的操作,它接受两个数组 xy,以及两个标量 alphabeta。以下是如何在 C++ 中定义它:

/**
*  Scale and sum two vectors element-wise
*  z = alpha * x + beta * y
*
*  Use NumPy-style broadcasting between x and y
*  Inputs are upcasted to floats if needed
**/
array axpby(
    const array& x, // Input array x
    const array& y, // Input array y
    const float alpha, // Scaling factor for x
    const float beta, // Scaling factor for y
    StreamOrDevice s = {} // Stream on which to schedule the operation
);

最简单的实现方式是使用现有的操作:

array axpby(
    const array& x, // Input array x
    const array& y, // Input array y
    const float alpha, // Scaling factor for x
    const float beta, // Scaling factor for y
    StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
    // Scale x and y on the provided stream
    auto ax = multiply(array(alpha), x, s);
    auto by = multiply(array(beta), y, s);

    // Add and return
    return add(ax, by, s);
}

操作本身不包含处理数据的实现,也不包含变换的规则。相反,它们是易于使用的接口,利用 Primitive (原语) 构建块。

原语#

一个 Primitive (原语) 是 array 的计算图的一部分。它定义了如何根据输入数组创建输出数组。此外,Primitive 包含在 CPU 或 GPU 上运行的方法,以及用于函数变换的方法,例如 vjpjvp。让我们回到我们的例子,以便更具体地说明。

class Axpby : public Primitive {
  public:
    explicit Axpby(Stream stream, float alpha, float beta)
        : Primitive(stream), alpha_(alpha), beta_(beta){};

    /**
    * A primitive must know how to evaluate itself on the CPU/GPU
    * for the given inputs and populate the output array.
    *
    * To avoid unnecessary allocations, the evaluation function
    * is responsible for allocating space for the array.
    */
    void eval_cpu(
        const std::vector<array>& inputs,
        std::vector<array>& outputs) override;
    void eval_gpu(
        const std::vector<array>& inputs,
        std::vector<array>& outputs) override;

    /** The Jacobian-vector product. */
    std::vector<array> jvp(
        const std::vector<array>& primals,
        const std::vector<array>& tangents,
        const std::vector<int>& argnums) override;

    /** The vector-Jacobian product. */
    std::vector<array> vjp(
        const std::vector<array>& primals,
        const std::vector<array>& cotangents,
        const std::vector<int>& argnums,
        const std::vector<array>& outputs) override;

    /**
    * The primitive must know how to vectorize itself across
    * the given axes. The output is a pair containing the array
    * representing the vectorized computation and the axis which
    * corresponds to the output vectorized dimension.
    */
    virtual std::pair<std::vector<array>, std::vector<int>> vmap(
        const std::vector<array>& inputs,
        const std::vector<int>& axes) override;

    /** Print the primitive. */
    void print(std::ostream& os) override {
        os << "Axpby";
    }

    /** Equivalence check **/
    bool is_equivalent(const Primitive& other) const override;

  private:
    float alpha_;
    float beta_;
};

Axpby 类派生自基础的 Primitive 类。Axpbyalphabeta 视为参数。然后,它通过 Axpby::eval_cpu()Axpby::eval_gpu() 提供了根据输入产生输出数组的实现。它还在 Axpby::jvp()Axpby::vjp()Axpby::vmap() 中提供了变换规则。

使用原语#

操作可以使用这个 Primitive (原语) 向计算图添加一个新的 array。可以通过提供其数据类型、形状、计算它的 Primitive 以及传递给原语的 array 输入来构造一个 array

现在让我们用我们的 Axpby 原语来重新实现我们的操作。

array axpby(
    const array& x, // Input array x
    const array& y, // Input array y
    const float alpha, // Scaling factor for x
    const float beta, // Scaling factor for y
    StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
    // Promote dtypes between x and y as needed
    auto promoted_dtype = promote_types(x.dtype(), y.dtype());

    // Upcast to float32 for non-floating point inputs x and y
    auto out_dtype = issubdtype(promoted_dtype, float32)
        ? promoted_dtype
        : promote_types(promoted_dtype, float32);

    // Cast x and y up to the determined dtype (on the same stream s)
    auto x_casted = astype(x, out_dtype, s);
    auto y_casted = astype(y, out_dtype, s);

    // Broadcast the shapes of x and y (on the same stream s)
    auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
    auto out_shape = broadcasted_inputs[0].shape();

    // Construct the array as the output of the Axpby primitive
    // with the broadcasted and upcasted arrays as inputs
    return array(
        /* const std::vector<int>& shape = */ out_shape,
        /* Dtype dtype = */ out_dtype,
        /* std::unique_ptr<Primitive> primitive = */
        std::make_shared<Axpby>(to_stream(s), alpha, beta),
        /* const std::vector<array>& inputs = */ broadcasted_inputs);
}

这个操作现在处理以下事项:

  1. 提升输入类型并确定输出数据类型。

  2. 广播输入并确定输出形状。

  3. 使用给定的流、alphabeta 构造原语 Axpby

  4. 使用原语和输入构造输出 array

实现原语#

当我们只调用操作本身时,不会发生计算。操作只构建计算图。当我们评估输出数组时,MLX 会调度计算图的执行,并根据用户指定的流/设备调用 Axpby::eval_cpu()Axpby::eval_gpu()

警告

当调用 Primitive::eval_cpu()Primitive::eval_gpu() 时,尚未为输出数组分配内存。这些函数的实现需要根据需要分配内存。

实现 CPU 后端#

让我们先从实现 Axpby::eval_cpu() 开始。

该方法将遍历输出数组的每个元素,找到 xy 对应的输入元素,并执行逐点操作。这在模板函数 axpby_impl() 中实现。

template <typename T>
void axpby_impl(
    const mx::array& x,
    const mx::array& y,
    mx::array& out,
    float alpha_,
    float beta_,
    mx::Stream stream) {
  out.set_data(mx::allocator::malloc(out.nbytes()));

  // Get the CPU command encoder and register input and output arrays
  auto& encoder = mx::cpu::get_command_encoder(stream);
  encoder.set_input_array(x);
  encoder.set_input_array(y);
  encoder.set_output_array(out);

  // Launch the CPU kernel
  encoder.dispatch([x_ptr = x.data<T>(),
                    y_ptr = y.data<T>(),
                    out_ptr = out.data<T>(),
                    size = out.size(),
                    shape = out.shape(),
                    x_strides = x.strides(),
                    y_strides = y.strides(),
                    alpha_,
                    beta_]() {

    // Cast alpha and beta to the relevant types
    T alpha = static_cast<T>(alpha_);
    T beta = static_cast<T>(beta_);

    // Do the element-wise operation for each output
    for (size_t out_idx = 0; out_idx < size; out_idx++) {
      // Map linear indices to offsets in x and y
      auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
      auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);

      // We allocate the output to be contiguous and regularly strided
      // (defaults to row major) and hence it doesn't need additional mapping
      out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
    }
  });
}

我们的实现应该适用于所有传入的浮点数数组。因此,我们添加了针对 float32float16bfloat16complex64 的分派。如果遇到意外类型,我们会抛出错误。

void Axpby::eval_cpu(
    const std::vector<mx::array>& inputs,
    std::vector<mx::array>& outputs) {
  auto& x = inputs[0];
  auto& y = inputs[1];
  auto& out = outputs[0];

  // Dispatch to the correct dtype
  if (out.dtype() == mx::float32) {
    return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::float16) {
    return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::bfloat16) {
    return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::complex64) {
    return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
  } else {
    throw std::runtime_error(
        "Axpby is only supported for floating point types.");
  }
}

只实现这些就足以在 CPU 流上运行 axpby() 操作!如果你不打算在 GPU 上运行该操作或在包含 Axpby 的计算图上使用变换,你可以在这里停止实现原语。

实现 GPU 后端#

Apple Silicon 设备使用 Metal 着色语言访问其 GPU,MLX 中的 GPU 内核使用 Metal 编写。

注意

如果你是 Metal 的新手,这里有一些有用的资源:

让我们保持 GPU 内核简单。我们将启动与输出中元素数量完全相同的线程。每个线程将从 xy 中选取它需要的元素,执行逐点操作,并更新其在输出中被分配的元素。

template <typename T>
[[kernel]] void axpby_general(
        device const T* x [[buffer(0)]],
        device const T* y [[buffer(1)]],
        device T* out [[buffer(2)]],
        constant const float& alpha [[buffer(3)]],
        constant const float& beta [[buffer(4)]],
        constant const int* shape [[buffer(5)]],
        constant const int64_t* x_strides [[buffer(6)]],
        constant const int64_t* y_strides [[buffer(7)]],
        constant const int& ndim [[buffer(8)]],
        uint index [[thread_position_in_grid]]) {
    // Convert linear indices to offsets in array
    auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
    auto y_offset = elem_to_loc(index, shape, y_strides, ndim);

    // Do the operation and update the output
    out[index] =
        static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}

然后,我们需要为所有浮点类型实例化此模板,并为每个实例化赋予唯一的宿主名称,以便我们能够识别它。

instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)

确定内核、设置输入、解析网格维度以及分派到 GPU 的逻辑都包含在 Axpby::eval_gpu() 中,如下所示。

/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
  const std::vector<array>& inputs,
  std::vector<array>& outputs) {
    // Prepare inputs
    assert(inputs.size() == 2);
    auto& x = inputs[0];
    auto& y = inputs[1];
    auto& out = outputs[0];

    // Each primitive carries the stream it should execute on
    // and each stream carries its device identifiers
    auto& s = stream();
    // We get the needed metal device using the stream
    auto& d = metal::device(s.device);

    // Allocate output memory
    out.set_data(allocator::malloc(out.nbytes()));

    // Resolve name of kernel
    std::ostringstream kname;
    kname << "axpby_" << "general_" << type_to_name(out);

    // Make sure the metal library is available
    d.register_library("mlx_ext");

    // Make a kernel from this metal library
    auto kernel = d.get_kernel(kname.str(), "mlx_ext");

    // Prepare to encode kernel
    auto& compute_encoder = d.get_command_encoder(s.index);
    compute_encoder.set_compute_pipeline_state(kernel);

    // Kernel parameters are registered with buffer indices corresponding to
    // those in the kernel declaration at axpby.metal
    int ndim = out.ndim();
    size_t nelem = out.size();

    // Encode input arrays to kernel
    compute_encoder.set_input_array(x, 0);
    compute_encoder.set_input_array(y, 1);

    // Encode output arrays to kernel
    compute_encoder.set_output_array(out, 2);

    // Encode alpha and beta
    compute_encoder.set_bytes(alpha_, 3);
    compute_encoder.set_bytes(beta_, 4);

    // Encode shape, strides and ndim
    compute_encoder.set_vector_bytes(x.shape(), 5);
    compute_encoder.set_vector_bytes(x.strides(), 6);
    compute_encoder.set_bytes(y.strides(), 7);
    compute_encoder.set_bytes(ndim, 8);

    // We launch 1 thread for each input and make sure that the number of
    // threads in any given threadgroup is not higher than the max allowed
    size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());

    // Fix the 3D size of each threadgroup (in terms of threads)
    MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);

    // Fix the 3D size of the launch grid (in terms of threads)
    MTL::Size grid_dims = MTL::Size(nelem, 1, 1);

    // Launch the grid with the given number of threads divided among
    // the given threadgroups
    compute_encoder.dispatch_threads(grid_dims, group_dims);
}

现在,我们可以在 CPU 和 GPU 上调用 axpby() 操作了!

在继续之前,关于 MLX 和 Metal 有几点需要注意。MLX 会跟踪活动 command_buffer 以及与之关联的 MTLCommandBuffer。我们依赖于 d.get_command_encoder() 来获取活动的 Metal 计算命令编码器,而不是构建一个新的并在最后调用 compute_encoder->end_encoding()。MLX 将内核(计算管线)添加到活动的命令缓冲区,直到达到指定的限制或需要刷新命令缓冲区进行同步。

原语变换#

接下来,让我们在一个 Primitive 中添加变换的实现。这些变换可以构建在其他操作之上,包括我们刚刚定义的操作。

/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
        const std::vector<array>& primals,
        const std::vector<array>& tangents,
        const std::vector<int>& argnums) {
    // Forward mode diff that pushes along the tangents
    // The jvp transform on the primitive can be built with ops
    // that are scheduled on the same stream as the primitive

    // If argnums = {0}, we only push along x in which case the
    // jvp is just the tangent scaled by alpha
    // Similarly, if argnums = {1}, the jvp is just the tangent
    // scaled by beta
    if (argnums.size() > 1) {
        auto scale = argnums[0] == 0 ? alpha_ : beta_;
        auto scale_arr = array(scale, tangents[0].dtype());
        return {multiply(scale_arr, tangents[0], stream())};
    }
    // If argnums = {0, 1}, we take contributions from both
    // which gives us jvp = tangent_x * alpha + tangent_y * beta
    else {
        return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
    }
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
        const std::vector<array>& primals,
        const std::vector<array>& cotangents,
        const std::vector<int>& argnums,
        const std::vector<int>& /* unused */) {
    // Reverse mode diff
    std::vector<array> vjps;
    for (auto arg : argnums) {
        auto scale = arg == 0 ? alpha_ : beta_;
        auto scale_arr = array(scale, cotangents[0].dtype());
        vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
    }
    return vjps;
}

注意,不必完全定义变换就可以开始使用 Primitive

/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
        const std::vector<array>& inputs,
        const std::vector<int>& axes) {
    throw std::runtime_error("[Axpby] vmap not implemented.");
}

构建和绑定#

让我们先看看整体的目录结构。

extensions
├── axpby
│ ├── axpby.cpp
│ ├── axpby.h
│ └── axpby.metal
├── mlx_sample_extensions
│ └── __init__.py
├── bindings.cpp
├── CMakeLists.txt
└── setup.py
  • extensions/axpby/ 定义了 C++ 扩展库

  • extensions/mlx_sample_extensions 构建了相关的 Python 包的结构

  • extensions/bindings.cpp 为我们的操作提供了 Python 绑定

  • extensions/CMakeLists.txt 包含用于构建库和 Python 绑定的 CMake 规则

  • extensions/setup.py 包含用于构建和安装 Python 包的 setuptools 规则

绑定到 Python#

我们使用 nanobind 为 C++ 库构建 Python API。由于像 mlx.core.arraymlx.core.stream 等组件的绑定已经提供,添加我们的 axpby() 就很简单了。

NB_MODULE(_ext, m) {
     m.doc() = "Sample extension for MLX";

     m.def(
         "axpby",
         &axpby,
         "x"_a,
         "y"_a,
         "alpha"_a,
         "beta"_a,
         nb::kw_only(),
         "stream"_a = nb::none(),
         R"(
             Scale and sum two vectors element-wise
             ``z = alpha * x + beta * y``

             Follows numpy style broadcasting between ``x`` and ``y``
             Inputs are upcasted to floats if needed

             Args:
                 x (array): Input array.
                 y (array): Input array.
                 alpha (float): Scaling factor for ``x``.
                 beta (float): Scaling factor for ``y``.

             Returns:
                 array: ``alpha * x + beta * y``
         )");
 }

上述例子中的大部分复杂性来自于字面量名称和文档字符串等额外的细节。

警告

必须在导入上述 nanobind 模块定义的 mlx_sample_extensions 之前导入 mlx.core,以确保像 mlx.core.array 这样的 mlx.core 组件的类型转换器可用。

使用 CMake 构建#

构建 C++ 扩展库只需要 find_package(MLX CONFIG),然后将其链接到你的库即可。

# Add library
add_library(mlx_ext)

# Add sources
target_sources(
    mlx_ext
    PUBLIC
    ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)

# Add include headers
target_include_directories(
    mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)

# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)

我们还需要构建附加的 Metal 库。为了方便,我们提供了一个 mlx_build_metallib() 函数,该函数可以根据源文件、头文件、目标路径等构建 .metallib 目标(该函数在 cmake/extension.cmake 中定义,并随 MLX 包自动导入)。

实际操作如下所示:

# Build metallib
if(MLX_BUILD_METAL)

mlx_build_metallib(
    TARGET mlx_ext_metallib
    TITLE mlx_ext
    SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
    INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
    OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)

add_dependencies(
    mlx_ext
    mlx_ext_metallib
)

endif()

最后,我们构建 nanobind 绑定。

nanobind_add_module(
  _ext
  NB_STATIC STABLE_ABI LTO NOMINSIZE
  NB_DOMAIN mlx
  ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)

if(BUILD_SHARED_LIBS)
  target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()

使用 setuptools 构建#

按照上述描述设置好 CMake 构建规则后,我们可以使用 mlx.extension 中定义的构建工具。

from mlx import extension
from setuptools import setup

if __name__ == "__main__":
    setup(
        name="mlx_sample_extensions",
        version="0.0.0",
        description="Sample C++ and Metal extensions for MLX primitives.",
        ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
        cmdclass={"build_ext": extension.CMakeBuild},
        packages=["mlx_sample_extensions"],
        package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
        extras_require={"dev":[]},
        zip_safe=False,
        python_requires=">=3.8",
    )

注意

我们将 extensions/mlx_sample_extensions 视为包目录,即使它只包含一个 __init__.py,这是为了确保以下几点:

  • 必须在导入 _ext 之前导入 mlx.core

  • C++ 扩展库和 Metal 库与 Python 绑定位于同一位置,如果安装包,它们会一起被复制。

要构建该包,首先使用 pip install -r requirements.txt 安装构建依赖项。然后,你可以在开发时使用 python setup.py build_ext -j8 --inplace (在 extensions/ 目录下)进行原地构建。

这将产生以下目录结构:

extensions
├── mlx_sample_extensions
│ ├── __init__.py
│ ├── libmlx_ext.dylib # C++ 扩展库
│ ├── mlx_ext.metallib # Metal 库
│ └── _ext.cpython-3x-darwin.so # Python 绑定

当你尝试使用命令 python -m pip install . (在 extensions/ 目录下)进行安装时,该包将以与 extensions/mlx_sample_extensions 相同的结构进行安装,并且由于 C++ 和 Metal 库被指定为 package_data,它们将与 Python 绑定一起被复制。

用法#

如上所述安装扩展后,你应该能够简单地导入 Python 包,并像使用任何其他 MLX 操作一样使用它。

让我们看看一个简单的脚本及其结果。

import mlx.core as mx
from mlx_sample_extensions import axpby

a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)

print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c is correct: {mx.all(c == 6.0).item()}")

输出

c shape: [3, 4]
c dtype: float32
c is correct: True

结果#

让我们快速运行一个基准测试,看看我们新的 axpby 操作与我们最初定义的朴素 simple_axpby() 相比如何。

import mlx.core as mx
from mlx_sample_extensions import axpby
import time

def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
    return alpha * x + beta * y

M = 4096
N = 4096

x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
alpha = 4.0
beta = 2.0

mx.eval(x, y)

def bench(f):
    # Warm up
    for i in range(5):
        z = f(x, y, alpha, beta)
        mx.eval(z)

    # Timed run
    s = time.time()
    for i in range(100):
        z = f(x, y, alpha, beta)
        mx.eval(z)
    e = time.time()
    return 1000 * (e - s) / 100

simple_time = bench(simple_axpby)
custom_time = bench(axpby)

print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")

结果是 Simple axpby: 1.559 ms | Custom axpby: 0.774 ms。我们立即看到了适度的改进!

这个操作现在可以很好地用于构建其他操作,在 mlx.nn.Module 调用中,以及作为 grad() 等图变换的一部分。

脚本#

下载代码

完整的示例代码可在 mlx 中找到。