mlx.nn.Module.apply#
- Module.apply(map_fn: Callable[[array], array], filter_fn: Callable[[Module, str, Any], bool] | None = None) Module #
使用提供的
map_fn
映射所有参数,并立即使用映射后的参数更新模块。例如,运行
model.apply(lambda x: x.astype(mx.float16))
将所有参数转换为 16 位浮点数。- 参数:
map_fn (Callable) – 将一个数组映射到另一个数组
filter_fn (Callable, optional) – 用于选择哪些数组进行映射的过滤器(默认值:
Module.valid_parameter_filter()
)。
- 返回值:
更新参数后的模块实例。