mlx.nn.GELU

目录

mlx.nn.GELU#

class GELU(approx='none')#

应用高斯误差线性单元。

\[\textrm{GELU}(x) = x * \Phi(x)\]

其中 \(\Phi(x)\) 是高斯累积分布函数(CDF)。

然而,如果 approx 设置为 'precise' 或 'fast',则分别应用以下公式:

\[\begin{split}\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\ \textrm{GELUFast}(x) &= x * \sigma\left(1.702 * x\right)\end{split}\]

注意

为了与PyTorch API兼容,'tanh' 可以作为 'precise' 的别名。

有关函数式等效及其误差范围信息,请参阅 gelu()gelu_approx()gelu_fast_approx()

参数:

approx ('none' | 'precise' | 'fast') – 如果使用 GELU 近似,指定使用的类型。

方法