神经网络#
在 MLX 中编写任意复杂的神经网络,只需使用 mlx.core.array
和 mlx.core.value_and_grad()
。然而,这要求用户一遍又一遍地编写相同的简单神经网络操作,并手动显式地处理所有参数状态和初始化。
模块 mlx.nn
通过提供一种直观的方式来组合神经网络层、初始化参数、冻结参数以进行微调等,解决了这个问题。
神经网络快速入门#
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int):
super().__init__()
self.layers = [
nn.Linear(in_dims, 128),
nn.Linear(128, 128),
nn.Linear(128, out_dims),
]
def __call__(self, x):
for i, l in enumerate(self.layers):
x = mx.maximum(x, 0) if i > 0 else x
x = l(x)
return x
# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)
# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)
# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])
# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())
# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
# it from the local scope. It could be a positional argument or a
# keyword argument.
def l2_loss(x, y):
y_hat = mlp(x)
return (y_hat - y).square().mean()
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
模块类#
任何神经网络库的核心都是 Module
类。在 MLX 中,Module
类是 mlx.core.array
或 Module
实例的容器。它的主要功能是提供一种递归地 访问 和 更新 其参数及其子模块参数的方法。
参数#
模块的参数是类型为 mlx.core.array
的任何公共成员(其名称不应以 _
开头)。它可以任意嵌套在其他 Module
实例或列表和字典中。
Module.parameters()
可用于提取一个嵌套字典,其中包含模块及其子模块的所有参数。
一个 Module
还可以跟踪“冻结”参数。有关详细信息,请参阅 Module.freeze()
方法。通过 mlx.nn.value_and_grad()
返回的梯度将是相对于这些可训练参数的。
更新参数#
MLX 模块允许访问和更新单个参数。然而,大多数时候我们需要更新模块参数的大部分子集。此操作通过 Module.update()
完成。
检查模块#
查看模型架构的最简单方法是打印它。按照上面的示例,您可以使用以下代码打印 MLP
:
print(mlp)
这将显示
MLP(
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)
要获取有关 Module
中数组的更详细信息,您可以在参数上使用 mlx.utils.tree_map()
。例如,要查看 Module
中所有参数的形状,请执行以下操作:
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
再例如,您可以使用以下代码计算 Module
中的参数数量:
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
值和梯度#
使用 Module
并不排除使用 MLX 的高阶函数变换(mlx.core.value_and_grad()
、mlx.core.grad()
等)。然而,这些函数变换假设函数是纯函数,即参数应作为参数传递给要转换的函数。
使用 MLX 模块可以轻松实现这一模式
model = ...
def f(params, other_inputs):
model.update(params) # <---- Necessary to make the model use the passed parameters
return model(other_inputs)
f(model.trainable_parameters(), mx.zeros((10,)))
然而,mlx.nn.value_and_grad()
精确地提供了这种模式,并且只计算相对于模型可训练参数的梯度。
详细来说
它用一个调用
Module.update()
的函数包装传递的函数,以确保模型使用提供的参数。它调用
mlx.core.value_and_grad()
将函数转换为一个同时计算相对于传递参数的梯度的函数。它用一个函数包装返回的函数,该函数将可训练参数作为第一个参数传递给
mlx.core.value_and_grad()
返回的函数。
|
将传入的函数 |
|
根据谓词量化模块的子模块。 |
|
在传递的组中的分布式进程中平均梯度。 |
- 模块
模块
- mlx.nn.Module.training
- mlx.nn.Module.state
- mlx.nn.Module.apply
- mlx.nn.Module.apply_to_modules
- mlx.nn.Module.children
- mlx.nn.Module.eval
- mlx.nn.Module.filter_and_map
- mlx.nn.Module.freeze
- mlx.nn.Module.leaf_modules
- mlx.nn.Module.load_weights
- mlx.nn.Module.modules
- mlx.nn.Module.named_modules
- mlx.nn.Module.parameters
- mlx.nn.Module.save_weights
- mlx.nn.Module.set_dtype
- mlx.nn.Module.train
- mlx.nn.Module.trainable_parameters
- mlx.nn.Module.unfreeze
- mlx.nn.Module.update
- mlx.nn.Module.update_modules
- 层
- mlx.nn.ALiBi
- mlx.nn.AvgPool1d
- mlx.nn.AvgPool2d
- mlx.nn.AvgPool3d
- mlx.nn.BatchNorm
- mlx.nn.CELU
- mlx.nn.Conv1d
- mlx.nn.Conv2d
- mlx.nn.Conv3d
- mlx.nn.ConvTranspose1d
- mlx.nn.ConvTranspose2d
- mlx.nn.ConvTranspose3d
- mlx.nn.Dropout
- mlx.nn.Dropout2d
- mlx.nn.Dropout3d
- mlx.nn.Embedding
- mlx.nn.ELU
- mlx.nn.GELU
- mlx.nn.GLU
- mlx.nn.GroupNorm
- mlx.nn.GRU
- mlx.nn.HardShrink
- mlx.nn.HardTanh
- mlx.nn.Hardswish
- mlx.nn.InstanceNorm
- mlx.nn.LayerNorm
- mlx.nn.LeakyReLU
- mlx.nn.Linear
- mlx.nn.LogSigmoid
- mlx.nn.LogSoftmax
- mlx.nn.LSTM
- mlx.nn.MaxPool1d
- mlx.nn.MaxPool2d
- mlx.nn.MaxPool3d
- mlx.nn.Mish
- mlx.nn.MultiHeadAttention
- mlx.nn.PReLU
- mlx.nn.QuantizedEmbedding
- mlx.nn.QuantizedLinear
- mlx.nn.RMSNorm
- mlx.nn.ReLU
- mlx.nn.ReLU6
- mlx.nn.RNN
- mlx.nn.RoPE
- mlx.nn.SELU
- mlx.nn.Sequential
- mlx.nn.Sigmoid
- mlx.nn.SiLU
- mlx.nn.SinusoidalPositionalEncoding
- mlx.nn.Softmin
- mlx.nn.Softshrink
- mlx.nn.Softsign
- mlx.nn.Softmax
- mlx.nn.Softplus
- mlx.nn.Step
- mlx.nn.Tanh
- mlx.nn.Transformer
- mlx.nn.Upsample
- 函数
- mlx.nn.elu
- mlx.nn.celu
- mlx.nn.gelu
- mlx.nn.gelu_approx
- mlx.nn.gelu_fast_approx
- mlx.nn.glu
- mlx.nn.hard_shrink
- mlx.nn.hard_tanh
- mlx.nn.hardswish
- mlx.nn.leaky_relu
- mlx.nn.log_sigmoid
- mlx.nn.log_softmax
- mlx.nn.mish
- mlx.nn.prelu
- mlx.nn.relu
- mlx.nn.relu6
- mlx.nn.selu
- mlx.nn.sigmoid
- mlx.nn.silu
- mlx.nn.softmax
- mlx.nn.softmin
- mlx.nn.softplus
- mlx.nn.softshrink
- mlx.nn.step
- mlx.nn.tanh
- 损失函数
- mlx.nn.losses.binary_cross_entropy
- mlx.nn.losses.cosine_similarity_loss
- mlx.nn.losses.cross_entropy
- mlx.nn.losses.gaussian_nll_loss
- mlx.nn.losses.hinge_loss
- mlx.nn.losses.huber_loss
- mlx.nn.losses.kl_div_loss
- mlx.nn.losses.l1_loss
- mlx.nn.losses.log_cosh_loss
- mlx.nn.losses.margin_ranking_loss
- mlx.nn.losses.mse_loss
- mlx.nn.losses.nll_loss
- mlx.nn.losses.smooth_l1_loss
- mlx.nn.losses.triplet_loss
- 初始化器