mlx.core.fast.scaled_dot_product_attention

mlx.core.fast.scaled_dot_product_attention#

scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: None | str | array = None, stream: None | Stream | Device = None) array#

多头注意力(multi-head attention)的一种快速实现:O = softmax(Q @ K.T, dim=-1) @ V

支持

注意:softmax 操作始终在 float32 中执行,无论输入精度如何。

注意:对于分组查询注意力和多查询注意力,输入 kv 不应预先平铺以匹配 q

尺寸定义如下:

  • B: 批量大小。

  • N_q: 查询头数量。

  • N_kv: 键和值头数量。

  • T_q: 每个示例的查询数量。

  • T_kv: 每个示例的键和值数量。

  • D: 每个头的维度。

参数:
  • q (array) – 查询,形状为 [B, N_q, T_q, D]

  • k (array) – 键,形状为 [B, N_kv, T_kv, D]

  • v (array) – 值,形状为 [B, N_kv, T_kv, D]

  • scale (float) – 查询的缩放因子(通常为 1.0 / sqrt(q.shape(-1)

  • mask (Union[None, str, array], optional) – 应用于查询-键得分的因果掩码(causal mask)、布尔掩码或加性掩码(additive mask)。掩码最多可以有 4 个维度,并且必须与形状 [B, N, T_q, T_kv] 广播兼容。如果给定加性掩码,其类型必须提升到 qkv 提升后的类型。

返回值:

输出数组。

返回类型:

array