mlx.core.value_and_grad

mlx.core.value_and_grad#

value_and_grad(fun: Callable, argnums: int | Sequence[int] | None = None, argnames: str | Sequence[str] = []) Callable#

返回一个计算 fun 的值和梯度的函数。

传递给 value_and_grad() 的函数应返回一个标量损失,或者一个元组,其中第一个元素是标量损失,其余元素可以是任何类型。

import mlx.core as mx

def mse(params, inputs, targets):
    outputs = forward(params, inputs)
    lvalue = (outputs - targets).square().mean()
    return lvalue

# Returns lvalue, dlvalue/dparams
lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets)

def lasso(params, inputs, targets, a=1.0, b=1.0):
    outputs = forward(params, inputs)
    mse = (outputs - targets).square().mean()
    l1 = mx.abs(outputs - targets).mean()

    loss = a*mse + b*l1

    return loss, mse, l1

(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
参数:
  • fun (Callable) – 一个函数,接受可变数量的 arrayarray 树,并返回一个标量输出 array,或者一个元组,其第一个元素应为标量 array

  • argnums (intlist(int), 可选) – 指定 fun 的位置参数的索引(或多个索引),以计算相对于这些参数的梯度。如果既未提供 argnums 也未提供 argnames,则 argnums 默认为 0,表示 fun 的第一个参数。

  • argnames (strlist(str), 可选) – 指定 fun 的关键字参数,以计算相对于这些参数的梯度。默认为 [],因此默认情况下不对关键字参数计算梯度。

返回值:

一个函数,返回一个元组,其中第一个元素是 fun 的输出,第二个元素是相对于损失的梯度。

返回类型:

Callable