mlx.nn.Module.load_weights

mlx.nn.Module.load_weights#

Module.load_weights(file_or_weights: str | List[Tuple[str, array]], strict: bool = True) Module#

.npz.safetensors 文件或列表中更新模型的权重。

参数:
  • file_or_weights (str or list(tuple(str, mx.array))) – 指向权重 .npz 文件(.npz.safetensors)的路径,或者由参数名和数组组成的对的列表。

  • strict (bool, optional) – 如果为 True,则检查提供的权重是否与模型的参数完全匹配。否则,只加载模型中实际包含的权重,并且不检查形状。默认值:True

返回值:

更新权重后的模块实例。

示例

import mlx.core as mx
import mlx.nn as nn
model = nn.Linear(10, 10)

# Load from file
model.load_weights("weights.npz")

# Load from .safetensors file
model.load_weights("weights.safetensors")

# Load from list
weights = [
    ("weight", mx.random.uniform(shape=(10, 10))),
    ("bias",  mx.zeros((10,))),
]
model.load_weights(weights)

# Missing weight
weights = [
    ("weight", mx.random.uniform(shape=(10, 10))),
]

# Raises a ValueError exception
model.load_weights(weights)

# Ok, only updates the weight but not the bias
model.load_weights(weights, strict=False)