变换

变换#

eval(*args)

求值一个 array 或一个 array 树。

async_eval(*args)

异步求值一个 array 或一个 array 树。

compile(fun[, inputs, outputs, shapeless])

返回一个编译后的函数,其输出与 fun 相同。

custom_function

设置一个函数用于自定义梯度和 vmap 定义。

disable_compile()

全局禁用编译。

enable_compile()

全局启用编译。

grad(fun[, argnums, argnames])

返回计算 fun 梯度的函数。

value_and_grad(fun[, argnums, argnames])

返回计算 fun 值和梯度的函数。

jvp(fun, primals, tangents)

计算雅可比向量积。

vjp(fun, primals, cotangents)

计算向量雅可比积。

vmap(fun[, in_axes, out_axes])

返回 fun 的向量化版本。