mlx.core.matmul

目录

mlx.core.matmul#

matmul(a: array, b: array, /, *, stream: None | Stream | Device = None) array#

矩阵乘法。

执行两个数组的(可能批处理的)矩阵乘法。此函数支持超过二维数组的广播。

  • 如果第一个数组是一维的,则在其形状前添加一个 1 使其成为矩阵。类似地,如果第二个数组是一维的,则在其形状后添加一个 1 使其成为矩阵。在这两种情况下,结果中的单例维度都会被移除。

  • 如果数组的维度超过 2,则执行批处理矩阵乘法。矩阵乘积的矩阵维度是每个输入的最后两个维度。

  • 除每个输入的最后两个维度外,所有其他维度都使用标准的 numpy 风格广播语义进行广播。

参数:
  • a (array) – 输入数组或标量。

  • b (array) – 输入数组或标量。

返回:

ab 的矩阵乘积。

返回类型:

array