mlx.nn.Module.load_weights#
- Module.load_weights(file_or_weights: str | List[Tuple[str, array]], strict: bool = True) Module #
从
.npz
、.safetensors
文件或列表中更新模型的权重。- 参数:
- 返回值:
更新权重后的模块实例。
示例
import mlx.core as mx import mlx.nn as nn model = nn.Linear(10, 10) # Load from file model.load_weights("weights.npz") # Load from .safetensors file model.load_weights("weights.safetensors") # Load from list weights = [ ("weight", mx.random.uniform(shape=(10, 10))), ("bias", mx.zeros((10,))), ] model.load_weights(weights) # Missing weight weights = [ ("weight", mx.random.uniform(shape=(10, 10))), ] # Raises a ValueError exception model.load_weights(weights) # Ok, only updates the weight but not the bias model.load_weights(weights, strict=False)