mlx.core.random.truncated_normal

mlx.core.random.truncated_normal#

truncated_normal(lower: scalar | array, upper: scalar | array, shape: Sequence[int] | None = None, dtype: Dtype | None = float32, key: array | None = None, stream: None | Stream | Device = None) array#

从截断正态分布生成值。

这些值从域 (lower, upper) 上的截断正态分布中采样。边界 lowerupper 可以是标量或数组,并且必须能够广播到 shape

参数:
  • lower (scalararray) – 域的下界。

  • upper (scalararray) – 域的上界。

  • shape (list(int), optional) – 输出的形状。默认值: ()

  • dtype (Dtype, optional) – 输出的数据类型。默认值: float32

  • key (array, optional) – PRNG 密钥。默认值: None

返回值:

随机值的输出数组。

返回类型:

array