mlx.core.grad

目录

mlx.core.grad#

grad(fun: Callable, argnums: int | Sequence[int] | None = None, argnames: str | Sequence[str] = []) Callable#

返回一个计算函数 fun 梯度的函数。

参数:
  • fun (Callable) – 一个函数,它接受可变数量的 arrayarray 树,并返回一个标量输出 array

  • argnums (intlist(int), 可选) – 指定 fun 的位置参数的索引(或索引列表),以计算相对于这些参数的梯度。如果既未提供 argnums 也未提供 argnames,则 argnums 默认为 0,表示 fun 的第一个参数。

  • argnames (strlist(str), 可选) – 指定 fun 的关键字参数,以计算相对于这些参数的梯度。它默认为 [],因此默认情况下不计算关键字参数的梯度。

返回值:

一个与 fun 具有相同输入参数并返回梯度(或多个梯度)的函数。

返回类型:

Callable