mlx.nn.average_gradients

mlx.nn.average_gradients#

average_gradients(gradients: Any, group: Group | None = None, all_reduce_size: int = 33554432, communication_type: Dtype | None = None)#

在传入的组中平均分布式进程的梯度。

此助手可以将几个小数组的梯度连接起来,以便进行一次大型 all reduce 调用,从而获得更好的网络性能。

参数
  • gradients (Any) – 包含梯度的 Python 树(它在所有进程中应具有相同的结构)

  • group (可选[Group]) – 用于平均梯度的进程组。如果设置为 None,则使用全局组。默认值:None

  • all_reduce_size (int) – 将数组分组,直到其字节大小超过此数值。每组数组执行一个通信步骤。如果小于或等于 0,则禁用数组分组。默认值:32MiB

  • communication_type (可选[Dtype]) – 如果提供,在执行通信之前转换为此类型。通常转换为较小的浮点类型以减少通信大小。默认值:None