损失函数

树工具

mlx.utils.tree_flatten

mlx.utils.tree_unflatten

mlx.utils.tree_map

mlx.utils.tree_map_with_path

mlx.utils.tree_reduce

C++ API 参考

延伸阅读

MLX 中的自定义扩展

Metal 调试器

自定义 Metal 内核

在 C++ 中使用 MLX

.rst

.pdf

损失函数#

binary_cross_entropy(inputs, targets[, ...])

计算二元交叉熵损失。

cosine_similarity_loss(x1, x2[, axis, eps, ...])

计算两个输入之间的余弦相似度损失。

cross_entropy(logits, targets[, weights, ...])

计算交叉熵损失。

gaussian_nll_loss(inputs, targets, vars[, ...])

计算高斯分布的负对数似然损失。

hinge_loss(inputs, targets[, reduction])

计算输入和目标之间的 Hinge 损失。

huber_loss(inputs, targets[, delta, reduction])

计算输入和目标之间的 Huber 损失。

kl_div_loss(inputs, targets[, axis, reduction])

计算 KL 散度损失。

l1_loss(predictions, targets[, reduction])