mlx.nn.BatchNorm#
- class BatchNorm(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True)#
对 2D 或 3D 输入应用批量归一化。
计算公式如下:
\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\]其中 \(\gamma\) 和 \(\beta\) 是按特征维度学习的参数,分别初始化为 1 和 0。
输入形状指定为
NC
或NLC
,其中N
是批量大小,C
是特征或通道数,L
是序列长度。输出与输入具有相同的形状。对于四维数组,形状为NHWC
,其中H
和W
分别是高度和宽度。有关批量归一化的更多信息,请参阅原始论文 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift。
- 参数:
示例
>>> import mlx.core as mx >>> import mlx.nn as nn >>> x = mx.random.normal((5, 4)) >>> bn = nn.BatchNorm(num_features=4, affine=True) >>> output = bn(x)
方法
unfreeze
(*args, **kwargs)封装 unfreeze 方法,确保 running_mean 和 var 始终是冻结的参数。