mlx.nn.init.normal

目录

mlx.nn.init.normal#

normal(mean: float = 0.0, std: float = 1.0, dtype: Dtype = mlx.core.float32) Callable[[array], array]#

一个初始化器,返回正态分布的样本。

参数:
  • mean (float, 可选) – 正态分布的均值。默认值:0.0

  • std (float, 可选) – 正态分布的标准差。默认值:1.0

  • dtype (Dtype, 可选) – 数组的数据类型。默认值:float32

返回:

一个初始化器,返回一个与输入具有相同形状的数组,并填充了正态分布的样本。

返回类型:

Callable[[array], array]

示例

>>> init_fn = nn.init.normal()
>>> init_fn(mx.zeros((2, 2)))
array([[-0.982273, -0.534422],
       [0.380709, 0.0645099]], dtype=float32)