mlx.nn.losses.triplet_loss#
- class triplet_loss(anchors: array, positives: array, negatives: array, axis: int = -1, p: int = 2, margin: float = 1.0, eps: float = 1e-06, reduction: Literal['none', 'mean', 'sum'] = 'none')#
计算给定锚点、正样本和负样本的 Triplet 损失。Margin 在数学公式中用 alpha 表示。
\[\max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)\]- 参数:
anchors (array) – 锚点样本。
positives (array) – 正样本。
negatives (array) – 负样本。
axis (int, 可选) – 分布轴。默认值:
-1
。p (int, 可选) – 成对距离的范数阶数。默认值:
2
。margin (float, 可选) – Triplet 损失的 margin。默认值为
1.0
。eps (float, 可选) – 用于防止数值不稳定的小正数常量。默认值为
1e-6
。reduction (str, 可选) – 指定应用于输出的归约方法:
'none'
|'mean'
|'sum'
。默认值:'none'
。
- 返回:
- 计算出的 Triplet 损失。如果 reduction 是 “none”,返回一个与输入形状相同的张量;
如果 reduction 是 “mean” 或 “sum”,返回一个标量张量。
- 返回类型: