mlx.optimizers.clip_grad_norm#
- clip_grad_norm(grads, max_norm)#
限制梯度的全局范数。
此函数确保梯度的全局范数不超过
max_norm
。如果梯度的范数大于max_norm
,它会按比例缩小梯度。示例
>>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])} >>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0) >>> print(clipped_grads) {"w1": mx.array([...]), "w2": mx.array([...])}