mlx.nn.losses.binary_cross_entropy

mlx.nn.losses.binary_cross_entropy#

class binary_cross_entropy(inputs: array, targets: array, weights: array | None = None, with_logits: bool = True, reduction: Literal['none', 'mean', 'sum'] = 'mean')#

计算二元交叉熵损失。

默认情况下,此函数接受 sigmoid 前的 logits,这样可以实现更快、更精确的损失计算。为了在 with_logits=False 时提高数值稳定性,损失计算会将输入概率(在对数空间中)裁剪到最小值 -100

参数:
  • inputs (array) – 预测值。如果 with_logitsTrue,则 inputs 是未归一化的 logits。否则,inputs 是概率。

  • targets (array) – 目标二元值,取值范围为 {0, 1}。

  • with_logits (bool, optional) – inputs 是否为 logits。默认为:True

  • weights (array, optional) – 可选的每个目标的权重。默认为:None

  • reduction (str, optional) – 指定应用于输出的归约方法:'none' | 'mean' | 'sum'。默认为:'mean'

返回值:

计算得到的二元交叉熵损失。

返回类型:

array

示例

>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
>>> targets = mx.array([0, 0, 1, 1])
>>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction="mean")
>>> loss
array(0.539245, dtype=float32)
>>> probs = mx.array([0.1, 0.1, 0.4, 0.4])
>>> targets = mx.array([0, 0, 1, 1])
>>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction="mean")
>>> loss
array(0.510826, dtype=float32)