mlx.optimizers.cosine_decay

mlx.optimizers.cosine_decay#

cosine_decay(init: float, decay_steps: int, end: float = 0.0) Callable#

创建一个余弦衰减调度器。

参数:
  • init (float) – 初始值。

  • decay_steps (int) – 衰减的步数。对于超过 decay_steps 的步数,衰减值保持不变。

  • end (float, optional) – 衰减到的最终值。默认值: 0

示例

>>> lr_schedule = optim.cosine_decay(1e-1, 1000)
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.1, dtype=float32)
>>>
>>> for _ in range(5): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.0999961, dtype=float32)