mlx.core.gather_qmm#
- gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: array | None = None, rhs_indices: array | None = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: None | Stream | Device = None) array #
执行带矩阵级别收集的量化矩阵乘法。
此操作是
gather_mm()
的量化等效版本。类似于gather_mm()
,索引lhs_indices
和rhs_indices
分别包含x
和w
沿批量维度(即除最后两个维度之外的所有维度)的扁平索引。注意,`scales` 和 `biases` 必须与 `w` 具有相同的批量维度,因为它们表示相同的量化矩阵。
- 参数:
x (array) – 输入数组
w (array) – 以无符号整数打包的量化矩阵
scales (array) – 用于 `w` 中每 `group_size` 个元素的尺度
biases (array) – 用于 `w` 中每 `group_size` 个元素的偏置
lhs_indices (array, 可选) – `x` 的整数索引。默认值:`None`。
rhs_indices (array, 可选) – `w` 的整数索引。默认值:`None`。
transpose (bool, 可选) – 定义是否与 `w` 的转置矩阵相乘,即是执行 `x @ w.T` 还是 `x @ w`。默认值:`True`。
group_size (int, 可选) – `w` 中共享尺度和偏置的组大小。默认值:`64`。
bits (int, 可选) – `w` 中每个元素占用的位数。默认值:`4`。
sorted_indices (bool, 可选) – 如果传入的索引已排序,可能会实现更快的运算。默认值:`False`。
- 返回值:
- 在使用 `lhs_indices` 和 `rhs_indices` 收集之后,
`x` 与 `w` 相乘的结果。
- 返回类型: