自定义 Metal Kernels#

MLX 支持通过 Python 和 C++ API 编写自定义 Metal Kernels。

简单示例#

我们来编写一个计算 exp 的自定义 elementwise kernel。

def exp_elementwise(a: mx.array):
    source = """
        uint elem = thread_position_in_grid.x;
        T tmp = inp[elem];
        out[elem] = metal::exp(tmp);
    """

    kernel = mx.fast.metal_kernel(
        name="myexp",
        input_names=["inp"],
        output_names=["out"],
        source=source,
    )
    outputs = kernel(
        inputs=[a],
        template=[("T", mx.float32)],
        grid=(a.size, 1, 1),
        threadgroup=(256, 1, 1),
        output_shapes=[a.shape],
        output_dtypes=[a.dtype],
    )
    return outputs[0]

a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))

注意

我们只需要在 source 中传递 Metal kernel 的主体。

完整的函数签名将使用以下内容生成:

  • inputs 的形状/数据类型 (dtypes)

    在上面,a 是一个类型为 mx.float16mx.array,我们使用键 inp 传递它,因此签名中将添加 const device float16_t* inp。如果 source 中存在 inp_shapeinp_stridesinp_ndim,它们也会被添加以便使用。

  • output_dtypes 列表

    在上面,out 是一个类型为 mx.float16mx.array,因此我们添加 device float16_t* out

  • 使用 template 传递的模板参数

    在上面,template=[("T", mx.float32)] 为函数添加了一个 template <typename T> 模板,并使用 custom_kernel_myexp_float<float> 实例化该模板。模板参数可以是 mx.core.Dtypeintbool

  • source 中使用的 Metal 属性,例如 [[thread_position_in_grid]]

    这些将作为函数参数添加。支持 Metal 着色语言规范 表 5.8 中定义的所有属性。

将所有这些放在一起,myexp 的生成函数签名如下:

template <typename T>
[[kernel]] void custom_kernel_myexp_float(
  const device float16_t* inp [[buffer(0)]],
  device float16_t* out [[buffer(1)]],
  uint3 thread_position_in_grid [[thread_position_in_grid]]) {

        uint elem = thread_position_in_grid.x;
        T tmp = inp[elem];
        out[elem] = metal::exp(tmp);

}

template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;

注意:gridthreadgroup 是 Metal dispatchThreads 函数的参数。这意味着我们将启动 mx.prod(grid) 个线程,细分为 threadgroup 大小的线程组。为了获得最佳性能,每个线程组维度应小于或等于相应的 grid 维度。

verbose=True 传递给 mx.fast.metal_kernel.__call__ 将打印生成的代码,用于调试。

使用 Shape/Strides#

mx.fast.metal_kernel 支持参数 ensure_row_contiguous,默认值为 True。这会在启动 kernel 之前根据需要复制 mx.array 输入,以确保内存布局是行连续的 (row contiguous)。通常这使得编写 kernel 更容易,因为我们在索引时无需担心间隙或维度的顺序。

如果我们想避免这种复制,如果 source 中存在任何输入数组 aa_shapea_stridesa_ndimmetal_kernel 会自动传递它们。然后我们可以使用 MLX 内置的索引工具来为每个线程获取正确的元素。

我们来修改上面的 myexp 以支持任意跨度 (strided) 的数组,而无需依赖 ensure_row_contiguous 进行复制。

def exp_elementwise(a: mx.array):
    source = """
        uint elem = thread_position_in_grid.x;
        // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
        uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
        T tmp = inp[loc];
        // Output arrays are always row contiguous
        out[elem] = metal::exp(tmp);
    """

    kernel = mx.fast.metal_kernel(
        name="myexp_strided",
        input_names=["inp"],
        output_names=["out"],
        source=source
    )
    outputs = kernel(
        inputs=[a],
        template=[("T", mx.float32)],
        grid=(a.size, 1, 1),
        threadgroup=(256, 1, 1),
        output_shapes=[a.shape],
        output_dtypes=[a.dtype],
        ensure_row_contiguous=False,
    )
    return outputs[0]

