随机数#
MLX 中的随机采样函数默认使用隐式的全局 PRNG 状态。然而,所有函数都接受一个可选的 key
关键字参数,以便在需要更精细控制或显式状态管理时使用。
例如,你可以使用以下方法生成随机数:
for _ in range(3):
print(mx.random.uniform())
这将打印一系列独特的伪随机数。或者,你可以显式设置 key
key = mx.random.key(0)
for _ in range(3):
print(mx.random.uniform(key=key))
这将导致在每次迭代中产生相同的伪随机数。
遵循 JAX 的 PRNG 设计,我们使用了 Threefry 的可分割版本,它是一种基于计数器的 PRNG。
|
生成伯努利随机值。 |
|
从分类分布中采样。 |
|
从标准 Gumbel 分布中采样。 |
|
从种子获取 PRNG key。 |
|
生成正态分布的随机数。 |
|
给定均值和协方差,生成联合正态分布的随机样本。 |
|
从给定区间生成随机整数。 |
|
为全局 PRNG 设置种子。 |
|
将 PRNG key 分割成子 key。 |
|
从截断正态分布生成值。 |
|
生成均匀分布的随机数。 |
|
从拉普拉斯分布中采样数。 |
|
生成随机排列或对数组的元素进行排列。 |