模块

目录

模块#

class Module#

使用 MLX 构建神经网络的基类。

mlx.nn.layers 中提供的所有层都继承此类,您的模型也应如此。

一个 Module 可以包含其他 Module 实例,或以任意嵌套的 Python 列表或字典形式包含 mlx.core.array 实例。然后,通过使用 mlx.nn.Module.parameters()Module 允许递归提取所有 mlx.core.array 实例。

此外,Module 具有可训练和不可训练(称为“冻结”)参数的概念。当使用 mlx.nn.value_and_grad() 时,梯度仅针对可训练参数返回。模块中的所有数组都是可训练的,除非通过调用 freeze() 将它们添加到“冻结”集合中。

import mlx.core as mx
import mlx.nn as nn

class MyMLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
        super().__init__()

        self.in_proj = nn.Linear(in_dims, hidden_dims)
        self.out_proj = nn.Linear(hidden_dims, out_dims)

    def __call__(self, x):
        x = self.in_proj(x)
        x = mx.maximum(x, 0)
        return self.out_proj(x)

model = MyMLP(2, 1)

# All the model parameters are created but since MLX is lazy by
# default, they are not evaluated yet. Calling `mx.eval` actually
# allocates memory and initializes the parameters.
mx.eval(model.parameters())

# Setting a parameter to a new value is as simply as accessing that
# parameter and assigning a new array to it.
model.in_proj.weight = model.in_proj.weight * 2
mx.eval(model.parameters())

属性

Module.training

布尔值,指示模型是否处于训练模式。

Module.state

模块的状态字典

方法

Module.apply(map_fn[, filter_fn])

使用提供的 map_fn 映射所有参数,并立即使用映射后的参数更新模块。

Module.apply_to_modules(apply_fn)

对该实例中的所有模块(包括此实例)应用函数。

Module.children()

返回此 Module 实例的直接后代。

Module.eval()

将模型设置为评估模式。

Module.filter_and_map(filter_fn[, map_fn, ...])

使用 filter_fn 递归过滤模块的内容,即仅选择 filter_fn 返回 true 的键和值。

Module.freeze(*[, recurse, keys, strict])

冻结 Module 的参数或其中一部分。

Module.leaf_modules()

返回不包含其他模块的子模块。

Module.load_weights(file_or_weights[, strict])

.npz 文件、.safetensors 文件或列表更新模型的权重。

Module.modules()

返回包含此实例中所有模块的列表。

Module.named_modules()

返回包含此实例中所有模块及其点表示名称的列表。

Module.parameters()

递归返回此 Module 的所有 mlx.core.array 成员,作为字典和列表的字典。

Module.save_weights(file)

将模型的权重保存到文件。

Module.set_dtype(dtype[, predicate])

设置模块参数的数据类型 (dtype)。

Module.train([mode])

将模型设置为进入或退出训练模式。

Module.trainable_parameters()

递归返回此 Module 的所有非冻结 mlx.core.array 成员,作为字典和列表的字典。

Module.unfreeze(*[, recurse, keys, strict])

解冻 Module 的参数或其中一部分。

Module.update(parameters)

用提供的字典和列表的字典中的参数替换此 Module 的参数。

Module.update_modules(modules)

用提供的字典和列表的字典中的子模块替换此 Module 实例的子模块。