mlx.optimizers.clip_grad_norm

mlx.optimizers.clip_grad_norm#

clip_grad_norm(grads, max_norm)#

限制梯度的全局范数。

此函数确保梯度的全局范数不超过 max_norm。如果梯度的范数大于 max_norm,它会按比例缩小梯度。

示例

>>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])}
>>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0)
>>> print(clipped_grads)
{"w1": mx.array([...]), "w2": mx.array([...])}
参数:
  • grads (dict) – 包含梯度数组的字典。

  • max_norm (float) – 允许的最大梯度全局范数。

返回值:

可能已缩放的梯度和原始梯度范数。

返回类型:

(dict, float)