mlx.core.random.multivariate_normal

mlx.core.random.multivariate_normal#

multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Dtype | None = float32, key: array | None = None, stream: None | Stream | Device = None) array#

给定均值和协方差,生成联合正态分布的随机样本。

矩阵 cov 必须是半正定的。如果不是,行为是未定义的。唯一支持的 dtypefloat32

参数:
  • mean (array) – 形状为 (..., n) 的数组,表示分布的均值。

  • cov (array) – 形状为 (..., n, n) 的数组,表示分布的协方差矩阵。批处理形状 ... 必须与 mean 的批处理形状兼容可广播。

  • shape (list(int), 可选) – 输出形状必须与 mean.shape[:-1]cov.shape[:-2] 兼容可广播。如果为空,结果形状由 meancov 的批处理形状广播决定。默认值:[]

  • dtype (Dtype, 可选) – 输出类型。默认值:float32

  • key (array, 可选) – PRNG 密钥。默认值:None

返回:

随机值组成的输出数组。

返回类型:

array