mlx.core.fast.rope

目录

mlx.core.fast.rope#

rope(a: array, dims: int, *, traditional: bool, base: float | None, scale: float, offset: int | array, freqs: array | None = None, stream: None | Stream |Device = None) array#

对输入应用旋转位置编码。

参数:
  • a (array) – 输入数组。

  • dims (int) – 需要旋转的特征维度。如果输入特征维度大于此值,其余部分将保持不变。

  • traditional (bool) – 如果设置为 True,则选择旋转连续维度的传统实现。

  • base (float, optional) – 用于计算位置编码中每个维度的角频率的基数。 basefreqs 中必须且只能有一个为 None

  • scale (float) – 用于缩放位置的比例。

  • offset (int or array) – 开始计算位置的偏移量。

  • freqs (array, optional) – 可选的用于 RoPE 的频率。如果设置此参数,base 参数必须为 None。默认值: None

返回值:

输出数组。

返回类型:

array