mlx.core.random.multivariate_normal#
- multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Dtype | None = float32, key: array | None = None, stream: None | Stream | Device = None) array#
给定均值和协方差,生成联合正态分布的随机样本。
矩阵
cov必须是半正定的。如果不是,行为是未定义的。唯一支持的dtype是float32。- 参数:
mean (array) – 形状为
(..., n)的数组,表示分布的均值。cov (array) – 形状为
(..., n, n)的数组,表示分布的协方差矩阵。批处理形状...必须与mean的批处理形状兼容可广播。shape (list(int), 可选) – 输出形状必须与
mean.shape[:-1]和cov.shape[:-2]兼容可广播。如果为空,结果形状由mean和cov的批处理形状广播决定。默认值:[]。dtype (Dtype, 可选) – 输出类型。默认值:
float32。key (array, 可选) – PRNG 密钥。默认值:
None。
- 返回:
随机值组成的输出数组。
- 返回类型: