mlx.nn.init.he_uniform#
- he_uniform(dtype: Dtype = mlx.core.float32) Callable[[array, Literal['fan_in', 'fan_out'], float], array] #
一个 He Uniform (Kaiming Uniform) 初始化器。
此初始化器从一个均匀分布中采样,其范围根据输入 (
fan_in
) 或输出 (fan_out
) 单元的数量计算得出,公式如下:\[\sigma = \gamma \sqrt{\frac{3.0}{\text{风扇}}}\]其中 \(\text{fan}\) 是当
mode
为"fan_in"
时的输入单元数量,或当mode
为"fan_out"
时的输出单元数量。更多详情请参阅原始参考文献:Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
- 参数:
dtype (Dtype, optional) – 数组的数据类型。默认值:
float32
。- 返回:
一个初始化器,它返回一个与输入形状相同的数组,其中填充了从 He Uniform 分布中采样的值。
- 返回类型:
示例
>>> init_fn = nn.init.he_uniform() >>> init_fn(mx.zeros((2, 2))) # uses fan_in array([[0.0300242, -0.0184009], [0.793615, 0.666329]], dtype=float32) >>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5) array([[-1.64331, -2.16506], [1.08619, 5.79854]], dtype=float32)