mlx.core.random.multivariate_normal#
- multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Dtype | None = float32, key: array | None = None, stream: None | Stream | Device = None) array #
给定均值和协方差,生成联合正态分布的随机样本。
矩阵
cov
必须是半正定的。如果不是,行为是未定义的。唯一支持的dtype
是float32
。- 参数:
mean (array) – 形状为
(..., n)
的数组,表示分布的均值。cov (array) – 形状为
(..., n, n)
的数组,表示分布的协方差矩阵。批处理形状...
必须与mean
的批处理形状兼容可广播。shape (list(int), 可选) – 输出形状必须与
mean.shape[:-1]
和cov.shape[:-2]
兼容可广播。如果为空,结果形状由mean
和cov
的批处理形状广播决定。默认值:[]
。dtype (Dtype, 可选) – 输出类型。默认值:
float32
。key (array, 可选) – PRNG 密钥。默认值:
None
。
- 返回:
随机值组成的输出数组。
- 返回类型: