mlx.nn.BatchNorm

目录

mlx.nn.BatchNorm#

class BatchNorm(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True)#

对 2D 或 3D 输入应用批量归一化。

计算公式如下:

\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\]

其中 \(\gamma\)\(\beta\) 是按特征维度学习的参数,分别初始化为 1 和 0。

输入形状指定为 NCNLC,其中 N 是批量大小,C 是特征或通道数,L 是序列长度。输出与输入具有相同的形状。对于四维数组,形状为 NHWC,其中 HW 分别是高度和宽度。

有关批量归一化的更多信息,请参阅原始论文 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift。

参数:
  • num_features (int) – 要归一化的特征维度。

  • eps (float, 可选) – 用于数值稳定性的小常数。默认值:1e-5

  • momentum (float, 可选) – 更新运行均值和方差的动量。默认值:0.1

  • affine (bool, 可选) – 如果为 True,则在归一化后应用学习到的仿射变换。默认值:True

  • track_running_stats (bool, 可选) – 如果为 True,则跟踪运行均值和方差。默认值:True

示例

>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.random.normal((5, 4))
>>> bn = nn.BatchNorm(num_features=4, affine=True)
>>> output = bn(x)

方法

unfreeze(*args, **kwargs)

封装 unfreeze 方法,确保 running_mean 和 var 始终是冻结的参数。