mlx.nn.value_and_grad

mlx.nn.value_and_grad#

value_and_grad(model: Module, fn: Callable)#

将传入的函数 fn 转换为一个函数,该函数计算 fn 相对于模型可训练参数的梯度以及 fn 的值。

参数
  • model (Module) – 要计算其可训练参数梯度的模型

  • fn (Callable) – 要计算梯度的标量函数

返回

一个可调用对象,返回 fn 的值以及 fn 相对于 model 可训练参数的梯度