mlx.nn.losses.margin_ranking_loss#
- class margin_ranking_loss(inputs1: array, inputs2: array, targets: array, margin: float = 0.0, reduction: Literal['none', 'mean', 'sum'] = 'none')#
计算 margin ranking loss。该损失函数接受输入 \(x_1\)、\(x_2\) 和标签 \(y\) (包含 1 或 -1)。
损失函数由以下公式给出:
\[\text{loss} = \max (0, -y * (x_1 - x_2) + \text{margin})\]其中 \(y\) 代表
targets
,\(x_1\) 代表inputs1
,\(x_2\) 代表inputs2
。- 参数:
- 返回值:
计算得到的 margin ranking loss。
- 返回值类型:
示例
>>> import mlx.core as mx >>> import mlx.nn as nn >>> targets = mx.array([1, 1, -1]) >>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638]) >>> inputs2 = mx.array([0.75596, 0.225763, 0.256995]) >>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets) >>> loss array(0.773433, dtype=float32)