mlx.nn.GroupNorm

目录

mlx.nn.GroupNorm#

class GroupNorm(num_groups: int, dims: int, eps: float = 1e-05, affine: bool = True, pytorch_compatible: bool = False)#

对输入应用组归一化(Group Normalization)[1]。

计算与层归一化(Layer Norm)相同的归一化,即

\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\]

其中 \(\gamma\)\(\beta\) 是逐特征维度学习的参数,分别初始化为 1 和 0。然而,均值和方差是跨空间维度和每组特征计算的。具体来说,输入沿特征维度被分成 num_groups 组。

特征维度被认为是最后一个维度,除第一个维度外的所有先行维度都被视为空间维度。

[1]: https://arxiv.org/abs/1803.08494

参数:
  • num_groups (int) – 将特征分组的数量

  • dims (int) – 需要进行归一化的输入的特征维度

  • eps (float) – 用于数值稳定性的微小加法常数

  • affine (bool) – 如果为 True,则在归一化后学习并应用一个仿射变换。

  • pytorch_compatible (bool) – 如果为 True,则以与 PyTorch 相同的顺序/分组执行组归一化。

方法