模块#
- class Module#
使用 MLX 构建神经网络的基类。
mlx.nn.layers
中提供的所有层都继承此类,您的模型也应如此。一个
Module
可以包含其他Module
实例,或以任意嵌套的 Python 列表或字典形式包含mlx.core.array
实例。然后,通过使用mlx.nn.Module.parameters()
,Module
允许递归提取所有mlx.core.array
实例。此外,
Module
具有可训练和不可训练(称为“冻结”)参数的概念。当使用mlx.nn.value_and_grad()
时,梯度仅针对可训练参数返回。模块中的所有数组都是可训练的,除非通过调用freeze()
将它们添加到“冻结”集合中。import mlx.core as mx import mlx.nn as nn class MyMLP(nn.Module): def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16): super().__init__() self.in_proj = nn.Linear(in_dims, hidden_dims) self.out_proj = nn.Linear(hidden_dims, out_dims) def __call__(self, x): x = self.in_proj(x) x = mx.maximum(x, 0) return self.out_proj(x) model = MyMLP(2, 1) # All the model parameters are created but since MLX is lazy by # default, they are not evaluated yet. Calling `mx.eval` actually # allocates memory and initializes the parameters. mx.eval(model.parameters()) # Setting a parameter to a new value is as simply as accessing that # parameter and assigning a new array to it. model.in_proj.weight = model.in_proj.weight * 2 mx.eval(model.parameters())
属性
布尔值,指示模型是否处于训练模式。
模块的状态字典
方法
Module.apply
(map_fn[, filter_fn])使用提供的
map_fn
映射所有参数,并立即使用映射后的参数更新模块。Module.apply_to_modules
(apply_fn)对该实例中的所有模块(包括此实例)应用函数。
返回此 Module 实例的直接后代。
将模型设置为评估模式。
Module.filter_and_map
(filter_fn[, map_fn, ...])使用
filter_fn
递归过滤模块的内容,即仅选择filter_fn
返回 true 的键和值。Module.freeze
(*[, recurse, keys, strict])冻结 Module 的参数或其中一部分。
返回不包含其他模块的子模块。
Module.load_weights
(file_or_weights[, strict])从
.npz
文件、.safetensors
文件或列表更新模型的权重。返回包含此实例中所有模块的列表。
返回包含此实例中所有模块及其点表示名称的列表。
递归返回此 Module 的所有
mlx.core.array
成员,作为字典和列表的字典。Module.save_weights
(file)将模型的权重保存到文件。
Module.set_dtype
(dtype[, predicate])设置模块参数的数据类型 (dtype)。
Module.train
([mode])将模型设置为进入或退出训练模式。
递归返回此 Module 的所有非冻结
mlx.core.array
成员,作为字典和列表的字典。Module.unfreeze
(*[, recurse, keys, strict])解冻 Module 的参数或其中一部分。
Module.update
(parameters)用提供的字典和列表的字典中的参数替换此 Module 的参数。
Module.update_modules
(modules)用提供的字典和列表的字典中的子模块替换此
Module
实例的子模块。