分布式通信#

MLX 支持分布式通信操作,允许在多台物理机器上分担训练或推理的计算成本。目前我们支持两种不同的通信后端:

  • MPI 一个功能齐全且成熟的分布式通信库

  • 我们自己的 ring 后端,它使用原生 TCP socket,对于 Thunderbolt 连接应该更快。

目前支持的所有操作及其文档可在API 文档中查看。

注意

有些操作可能尚不支持或未达到应有的速度。我们正在添加更多操作并调优现有操作,同时也在探索使用 MLX 在 Mac 上进行分布式计算的最佳方法。

入门#

MLX 中的分布式程序就像下面这样简单:

import mlx.core as mx

world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)

上面的程序在所有分布式进程中对数组 mx.ones(10) 进行求和。然而,当使用 python 运行此脚本时,只会启动一个进程,不会发生分布式通信。也就是说,当分布式组的大小为一时,mx.distributed 中的所有操作都是空操作。此特性使我们无需编写类似于以下代码中检查是否处于分布式环境的代码:

import mlx.core as mx

x = ...
world = mx.distributed.init()
# No need for the check we can simply do x = mx.distributed.all_sum(x)
if world.size() > 1:
    x = mx.distributed.all_sum(x)

运行分布式程序#

MLX 提供了 mlx.launch,这是一个启动分布式程序的辅助脚本。接着我们的初始示例,我们可以使用以下命令在本地主机上以 4 个进程运行它:

$ mlx.launch -n 4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)

我们也可以通过提供远程主机的 IP 地址在这些主机上运行它(前提是脚本存在于所有主机上并且可以通过 ssh 访问它们):

$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)

有关使用 mlx.launch 的更多信息,请查阅专门的使用指南

选择后端#

您可以通过传递 {'any', 'ring', 'mpi'} 中的一个来选择调用 init() 时要使用的后端。当传递 any 时,MLX 将尝试初始化 ring 后端,如果失败则尝试 mpi 后端。如果两者都失败,则会创建一个单例组。

注意

分布式后端成功初始化后,如果在不带参数或将后端设置为 any 的情况下调用 init(),它将返回 相同的后端

以下示例旨在阐明 MLX 中的后端初始化逻辑:

# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
world = mx.distributed.init(backend="mpi")
world2 = mx.distributed.init()  # subsequent calls return the MPI backend!

# Case 2: Initialize any backend
world = mx.distributed.init(backend="any")  # equivalent to no arguments
world2 = mx.distributed.init()  # same as above

# Case 3: Initialize both backends at the same time
world_mpi = mx.distributed.init(backend="mpi")
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init()  # same as MPI because it was initialized first!

训练示例#

在本节中,我们将调整一个 MLX 训练循环以支持数据并行分布式训练。也就是说,我们将在应用梯度到模型之前,先对一组主机上的梯度进行平均。

如果忽略模型、数据集和优化器的初始化,我们的训练循环代码片段如下所示:

model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    optimizer.update(model, grads)
    return loss

for x, y in dataset:
    loss = step(model, x, y)
    mx.eval(loss, model.parameters())

要在机器之间平均梯度,我们只需执行 all_sum() 并除以 Group 的大小。也就是说,我们需要使用以下函数对梯度进行 mlx.utils.tree_map() 操作。

def all_avg(x):
    return mx.distributed.all_sum(x) / mx.distributed.init().size()

将所有内容放在一起,我们的训练循环步骤如下所示,其他部分保持不变。

from mlx.utils import tree_map

def all_reduce_grads(grads):
    N = mx.distributed.init().size()
    if N == 1:
        return grads
    return tree_map(
        lambda x: mx.distributed.all_sum(x) / N,
        grads
    )

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    grads = all_reduce_grads(grads)  # <--- This line was added
    optimizer.update(model, grads)
    return loss

使用 nn.average_gradients#

尽管上面的代码示例可以正常工作,但它为每个梯度执行一次通信。将多个梯度聚合在一起并执行更少的通信步骤效率会显著提高。

这就是 mlx.nn.average_gradients() 的目的。最终的代码与上面的示例几乎相同:

model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    grads = mlx.nn.average_gradients(grads) # <---- This line was added
    optimizer.update(model, grads)
    return loss

for x, y in dataset:
    loss = step(model, x, y)
    mx.eval(loss, model.parameters())

MPI 入门#

如果机器上安装了 MPI,MLX 已具备与之“对话”的能力。使用 MPI 启动分布式 MLX 程序可以像预期一样通过 mpirun 进行。然而,在以下示例中,我们将使用 mlx.launch --backend mpi,它会处理一些麻烦,例如为 mpirun 可执行文件和 libmpi.dyld 共享库设置绝对路径。

