mlx.core.custom_function

mlx.core.custom_function#

class custom_function#

设置一个用于自定义梯度和 vmap 定义的函数。

此类旨在用作函数装饰器。实例是可调用对象,行为与包装的函数完全相同。然而,当使用函数变换时(例如使用 value_and_grad() 计算梯度),将使用通过 custom_function.vjp()custom_function.jvp()custom_function.vmap() 定义的函数,而不是默认变换。

请注意,所有自定义变换都是可选的。未定义的变换将回退到默认行为。

示例

import mlx.core as mx

@mx.custom_function
def f(x, y):
    return mx.sin(x) * y

@f.vjp
def f_vjp(primals, cotangent, output):
    x, y = primals
    return cotan * mx.cos(x) * y, cotan * mx.sin(x)

@f.jvp
def f_jvp(primals, tangents):
  x, y = primals
  dx, dy = tangents
  return dx * mx.cos(x) * y + dy * mx.sin(x)

@f.vmap
def f_vmap(inputs, axes):
  x, y = inputs
  ax, ay = axes
  if ay != ax and ax is not None:
      y = y.swapaxes(ay, ax)
  return mx.sin(x) * y, (ax or ay)

所有 custom_function 实例都表现为纯函数。换句话说,捕获的任何变量都将被视为常量,并且不会计算捕获数组的梯度。例如,

import mlx.core as mx

def g(x, y):
  @mx.custom_function
  def f(x):
    return x * y

  @f.vjp
  def f_vjp(x, dx, fx):
    # Note that we have only x, dx and fx and nothing with respect to y
    raise ValueError("Abort!")

  return f(x)

x = mx.array(2.0)
y = mx.array(3.0)
print(g(x, y))                     # prints 6.0
print(mx.grad(g)(x, y))            # Raises exception
print(mx.grad(g, argnums=1)(x, y)) # prints 0.0
__init__(self, f: Callable)#

方法

__init__(self, f)

jvp(self, f)

为包装的函数定义自定义 jvp。

vjp(self, f)

为包装的函数定义自定义 vjp。

vmap(self, f)

为包装的函数定义自定义向量化变换。