mlx.core.random.categorical

目录

mlx.core.random.categorical#

categorical(logits: array, axis: int = -1, shape: Sequence[int] | None = None, num_samples: int | None = None, key: array | None = None, stream: None | Stream | Device = None) array#

从分类分布中采样。

值是从 logits 中指定的未归一化值所表示的分类分布中采样的。注意,shapenum_samples 中最多只能指定一个。如果两者都为 None,则输出的形状与移除了 axis 维度的 logits 相同。

参数:
  • logits (array) – 未归一化的分类分布。

  • axis (int, optional) – 指定分布的轴。默认值: -1

  • shape (list(int), optional) – 输出的形状。该形状必须与移除了 axis 维度的 logits.shape 广播兼容。默认值: None

  • num_samples (int, optional) – 从 logits 中每个分类分布抽取的样本数量。输出将在最后一个维度包含 num_samples。默认值: None

  • key (array, optional) – PRNG 密钥。默认值: None

返回:

形状为 shape 的输出数组,类型为 uint32

返回类型:

array