mlx.core.random.categorical#
- categorical(logits: array, axis: int = -1, shape: Sequence[int] | None = None, num_samples: int | None = None, key: array | None = None, stream: None | Stream | Device = None) array#
从分类分布中采样。
值是从
logits中指定的未归一化值所表示的分类分布中采样的。注意,shape或num_samples中最多只能指定一个。如果两者都为None,则输出的形状与移除了axis维度的logits相同。- 参数:
- 返回:
形状为
shape的输出数组,类型为uint32。- 返回类型: