mlx.optimizers.AdamW#
- class AdamW(learning_rate: float | Callable[[array], array], betas: List[float] = [0.9, 0.999], eps: float = 1e-08, weight_decay: float = 0.01, bias_correction: bool = False)#
AdamW 优化器 [1]。我们使用 weight_decay (\(\lambda\)) 值更新权重。
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay regularization. ICLR 2019.
\[\begin{split}m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\ w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)\end{split}\]- 参数:
learning_rate (float 或 callable) – 学习率 \(\alpha\)。
betas (Tuple[float, float], 可选) – 用于计算梯度及其平方的滑动平均的系数 \((\beta_1, \beta_2)\)。默认值:
(0.9, 0.999)
eps (float, 可选) – 添加到分母中的项 \(\epsilon\),以提高数值稳定性。默认值:
1e-8
weight_decay (float, 可选) – 权重衰减 \(\lambda\)。默认值:
0.01
。bias_correction (bool, 可选) – 如果设置为
True
,则应用偏差校正。默认值:False
方法
__init__
(learning_rate[, betas, eps, ...])apply_single
(gradient, parameter, state)通过修改传递给 Adam 的参数来执行 AdamW 参数更新。