mlx.nn.losses.gaussian_nll_loss#
- class gaussian_nll_loss(inputs: array, targets: array, vars: array, full: bool = False, eps: float = 1e-06, reduction: Literal['none', 'mean', 'sum'] = 'mean')#
计算高斯分布的负对数似然损失。
损失由以下公式给出:
\[\frac{1}{2}\left(\log\left(\max\left(\text{vars}, \ \epsilon\right)\right) + \frac{\left(\text{inputs} - \text{targets} \right)^2} {\max\left(\text{vars}, \ \epsilon \right)}\right) + \text{const.}\]其中
inputs
是预测均值,vars
是预测方差。