mlx.nn.init.he_normal#
- he_normal(dtype: Dtype = mlx.core.float32) Callable[[array, Literal['fan_in', 'fan_out'], float], array] #
构建一个 He 正态分布初始化器。
此初始化器从一个正态分布中采样,其标准差根据输入单元数 (
fan_in
) 或输出单元数 (fan_out
) 计算,公式如下:\[\sigma = \gamma \frac{1}{\sqrt{\text{fan}}}\]其中 \(\text{fan}\) 是当
mode
为"fan_in"
时表示输入单元数,当mode
为"fan_out"
时表示输出单元数。更多详细信息请参阅原始文献:Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
- 参数:
dtype (Dtype, optional) – 数组的数据类型。默认为 mx.float32。
- 返回:
一个初始化器,它返回一个与输入具有相同形状的数组,并填充从 He 正态分布中采样的值。
- 返回类型:
示例
>>> init_fn = nn.init.he_normal() >>> init_fn(mx.zeros((2, 2))) # uses fan_in array([[-1.25211, 0.458835], [-0.177208, -0.0137595]], dtype=float32) >>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5) array([[5.6967, 4.02765], [-4.15268, -2.75787]], dtype=float32)