mlx.core.gather_qmm

目录

mlx.core.gather_qmm#

gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: array | None = None, rhs_indices: array | None = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: None | Stream | Device = None) array#

执行带矩阵级别收集的量化矩阵乘法。

此操作是 gather_mm() 的量化等效版本。类似于 gather_mm(),索引 lhs_indicesrhs_indices 分别包含 xw 沿批量维度(即除最后两个维度之外的所有维度)的扁平索引。

注意,`scales` 和 `biases` 必须与 `w` 具有相同的批量维度,因为它们表示相同的量化矩阵。

参数:
  • x (array) – 输入数组

  • w (array) – 以无符号整数打包的量化矩阵

  • scales (array) – 用于 `w` 中每 `group_size` 个元素的尺度

  • biases (array) – 用于 `w` 中每 `group_size` 个元素的偏置

  • lhs_indices (array, 可选) – `x` 的整数索引。默认值:`None`。

  • rhs_indices (array, 可选) – `w` 的整数索引。默认值:`None`。

  • transpose (bool, 可选) – 定义是否与 `w` 的转置矩阵相乘,即是执行 `x @ w.T` 还是 `x @ w`。默认值:`True`。

  • group_size (int, 可选) – `w` 中共享尺度和偏置的组大小。默认值:`64`。

  • bits (int, 可选) – `w` 中每个元素占用的位数。默认值:`4`。

  • sorted_indices (bool, 可选) – 如果传入的索引已排序,可能会实现更快的运算。默认值:`False`。

返回值:

在使用 `lhs_indices` 和 `rhs_indices` 收集之后,

`x` 与 `w` 相乘的结果。

返回类型:

array