mlx.nn.Module.filter_and_map

mlx.nn.Module.filter_and_map#

Module.filter_and_map(filter_fn: Callable[[Module, str, Any], bool], map_fn: Callable | None = None, is_leaf_fn: Callable[[Module, str, Any], bool] | None = None)#

使用 filter_fn 递归过滤模块的内容,即只选择 filter_fn 返回 true 的键和值。

这用于实现 parameters()trainable_parameters(),但它也可以用于提取模块参数的任意子集。

参数:
  • filter_fn (Callable) – 给定一个值、找到该值的键以及包含该值的模块,决定是保留该值还是丢弃它。

  • map_fn (Callable, optional) – 可选地在返回该值之前对其进行变换。

  • is_leaf_fn (Callable, optional) – 给定一个值、找到该值的键以及包含该值的模块,决定它是否是叶子节点。

返回:

一个字典,包含递归过滤后的模块内容