a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))

复杂示例#

我们来实现一个更复杂的示例:在 "bilinear" 模式下的 grid_sample

我们将从以下使用标准操作的 MLX 实现开始:

def grid_sample_ref(x, grid):
    N, H_in, W_in, _ = x.shape
    ix = ((grid[..., 0] + 1) * W_in - 1) / 2
    iy = ((grid[..., 1] + 1) * H_in - 1) / 2

    ix_nw = mx.floor(ix).astype(mx.int32)
    iy_nw = mx.floor(iy).astype(mx.int32)

    ix_ne = ix_nw + 1
    iy_ne = iy_nw

    ix_sw = ix_nw
    iy_sw = iy_nw + 1

    ix_se = ix_nw + 1
    iy_se = iy_nw + 1

    nw = (ix_se - ix)    * (iy_se - iy)
    ne = (ix    - ix_sw) * (iy_sw - iy)
    sw = (ix_ne - ix)    * (iy    - iy_ne)
    se = (ix    - ix_nw) * (iy    - iy_nw)

    I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
    I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
    I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
    I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]

    mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
    mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
    mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
    mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)

    I_nw *= mask_nw[..., None]
    I_ne *= mask_ne[..., None]
    I_sw *= mask_sw[..., None]
    I_se *= mask_se[..., None]

    output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se

    return output

现在我们结合使用 mx.custom_functionmx.fast.metal_kernel 来为前向和后向传播编写一个快速 GPU kernel。

首先,我们将前向传播实现为一个 fused kernel。

@mx.custom_function
def grid_sample(x, grid):

    assert x.ndim == 4, "`x` must be 4D."
    assert grid.ndim == 4, "`grid` must be 4D."

    B, _, _, C = x.shape
    _, gN, gM, D = grid.shape
    out_shape = (B, gN, gM, C)

    assert D == 2, "Last dim of `grid` must be size 2."

    source = """
        uint elem = thread_position_in_grid.x;
        int H = x_shape[1];
        int W = x_shape[2];
        int C = x_shape[3];
        int gH = grid_shape[1];
        int gW = grid_shape[2];

        int w_stride = C;
        int h_stride = W * w_stride;
        int b_stride = H * h_stride;

        uint grid_idx = elem / C * 2;
        float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
        float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;

        int ix_nw = floor(ix);
        int iy_nw = floor(iy);

        int ix_ne = ix_nw + 1;
        int iy_ne = iy_nw;

        int ix_sw = ix_nw;
        int iy_sw = iy_nw + 1;

        int ix_se = ix_nw + 1;
        int iy_se = iy_nw + 1;

        T nw = (ix_se - ix)    * (iy_se - iy);
        T ne = (ix    - ix_sw) * (iy_sw - iy);
        T sw = (ix_ne - ix)    * (iy    - iy_ne);
        T se = (ix    - ix_nw) * (iy    - iy_nw);

        int batch_idx = elem / C / gH / gW * b_stride;
        int channel_idx = elem % C;
        int base_idx = batch_idx + channel_idx;

        T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
        T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
        T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
        T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];

        I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
        I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
        I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
        I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;

        out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
    """
    kernel = mx.fast.metal_kernel(
        name="grid_sample",
        input_names=["x", "grid"],
        output_names=["out"],
        source=source,
    )
    outputs = kernel(
        inputs=[x, grid],
        template=[("T", x.dtype)],
        output_shapes=[out_shape],
        output_dtypes=[x.dtype],
        grid=(np.prod(out_shape), 1, 1),
        threadgroup=(256, 1, 1),
    )
    return outputs[0]

对于合理大小的输入,例如

x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)

在 M1 Max 上,我们看到了显著的性能提升:

55.7毫秒 -> 6.7毫秒 => 8倍 加速

Grid Sample VJP#

由于我们使用 mx.custom_function 装饰了 grid_sample,我们现在可以定义其自定义 vjp 变换,以便 MLX 可以对其进行微分。

