mlx.core.fast.metal_kernel

mlx.core.fast.metal_kernel#

metal_kernel(name: str, input_names: Sequence[str], output_names: Sequence[str], source: str, header: str = '', ensure_row_contiguous: bool = True, atomic_outputs: bool = False) object#

一个从源代码字符串定义的 jit 编译的自定义 Metal 内核。

完整文档:自定义 Metal 内核

参数:
  • name (str) – 内核名称。

  • input_names (List[str]) – 函数签名中输入的参数名称。

  • output_names (List[str]) – 函数签名中输出的参数名称。

  • source (str) – 源代码。这是 Metal 中函数的主体,函数签名将自动生成。

  • header (str) – 在主函数之前包含的头文件源代码。对于应该位于主函数体之外的辅助函数或包含文件很有用。

  • ensure_row_contiguous (bool) – 是否在内核运行前确保输入是行连续的。默认值:True

  • atomic_outputs (bool) – 是否在函数签名中使用原子输出,例如 device atomic<float>。默认值:False

返回:

可调用对象 metal_kernel

示例

def exp_elementwise(a: mx.array):
    source = '''
        uint elem = thread_position_in_grid.x;
        T tmp = inp[elem];
        out[elem] = metal::exp(tmp);
    '''

    kernel = mx.fast.metal_kernel(
        name="myexp",
        input_names=["inp"],
        output_names=["out"],
        source=source
    )
    outputs = kernel(
        inputs=[a],
        template=[("T", mx.float32)],
        grid=(a.size, 1, 1),
        threadgroup=(256, 1, 1),
        output_shapes=[a.shape],
        output_dtypes=[a.dtype],
        verbose=True,
    )
    return outputs[0]

a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))