mlx.nn.losses.cross_entropy

目录

mlx.nn.losses.cross_entropy#

class cross_entropy(logits: array, targets: array, weights: array | None = None, axis: int = -1, label_smoothing: 类型为 float = 0.0, reduction: Literal['none', 'mean', 'sum'] = 'none')#

计算交叉熵损失。

参数:
  • logits (数组) – 未归一化的对数几率(logits)。

  • targets (数组) – 真实值。可以是类别索引或每个类别的概率。如果 targets 是类别索引,则 targets 的形状应与 logits 形状匹配,但移除 axis 维度。如果 targets 是概率(或独热编码),则 targets 的形状应与 logits 形状相同。

  • weights (数组, 可选) – 每个目标的权重(可选)。默认值:None

  • axis (int, 可选) – 计算 softmax 的轴。默认值:-1

  • label_smoothing (float, 可选) – 标签平滑因子。默认值:0

  • reduction (str, 可选) – 指定应用于输出的归约方式:'none' | 'mean' | 'sum'。默认值:'none'

返回值:

计算得到的交叉熵损失。

返回类型:

数组

示例

>>> import mlx.core as mx
>>> import mlx.nn as nn
>>>
>>> # Class indices as targets
>>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
>>> targets = mx.array([0, 1])
>>> nn.losses.cross_entropy(logits, targets)
array([0.0485873, 0.0485873], dtype=float32)
>>>
>>> # Probabilities (or one-hot vectors) as targets
>>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
>>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]])
>>> nn.losses.cross_entropy(logits, targets)
array([0.348587, 0.348587], dtype=float32)