mlx.nn.init.he_normal

目录

mlx.nn.init.he_normal#

he_normal(dtype: Dtype = mlx.core.float32) Callable[[array, Literal['fan_in', 'fan_out'], float], array]#

构建一个 He 正态分布初始化器。

此初始化器从一个正态分布中采样,其标准差根据输入单元数 (fan_in) 或输出单元数 (fan_out) 计算,公式如下:

\[\sigma = \gamma \frac{1}{\sqrt{\text{fan}}}\]

其中 \(\text{fan}\) 是当 mode"fan_in" 时表示输入单元数,当 mode"fan_out" 时表示输出单元数。

更多详细信息请参阅原始文献:Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification

参数:

dtype (Dtype, optional) – 数组的数据类型。默认为 mx.float32。

返回:

一个初始化器,它返回一个与输入具有相同形状的数组,并填充从 He 正态分布中采样的值。

返回类型:

Callable[[array, str, float], array]

示例

>>> init_fn = nn.init.he_normal()
>>> init_fn(mx.zeros((2, 2)))  # uses fan_in
array([[-1.25211, 0.458835],
       [-0.177208, -0.0137595]], dtype=float32)
>>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5)
array([[5.6967, 4.02765],
       [-4.15268, -2.75787]], dtype=float32)