mlx.nn.losses.binary_cross_entropy#
- class binary_cross_entropy(inputs: array, targets: array, weights: array | None = None, with_logits: bool = True, reduction: Literal['none', 'mean', 'sum'] = 'mean')#
计算二元交叉熵损失。
默认情况下,此函数接受 sigmoid 前的 logits,这样可以实现更快、更精确的损失计算。为了在
with_logits=False
时提高数值稳定性,损失计算会将输入概率(在对数空间中)裁剪到最小值-100
。- 参数:
inputs (array) – 预测值。如果
with_logits
为True
,则inputs
是未归一化的 logits。否则,inputs
是概率。targets (array) – 目标二元值,取值范围为 {0, 1}。
with_logits (bool, optional) –
inputs
是否为 logits。默认为:True
。weights (array, optional) – 可选的每个目标的权重。默认为:
None
。reduction (str, optional) – 指定应用于输出的归约方法:
'none'
|'mean'
|'sum'
。默认为:'mean'
。
- 返回值:
计算得到的二元交叉熵损失。
- 返回类型:
示例
>>> import mlx.core as mx >>> import mlx.nn as nn
>>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291]) >>> targets = mx.array([0, 0, 1, 1]) >>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction="mean") >>> loss array(0.539245, dtype=float32)
>>> probs = mx.array([0.1, 0.1, 0.4, 0.4]) >>> targets = mx.array([0, 0, 1, 1]) >>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction="mean") >>> loss array(0.510826, dtype=float32)