mlx.nn.losses.kl_div_loss

目录

mlx.nn.losses.kl_div_loss#

class kl_div_loss(inputs: array, targets: array, axis: int = -1, reduction: Literal['none', 'mean', 'sum'] = 'none')#

计算 Kullback-Leibler 散度损失。

reduction == 'none' 时计算以下内容

mx.exp(targets) * (targets - inputs).sum(axis)
参数:
  • inputs (array) – 预测分布的对数概率。

  • targets (array) – 目标分布的对数概率。

  • axis (int, 可选的) – 分布轴。默认值: -1

  • reduction (str, 可选的) – 指定应用于输出的归约方式: 'none' | 'mean' | 'sum'。默认值: 'none'

返回值:

计算得到的 Kullback-Leibler 散度损失。

返回类型:

array