mlx.core.random.categorical#
- categorical(logits: array, axis: int = -1, shape: Sequence[int] | None = None, num_samples: int | None = None, key: array | None = None, stream: None | Stream | Device = None) array #
从分类分布中采样。
值是从
logits
中指定的未归一化值所表示的分类分布中采样的。注意,shape
或num_samples
中最多只能指定一个。如果两者都为None
,则输出的形状与移除了axis
维度的logits
相同。- 参数:
- 返回:
形状为
shape
的输出数组,类型为uint32
。- 返回类型: