mlx.nn.losses.cross_entropy#
- class cross_entropy(logits: array, targets: array, weights: array | None = None, axis: int = -1, label_smoothing: 类型为 float = 0.0, reduction: Literal['none', 'mean', 'sum'] = 'none')#
计算交叉熵损失。
- 参数:
logits (数组) – 未归一化的对数几率(logits)。
targets (数组) – 真实值。可以是类别索引或每个类别的概率。如果
targets
是类别索引,则targets
的形状应与logits
形状匹配,但移除axis
维度。如果targets
是概率(或独热编码),则targets
的形状应与logits
形状相同。weights (数组, 可选) – 每个目标的权重(可选)。默认值:
None
。axis (int, 可选) – 计算 softmax 的轴。默认值:
-1
。label_smoothing (float, 可选) – 标签平滑因子。默认值:
0
。reduction (str, 可选) – 指定应用于输出的归约方式:
'none'
|'mean'
|'sum'
。默认值:'none'
。
- 返回值:
计算得到的交叉熵损失。
- 返回类型:
示例
>>> import mlx.core as mx >>> import mlx.nn as nn >>> >>> # Class indices as targets >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) >>> targets = mx.array([0, 1]) >>> nn.losses.cross_entropy(logits, targets) array([0.0485873, 0.0485873], dtype=float32) >>> >>> # Probabilities (or one-hot vectors) as targets >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) >>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]]) >>> nn.losses.cross_entropy(logits, targets) array([0.348587, 0.348587], dtype=float32)