mlx.utils.tree_map#
- tree_map(fn: Callable, tree: Any, *rest: Any, is_leaf: Callable | None = None) Any #
将
fn
应用于 Python 树tree
的叶子,并返回包含结果的新集合。如果提供了
rest
,则假定每个项都是tree
的超集,并且相应的叶子作为额外的 positional arguments 提供给fn
。在这方面,tree_map()
更类似于itertools.starmap()
而不是map()
。关键字参数
is_leaf
决定tree
中哪些元素被视为叶子,类似于tree_flatten()
。import mlx.nn as nn from mlx.utils import tree_map model = nn.Linear(10, 10) print(model.parameters().keys()) # dict_keys(['weight', 'bias']) # square the parameters model.update(tree_map(lambda x: x*x, model.parameters()))
- 参数:
fn (callable) – 处理树叶子的函数。
tree (Any) – 将被迭代的主要 Python 树。
rest (tuple[Any]) – 与
tree
一起迭代的额外树。is_leaf (callable, optional) – 一个可选的可调用对象,如果传入的对象被视为叶子则返回
True
,否则返回False
。
- 返回值:
一个 Python 树,其中包含由
fn
返回的新值。