mlx.nn.Module.apply

mlx.nn.Module.apply#

Module.apply(map_fn: Callable[[array], array], filter_fn: Callable[[Module, str, Any], bool] | None = None) Module#

使用提供的 map_fn 映射所有参数,并立即使用映射后的参数更新模块。

例如,运行 model.apply(lambda x: x.astype(mx.float16)) 将所有参数转换为 16 位浮点数。

参数:
  • map_fn (Callable) – 将一个数组映射到另一个数组

  • filter_fn (Callable, optional) – 用于选择哪些数组进行映射的过滤器(默认值:Module.valid_parameter_filter())。

返回值:

更新参数后的模块实例。