mlx.core.vjp

目录

mlx.core.vjp#

vjp(fun: Callable, primals: list[array], cotangents: list[array]) tuple[list[array], list[array]]#

计算向量-雅可比积。

计算 cotangents 与函数 funprimals 处评估的雅可比矩阵的乘积。

参数:
  • fun (Callable) – 一个函数,接受可变数量的 array,并返回单个 arrayarray 列表。

  • primals (list(array)) – 用于计算雅可比矩阵的 array 列表。

  • cotangents (list(array)) – array 列表,它们是向量-雅可比积中的“向量”。cotangents 的数量、形状和类型应与 fun 的输出一致。

返回:

向量-雅可比积的列表,其数量、形状和类型与 fun 的输出一致。

返回类型:

list(array)