mlx.core.fast.metal_kernel#
- metal_kernel(name: str, input_names: Sequence[str], output_names: Sequence[str], source: str, header: str = '', ensure_row_contiguous: bool = True, atomic_outputs: bool = False) object #
一个从源代码字符串定义的 jit 编译的自定义 Metal 内核。
完整文档:自定义 Metal 内核。
- 参数:
name (str) – 内核名称。
input_names (List[str]) – 函数签名中输入的参数名称。
output_names (List[str]) – 函数签名中输出的参数名称。
source (str) – 源代码。这是 Metal 中函数的主体,函数签名将自动生成。
header (str) – 在主函数之前包含的头文件源代码。对于应该位于主函数体之外的辅助函数或包含文件很有用。
ensure_row_contiguous (bool) – 是否在内核运行前确保输入是行连续的。默认值:
True
。atomic_outputs (bool) – 是否在函数签名中使用原子输出,例如
device atomic<float>
。默认值:False
。
- 返回:
可调用对象
metal_kernel
。
示例
def exp_elementwise(a: mx.array): source = ''' uint elem = thread_position_in_grid.x; T tmp = inp[elem]; out[elem] = metal::exp(tmp); ''' kernel = mx.fast.metal_kernel( name="myexp", input_names=["inp"], output_names=["out"], source=source ) outputs = kernel( inputs=[a], template=[("T", mx.float32)], grid=(a.size, 1, 1), threadgroup=(256, 1, 1), output_shapes=[a.shape], output_dtypes=[a.dtype], verbose=True, ) return outputs[0] a = mx.random.normal(shape=(4, 16)).astype(mx.float16) b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a))