mlx.core.gather_mm

目录

mlx.core.gather_mm#

gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: None | Stream | Device = None) array#

带有矩阵级聚集的矩阵乘法。

对操作数使用给定索引执行聚集操作,然后对两个数组执行(可能是批量)矩阵乘法。此操作比显式应用 take() 再接 matmul() 更高效。

索引 lhs_indicesrhs_indices 分别包含沿 ab 的批处理维度(即除最后两个维度之外的所有维度)的展平索引。

对于形状为 (A1, A2, ..., AS, M, K)alhs_indices 包含范围 [0, A1 * A2 * ... * AS) 内的索引。

对于形状为 (B1, B2, ..., BS, M, K)brhs_indices 包含范围 [0, B1 * B2 * ... * BS) 内的索引。

如果只传递一个索引且该索引已排序,可以传递 sorted_indices 标志以可能获得更快的实现。

参数:
  • a (array) – 输入数组。

  • b (array) – 输入数组。

  • lhs_indices (array, 可选) – a 的整数索引。默认值:None

  • rhs_indices (array, 可选) – b 的整数索引。默认值:None

  • sorted_indices (bool, 可选) – 如果传递的索引已排序,可能允许更快的实现。默认值:False

返回值:

输出数组。

返回类型:

array