后向传播需要原子更新 x_grad/grid_grad,因此需要 mx.fast.metal_kernel 的一些额外特性:

  • init_value=0

    在 kernel 运行之前,将所有 kernel 的输出初始化为此值。这允许我们只使用 kernel 更新部分输出数组。

  • atomic_outputs=True

    在函数签名中将所有 kernel 输出指定为 atomic。这意味着我们可以使用 Metal 的 atomic 特性从多个线程组同时更新 x_gradgrid_grad 数组。有关更多详细信息,请参见 Metal 着色语言规范 第 6.15 节。

然后我们可以按照如下方式实现后向传播:

@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
    x, grid = primals
    B, _, _, C = x.shape
    _, gN, gM, D = grid.shape

    assert D == 2, "Last dim of `grid` must be size 2."

    source = """
        uint elem = thread_position_in_grid.x;
        int H = x_shape[1];
        int W = x_shape[2];
        int C = x_shape[3];
        // Pad C to the nearest larger simdgroup size multiple
        int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;

        int gH = grid_shape[1];
        int gW = grid_shape[2];

        int w_stride = C;
        int h_stride = W * w_stride;
        int b_stride = H * h_stride;

        uint grid_idx = elem / C_padded * 2;
        float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
        float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;

        int ix_nw = floor(ix);
        int iy_nw = floor(iy);

        int ix_ne = ix_nw + 1;
        int iy_ne = iy_nw;

        int ix_sw = ix_nw;
        int iy_sw = iy_nw + 1;

        int ix_se = ix_nw + 1;
        int iy_se = iy_nw + 1;

        T nw = (ix_se - ix)    * (iy_se - iy);
        T ne = (ix    - ix_sw) * (iy_sw - iy);
        T sw = (ix_ne - ix)    * (iy    - iy_ne);
        T se = (ix    - ix_nw) * (iy    - iy_nw);

        int batch_idx = elem / C_padded / gH / gW * b_stride;
        int channel_idx = elem % C_padded;
        int base_idx = batch_idx + channel_idx;

        T gix = T(0);
        T giy = T(0);
        if (channel_idx < C) {
            int cot_index = elem / C_padded * C + channel_idx;
            T cot = cotangent[cot_index];
            if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
                int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
                atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);

                T I_nw = x[offset];
                gix -= I_nw * (iy_se - iy) * cot;
                giy -= I_nw * (ix_se - ix) * cot;
            }
            if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
                int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
                atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);

                T I_ne = x[offset];
                gix += I_ne * (iy_sw - iy) * cot;
                giy -= I_ne * (ix - ix_sw) * cot;
            }
            if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
                int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
                atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);

                T I_sw = x[offset];
                gix -= I_sw * (iy - iy_ne) * cot;
                giy += I_sw * (ix_ne - ix) * cot;
            }
            if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
                int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
                atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);

                T I_se = x[offset];
                gix += I_se * (iy - iy_nw) * cot;
                giy += I_se * (ix - ix_nw) * cot;
            }
        }

        T gix_mult = W / 2;
        T giy_mult = H / 2;

        // Reduce across each simdgroup first.
        // This is much faster than relying purely on atomics.
        gix = simd_sum(gix);
        giy = simd_sum(giy);

        if (thread_index_in_simdgroup == 0) {
            atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
            atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
        }
    """
    kernel = mx.fast.metal_kernel(
        name="grid_sample_grad",
        input_names=["x", "grid", "cotangent"],
        output_names=["x_grad", "grid_grad"],
        source=source,
        atomic_outputs=True,
    )
    # pad the output channels to simd group size
    # so that our `simd_sum`s don't overlap.
    simdgroup_size = 32
    C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
    grid_size = B * gN * gM * C_padded
    outputs = kernel(
        inputs=[x, grid, cotangent],
        template=[("T", x.dtype)],
        output_shapes=[x.shape, grid.shape],
        output_dtypes=[x.dtype, x.dtype],
        grid=(grid_size, 1, 1),
        threadgroup=(256, 1, 1),
        init_value=0,
    )
    return outputs[0], outputs[1]

对于 vjp,速度提升更大:

676.4毫秒 -> 16.7毫秒 => 40倍 加速