mlx.core.block_masked_mm#
- block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array | None = None, mask_lhs: array | None = None, mask_rhs: array | None = None, *, stream: None | Stream | Device = None) array #
具有块掩码的矩阵乘法。
执行两个数组的(可能批处理的)矩阵乘法,并选择性地屏蔽掉大小为
block_size x block_size
的块。假设
a
的形状为 (…, M, K),b
的形状为 (…, K, N)lhs_mask
的形状必须为 (…, \(\lceil\) M /block_size
\(\rceil\), \(\lceil\) K /block_size
\(\rceil\))rhs_mask
的形状必须为 (…, \(\lceil\) K /block_size
\(\rceil\), \(\lceil\) N /block_size
\(\rceil\))out_mask
的形状必须为 (…, \(\lceil\) M /block_size
\(\rceil\), \(\lceil\) N /block_size
\(\rceil\))
注意:目前仅支持
block_size=64
和block_size=32
。