神经网络#

在 MLX 中编写任意复杂的神经网络,只需使用 mlx.core.arraymlx.core.value_and_grad()。然而,这要求用户一遍又一遍地编写相同的简单神经网络操作,并手动显式地处理所有参数状态和初始化。

模块 mlx.nn 通过提供一种直观的方式来组合神经网络层、初始化参数、冻结参数以进行微调等,解决了这个问题。

神经网络快速入门#

import mlx.core as mx
import mlx.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int):
        super().__init__()

        self.layers = [
            nn.Linear(in_dims, 128),
            nn.Linear(128, 128),
            nn.Linear(128, out_dims),
        ]

    def __call__(self, x):
        for i, l in enumerate(self.layers):
            x = mx.maximum(x, 0) if i > 0 else x
            x = l(x)
        return x

# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)

# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)

# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])

# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())

# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
#       it from the local scope. It could be a positional argument or a
#       keyword argument.
def l2_loss(x, y):
    y_hat = mlp(x)
    return (y_hat - y).square().mean()

# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)

模块类#

任何神经网络库的核心都是 Module 类。在 MLX 中,Module 类是 mlx.core.arrayModule 实例的容器。它的主要功能是提供一种递归地 访问更新 其参数及其子模块参数的方法。

参数#

模块的参数是类型为 mlx.core.array 的任何公共成员(其名称不应以 _ 开头)。它可以任意嵌套在其他 Module 实例或列表和字典中。

Module.parameters() 可用于提取一个嵌套字典,其中包含模块及其子模块的所有参数。

一个 Module 还可以跟踪“冻结”参数。有关详细信息,请参阅 Module.freeze() 方法。通过 mlx.nn.value_and_grad() 返回的梯度将是相对于这些可训练参数的。

更新参数#

MLX 模块允许访问和更新单个参数。然而,大多数时候我们需要更新模块参数的大部分子集。此操作通过 Module.update() 完成。

检查模块#

查看模型架构的最简单方法是打印它。按照上面的示例,您可以使用以下代码打印 MLP

print(mlp)

这将显示

MLP(
  (layers.0): Linear(input_dims=2, output_dims=128, bias=True)
  (layers.1): Linear(input_dims=128, output_dims=128, bias=True)
  (layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)

要获取有关 Module 中数组的更详细信息,您可以在参数上使用 mlx.utils.tree_map()。例如,要查看 Module 中所有参数的形状,请执行以下操作:

from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())

再例如,您可以使用以下代码计算 Module 中的参数数量:

from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))

值和梯度#

使用 Module 并不排除使用 MLX 的高阶函数变换(mlx.core.value_and_grad()mlx.core.grad() 等)。然而,这些函数变换假设函数是纯函数,即参数应作为参数传递给要转换的函数。

使用 MLX 模块可以轻松实现这一模式

model = ...

def f(params, other_inputs):
    model.update(params)  # <---- Necessary to make the model use the passed parameters
    return model(other_inputs)

f(model.trainable_parameters(), mx.zeros((10,)))

然而,mlx.nn.value_and_grad() 精确地提供了这种模式,并且只计算相对于模型可训练参数的梯度。

详细来说

  • 它用一个调用 Module.update() 的函数包装传递的函数,以确保模型使用提供的参数。

  • 它调用 mlx.core.value_and_grad() 将函数转换为一个同时计算相对于传递参数的梯度的函数。

  • 它用一个函数包装返回的函数,该函数将可训练参数作为第一个参数传递给 mlx.core.value_and_grad() 返回的函数。

value_and_grad(model, fn)

将传入的函数 fn 转换为一个计算 fn 关于模型可训练参数的梯度及其值的函数。

quantize(model[, group_size, bits, ...])

根据谓词量化模块的子模块。

average_gradients(gradients[, group, ...])

在传递的组中的分布式进程中平均梯度。