最简单的用法如下,假设使用本页开头的最小示例,其结果应该是:

$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)

上述命令在同一(本地)机器上启动了两个进程,我们可以看到两个标准输出流。这些进程将全 1 数组发送给彼此并计算总和,然后打印出来。使用 mlx.launch -n 4 ... 启动将打印 4 等。

安装 MPI#

MPI 可以通过 Homebrew、Anaconda 包管理器或从源代码编译安装。我们的大多数测试是使用通过 Anaconda 包管理器安装的 openmpi 完成的,如下所示:

$ conda install conda-forge::openmpi

使用 Homebrew 安装可能需要指定 libmpi.dyld 的位置,以便 MLX 可以在运行时找到并加载它。这只需将 DYLD_LIBRARY_PATH 环境变量传递给 mpirun 即可实现,并且 mlx.launch 会自动完成此操作。

$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ # or simply
$ mlx.launch -n 2 test.py

设置远程主机#

如果远程主机可以通过 ssh 访问,MPI 可以自动连接到远程主机并在网络上建立通信。一个用于调试连接问题的良好清单如下:

  • ssh hostname 可以在所有机器之间工作,无需询问密码或主机确认

  • mpirun 在所有机器上都可访问。

  • 确保 MPI 使用的 hostname 与您在所有机器的 .ssh/config 文件中配置的 hostname 一致。

调优 MPI All Reduce#

注意

为了获得更快的 all reduce,可以考虑使用 ring 后端,无论是通过 Thunderbolt 连接还是以太网连接。

配置 MPI 在每个主机之间使用 N 个 TCP 连接以提高带宽,通过传递 --mca btl_tcp_links N

通过设置 --mca btl_tcp_if_include <iface> 强制 MPI 使用性能最好的网络接口,其中 <iface> 应为您要使用的接口。

Ring 入门#

ring 后端不依赖任何第三方库,因此始终可用。它使用 TCP socket,所以节点需要通过网络可达。顾名思义,节点呈环状连接,这意味着 rank 1 只能与 rank 0 和 rank 2 通信,rank 2 只能与 rank 1 和 rank 3 通信,依此类推。因此,ring 后端不支持带有任意发送者和接收者的 send()recv()

定义 Ring#

定义和使用 ring 最简单的方法是通过 JSON 主机文件和 mlx.launch 辅助脚本。对于每个节点,需要定义一个用于 ssh 登录以在该节点上运行命令的主机名,以及一个或多个该节点将监听连接的 IP 地址。

例如,下面的主机文件定义了一个 4 节点环。 hostname1 将是 rank 0,hostname2 是 rank 1,依此类推。

[
    {"ssh": "hostname1", "ips": ["123.123.123.1"]},
    {"ssh": "hostname2", "ips": ["123.123.123.2"]},
    {"ssh": "hostname3", "ips": ["123.123.123.3"]},
    {"ssh": "hostname4", "ips": ["123.123.123.4"]}
]

运行 mlx.launch --hostfile ring-4.json my_script.py 将 ssh 登录到每个节点,运行脚本,该脚本将在提供的每个 IP 地址上监听连接。具体来说,hostname1 将连接到 123.123.123.2 并接受来自 123.123.123.4 的连接,依此类推。

Thunderbolt Ring#

尽管 ring 后端即使在以太网上也可能比 MPI 具有优势,但其主要目的是利用 Thunderbolt ring 实现更高带宽的通信。手动设置此类 Thunderbolt ring 可能相对繁琐。为了简化这一过程,我们提供了 mlx.distributed_config 工具。

要使用 mlx.distributed_config,您的计算机需要通过以太网或 Wi-Fi 经 ssh 访问。然后,通过 Thunderbolt 连接线将它们连接起来,并按如下方式调用该工具:

mlx.distributed_config --verbose --hosts host1,host2,host3,host4

默认情况下,脚本会尝试发现 Thunderbolt ring,并为您提供配置每个节点的命令以及与 mlx.launch 一起使用的 hostfile.json 文件。如果节点上可以使用无密码的 sudo,则可以使用 --auto-setup 自动配置它们。

要在不配置任何内容的情况下验证您的连接,mlx.distributed_config 还可以使用 DOT 格式绘制 ring。

mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png

如果您想手动完成此过程,步骤如下:

  • 禁用 Thunderbolt bridge 接口

  • 对于连接 rank i 和 rank i + 1 的连接线,找到节点 ii + 1 中与该连接线对应的接口。

  • 为对应的接口设置一个连接这两个节点的唯一子网。例如,如果连接线对应于节点 i 上的 en2 和节点 i + 1 上的 en2,那么我们可以分别为这两个节点分配 IP 地址 192.168.0.1192.168.0.2。更多详细信息可以查看工具脚本准备的命令。