mlx.core.fast.scaled_dot_product_attention#
- scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: None | str | array = None, stream: None | Stream | Device = None) array#
多头注意力(multi-head attention)的一种快速实现:
O = softmax(Q @ K.T, dim=-1) @ V。支持
注意:softmax 操作始终在
float32中执行,无论输入精度如何。注意:对于分组查询注意力和多查询注意力,输入
k和v不应预先平铺以匹配q。尺寸定义如下:
B: 批量大小。N_q: 查询头数量。N_kv: 键和值头数量。T_q: 每个示例的查询数量。T_kv: 每个示例的键和值数量。D: 每个头的维度。
- 参数:
q (array) – 查询,形状为
[B, N_q, T_q, D]。k (array) – 键,形状为
[B, N_kv, T_kv, D]。v (array) – 值,形状为
[B, N_kv, T_kv, D]。scale (float) – 查询的缩放因子(通常为
1.0 / sqrt(q.shape(-1))mask (Union[None, str, array], optional) – 应用于查询-键得分的因果掩码(causal mask)、布尔掩码或加性掩码(additive mask)。掩码最多可以有 4 个维度,并且必须与形状
[B, N, T_q, T_kv]广播兼容。如果给定加性掩码,其类型必须提升到q、k和v提升后的类型。
- 返回值:
输出数组。
- 返回类型: