mlx.nn.losses.cosine_similarity_loss

mlx.nn.losses.cosine_similarity_loss#

class cosine_similarity_loss(x1: array, x2: array, axis: int = 1, eps: float = 1e-08, reduction: Literal['none', 'mean', 'sum'] = 'none')#

计算两个输入之间的余弦相似度。

余弦相似度损失计算公式如下:

\[\frac{x_1 \cdot x_2}{\max(\|x_1\| \cdot \|x_2\|, \epsilon)}\]
参数:
  • x1 (mx.array) – 第一个输入集合。

  • x2 (mx.array) – 第二个输入集合。

  • axis (int, 可选) – 嵌入轴。默认值: 1

  • eps (float, 可选) – 用于数值稳定性的分母最小值。默认值: 1e-8

  • reduction (str, 可选) – 指定应用于输出的归约方式: 'none' | 'mean' | 'sum'。默认值: 'none'

返回值:

计算得到的余弦相似度损失。

返回值类型:

mx.array