随机数

随机数#

MLX 中的随机采样函数默认使用隐式的全局 PRNG 状态。然而,所有函数都接受一个可选的 key 关键字参数,以便在需要更精细控制或显式状态管理时使用。

例如,你可以使用以下方法生成随机数:

for _ in range(3):
  print(mx.random.uniform())

这将打印一系列独特的伪随机数。或者,你可以显式设置 key

key = mx.random.key(0)
for _ in range(3):
  print(mx.random.uniform(key=key))

这将导致在每次迭代中产生相同的伪随机数。

遵循 JAX 的 PRNG 设计,我们使用了 Threefry 的可分割版本,它是一种基于计数器的 PRNG。

bernoulli([p, shape, key, stream])

生成伯努利随机值。

categorical(logits[, axis, shape, ...])

从分类分布中采样。

gumbel([shape, dtype, key, stream])

从标准 Gumbel 分布中采样。

key(seed)

从种子获取 PRNG key。

normal([shape, dtype, loc, scale, key, stream])

生成正态分布的随机数。

multivariate_normal(mean, cov[, shape, ...])

给定均值和协方差,生成联合正态分布的随机样本。

randint(low, high[, shape, dtype, key, stream])

从给定区间生成随机整数。

seed(seed)

为全局 PRNG 设置种子。

split(key[, num, stream])

将 PRNG key 分割成子 key。

truncated_normal(lower, upper[, shape, ...])

从截断正态分布生成值。

uniform([low, high, shape, dtype, key, stream])

生成均匀分布的随机数。

laplace([shape, dtype, loc, scale, key, stream])

从拉普拉斯分布中采样数。

permutation(x[, axis, key, stream])

生成随机排列或对数组的元素进行排列。