mlx.nn.init.he_uniform

目录

mlx.nn.init.he_uniform#

he_uniform(dtype: Dtype = mlx.core.float32) Callable[[array, Literal['fan_in', 'fan_out'], float], array]#

一个 He Uniform (Kaiming Uniform) 初始化器。

此初始化器从一个均匀分布中采样,其范围根据输入 (fan_in) 或输出 (fan_out) 单元的数量计算得出,公式如下:

\[\sigma = \gamma \sqrt{\frac{3.0}{\text{风扇}}}\]

其中 \(\text{fan}\) 是当 mode"fan_in" 时的输入单元数量,或当 mode"fan_out" 时的输出单元数量。

更多详情请参阅原始参考文献:Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification

参数:

dtype (Dtype, optional) – 数组的数据类型。默认值:float32

返回:

一个初始化器,它返回一个与输入形状相同的数组,其中填充了从 He Uniform 分布中采样的值。

返回类型:

Callable[[array, str, float], array]

示例

>>> init_fn = nn.init.he_uniform()
>>> init_fn(mx.zeros((2, 2)))  # uses fan_in
array([[0.0300242, -0.0184009],
       [0.793615, 0.666329]], dtype=float32)
>>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5)
array([[-1.64331, -2.16506],
       [1.08619, 5.79854]], dtype=float32)