mlx.nn.init.glorot_uniform

mlx.nn.init.glorot_uniform#

glorot_uniform(dtype: Dtype = mlx.core.float32) Callable[[array, float], array]#

Glorot 均匀分布初始化器。

该初始化器从均匀分布中采样,其范围根据输入单元数 (fan_in) 和输出单元数 (fan_out) 计算得出,公式如下:

\[\sigma = \gamma \sqrt{\frac{6.0}{\text{fan\_in} + \text{fan\_out}}}\]

更多详细信息请参阅原始参考文献:理解训练深度前馈神经网络的难度

参数:

dtype (Dtype, 可选) – 数组的数据类型。默认值:float32

返回值:

一个初始化器,它返回一个与输入具有相同形状的数组,并填充从 Glorot 均匀分布中采样的值。

返回类型:

Callable[[array, float], array]

示例

>>> init_fn = nn.init.glorot_uniform()
>>> init_fn(mx.zeros((2, 2)))
array([[0.223404, -0.890597],
       [-0.379159, -0.776856]], dtype=float32)
>>> init_fn(mx.zeros((2, 2)), gain=4.0)
array([[-1.90041, 3.02264],
       [-0.912766, 4.12451]], dtype=float32)