mlx.core.block_masked_mm

mlx.core.block_masked_mm#

block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array | None = None, mask_lhs: array | None = None, mask_rhs: array | None = None, *, stream: None | Stream | Device = None) array#

具有块掩码的矩阵乘法。

执行两个数组的(可能批处理的)矩阵乘法,并选择性地屏蔽掉大小为 block_size x block_size 的块。

假设 a 的形状为 (…, M, K),b 的形状为 (…, K, N)

  • lhs_mask 的形状必须为 (…, \(\lceil\) M / block_size \(\rceil\), \(\lceil\) K / block_size \(\rceil\))

  • rhs_mask 的形状必须为 (…, \(\lceil\) K / block_size \(\rceil\), \(\lceil\) N / block_size \(\rceil\))

  • out_mask 的形状必须为 (…, \(\lceil\) M / block_size \(\rceil\), \(\lceil\) N / block_size \(\rceil\))

注意:目前仅支持 block_size=64block_size=32

参数:
  • a (array) – 输入数组或标量。

  • b (array) – 输入数组或标量。

  • block_size (int) – 要屏蔽的块大小。必须是 3264。默认值:64

  • mask_out (array, optional) – 输出的掩码。默认值:None

  • mask_lhs (array, optional) – a 的掩码。默认值:None

  • mask_rhs (array, optional) – b 的掩码。默认值:None

返回:

输出数组。

返回类型:

array