mlx.core.jvp

目录

mlx.core.jvp#

jvp(fun: Callable, primals: list[array], tangents: list[array]) tuple[list[array], list[array]]#

计算雅可比向量积。

这计算函数 funprimals 处计算的雅可比矩阵与 tangents 的乘积。

参数:
  • fun (Callable) – 一个函数,接受可变数量的 array,并返回单个 arrayarray 列表。

  • primals (list(array)) – 用于计算雅可比矩阵的 array 列表。

  • tangents (list(array)) – 作为雅可比向量积中“向量”的 array 列表。tangents 的数量、形状和类型应与 fun 的输入(即 primals)相同。

返回值:

雅可比向量积的列表,其数量、形状和类型与 fun 的输入相同。

返回类型:

list(array)