mlx.core.gather_mm#
- gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: None | Stream | Device = None) array #
带有矩阵级聚集的矩阵乘法。
对操作数使用给定索引执行聚集操作,然后对两个数组执行(可能是批量)矩阵乘法。此操作比显式应用
take()
再接matmul()
更高效。索引
lhs_indices
和rhs_indices
分别包含沿a
和b
的批处理维度(即除最后两个维度之外的所有维度)的展平索引。对于形状为
(A1, A2, ..., AS, M, K)
的a
,lhs_indices
包含范围[0, A1 * A2 * ... * AS)
内的索引。对于形状为
(B1, B2, ..., BS, M, K)
的b
,rhs_indices
包含范围[0, B1 * B2 * ... * BS)
内的索引。如果只传递一个索引且该索引已排序,可以传递
sorted_indices
标志以可能获得更快的实现。