mlx.nn.init.glorot_normal

mlx.nn.init.glorot_normal#

glorot_normal(dtype: Dtype = mlx.core.float32) Callable[[array, float], array]#

一个 Glorot 正态分布初始化器。

此初始化器从一个正态分布中采样,其标准差根据输入单元数 (fan_in) 和输出单元数 (fan_out) 计算得到,公式如下:

\[\sigma = \gamma \sqrt{\frac{2.0}{\text{fan\_in} + \text{fan\_out}}}\]

更多详情请参阅原始文献:Understanding the difficulty of training deep feedforward neural networks

参数:

dtype (Dtype, optional) – 数组的数据类型。默认值: float32

返回:

一个初始化器,返回一个与输入具有相同形状的数组,填充了从 Glorot 正态分布中采样的值。

返回类型:

Callable[[array, float], array]

示例

>>> init_fn = nn.init.glorot_normal()
>>> init_fn(mx.zeros((2, 2)))
array([[0.191107, 1.61278],
       [-0.150594, -0.363207]], dtype=float32)
>>> init_fn(mx.zeros((2, 2)), gain=4.0)
array([[1.89613, -4.53947],
       [4.48095, 0.995016]], dtype=float32)