mlx.nn.MultiHeadAttention

mlx.nn.MultiHeadAttention#

class MultiHeadAttention(dims: int, num_heads: int, query_input_dims: int | None = None, key_input_dims: int | None = None, value_input_dims: int | None = None, value_dims: int | None = None, value_output_dims: int | None = None, bias: bool = False )#

实现带有多头的缩放点积注意力。

给定查询、键和值输入,MultiHeadAttention 通过根据输入查询和键的相似性聚合输入值的信息来生成新的值。

所有输入以及输出默认进行无偏置的线性投影。

MultiHeadAttention 还接受一个可选的加性注意力掩码,该掩码应可与 (batch, num_heads, # queries, # keys) 进行广播。掩码应在不应被注意到的位置包含 -inf 或非常大的负数。

参数:
  • dims (int) – 模型维度。这也是查询、键、值和输出的默认值。

  • num_heads (int) – 要使用的注意力头数量。

  • query_input_dims (int, 可选) – 查询的输入维度。默认值:dims

  • key_input_dims (int, 可选) – 键的输入维度。默认值:dims

  • value_input_dims (int, 可选) – 值的输入维度。默认值:key_input_dims

  • value_dims (int, 可选) – 投影后值的维度。默认值:dims

  • value_output_dims (int, 可选) – 新值将投影到的维度。默认值:dims

  • bias (bool, 可选) – 是否在投影中使用偏置。默认值:False

方法

create_additive_causal_mask(N[, dtype])