mlx.nn.RoPE

目录

mlx.nn.RoPE#

class RoPE(dims: int, traditional: bool = False, base: float = 10000, scale: float = 1.0)#

实现旋转位置编码。

传统实现旋转特征维度中连续的元素对,而默认实现为了效率旋转步长为特征维度一半的元素对。

更多详情请参阅 RoFormer: Enhanced Transformer with Rotary Position Embedding

参数:
  • dims (int) – 需要旋转的特征维度。如果输入特征大于 dims,则其余部分保持不变。

  • traditional (bool, 可选) – 如果设置为 True,选择传统实现,效率略低。默认值: False

  • base (float, 可选) – 用于计算位置编码中每个维度的角频率的基数。默认值: 10000

  • scale (float, 可选) – 用于缩放位置的比例因子。默认值: 1.0

方法