mlx.nn.Module.freeze

mlx.nn.Module.freeze#

Module.freeze(*, recurse: bool = True, keys: str | List[str] |None = None, strict: bool = False) Module#

冻结模块的参数或其中一部分。冻结参数意味着不对其计算梯度。

此函数是幂等的,即冻结一个已冻结的模型不会产生任何操作。

示例

例如,仅训练 Transformer 中的注意力参数

model = nn.Transformer()
model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
参数:
  • recurse (bool, 可选) – 如果为 True,则同时冻结子模块的参数。默认值: True

  • keys (strlist[str], 可选) – 如果提供,则仅冻结这些参数,否则冻结模块的所有参数。例如,通过调用 module.freeze(keys="bias") 冻结所有偏置。

  • strict (bool, 可选) – 如果设置为 True,则验证传递的键是否存在。默认值: False

返回:

冻结参数后的模块实例。