mlx.optimizers.join_schedules

mlx.optimizers.join_schedules#

join_schedules(schedules: List[Callable], boundaries: List[int]) 可调用对象#

连接多个调度器以创建一个新的调度器。

参数:
  • schedules (list(Callable)) – 调度器列表。调度器 \(i+1\) 接收一个步数计数,表示自第 \(i\) 个边界以来的步数。

  • boundaries (list(int)) – 整数列表,长度为 len(schedules) - 1,指示何时在调度器之间进行切换。

示例

>>> linear = optim.linear_schedule(0, 1e-1, steps=10)
>>> cosine = optim.cosine_decay(1e-1, 200)
>>> lr_schedule = optim.join_schedules([linear, cosine], [10])
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(12): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.0999938, dtype=float32)