mlx.nn.GroupNorm#
- class GroupNorm(num_groups: int, dims: int, eps: float = 1e-05, affine: bool = True, pytorch_compatible: bool = False)#
对输入应用组归一化(Group Normalization)[1]。
计算与层归一化(Layer Norm)相同的归一化,即
\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\]其中 \(\gamma\) 和 \(\beta\) 是逐特征维度学习的参数,分别初始化为 1 和 0。然而,均值和方差是跨空间维度和每组特征计算的。具体来说,输入沿特征维度被分成 num_groups 组。
特征维度被认为是最后一个维度,除第一个维度外的所有先行维度都被视为空间维度。
[1]: https://arxiv.org/abs/1803.08494
- 参数:
方法