mlx.nn.losses.triplet_loss

目录

mlx.nn.losses.triplet_loss#

class triplet_loss(anchors: array, positives: array, negatives: array, axis: int = -1, p: int = 2, margin: float = 1.0, eps: float = 1e-06, reduction: Literal['none', 'mean', 'sum'] = 'none')#

计算给定锚点、正样本和负样本的 Triplet 损失。Margin 在数学公式中用 alpha 表示。

\[\max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)\]
参数:
  • anchors (array) – 锚点样本。

  • positives (array) – 正样本。

  • negatives (array) – 负样本。

  • axis (int, 可选) – 分布轴。默认值: -1

  • p (int, 可选) – 成对距离的范数阶数。默认值: 2

  • margin (float, 可选) – Triplet 损失的 margin。默认值为 1.0

  • eps (float, 可选) – 用于防止数值不稳定的小正数常量。默认值为 1e-6

  • reduction (str, 可选) – 指定应用于输出的归约方法:'none' | 'mean' | 'sum'。默认值: 'none'

返回:

计算出的 Triplet 损失。如果 reduction 是 “none”,返回一个与输入形状相同的张量;

如果 reduction 是 “mean” 或 “sum”,返回一个标量张量。

返回类型:

array