mlx.core.tensordot

目录

mlx.core.tensordot#

tensordot(a: array, b: array, /, axes: int | list[Sequence[int]] = 2, *, stream: None | Stream | Device = None) array#

沿指定轴计算张量点积。

参数:
  • a (array) – 输入数组

  • b (array) – 输入数组

  • axes (intlist(list(int)), 可选) – 求和的维度数量。如果提供一个整型,则在 a 的最后 axes 个维度和 b 的前 axes 个维度上求和。如果提供一个列表的列表,则在 ab 的对应维度上求和。默认值:2。

返回值:

张量点积。

返回类型:

array