mlx.core.custom_function#
- class custom_function#
设置一个用于自定义梯度和 vmap 定义的函数。
此类旨在用作函数装饰器。实例是可调用对象,行为与包装的函数完全相同。然而,当使用函数变换时(例如使用
value_and_grad()
计算梯度),将使用通过custom_function.vjp()
、custom_function.jvp()
和custom_function.vmap()
定义的函数,而不是默认变换。请注意,所有自定义变换都是可选的。未定义的变换将回退到默认行为。
示例
import mlx.core as mx @mx.custom_function def f(x, y): return mx.sin(x) * y @f.vjp def f_vjp(primals, cotangent, output): x, y = primals return cotan * mx.cos(x) * y, cotan * mx.sin(x) @f.jvp def f_jvp(primals, tangents): x, y = primals dx, dy = tangents return dx * mx.cos(x) * y + dy * mx.sin(x) @f.vmap def f_vmap(inputs, axes): x, y = inputs ax, ay = axes if ay != ax and ax is not None: y = y.swapaxes(ay, ax) return mx.sin(x) * y, (ax or ay)
所有
custom_function
实例都表现为纯函数。换句话说,捕获的任何变量都将被视为常量,并且不会计算捕获数组的梯度。例如,import mlx.core as mx def g(x, y): @mx.custom_function def f(x): return x * y @f.vjp def f_vjp(x, dx, fx): # Note that we have only x, dx and fx and nothing with respect to y raise ValueError("Abort!") return f(x) x = mx.array(2.0) y = mx.array(3.0) print(g(x, y)) # prints 6.0 print(mx.grad(g)(x, y)) # Raises exception print(mx.grad(g, argnums=1)(x, y)) # prints 0.0
- __init__(self, f: Callable)#
方法
__init__
(self, f)jvp
(self, f)为包装的函数定义自定义 jvp。
vjp
(self, f)为包装的函数定义自定义 vjp。
vmap
(self, f)为包装的函数定义自定义向量化变换。