mlx.utils.tree_map#
- tree_map(fn: Callable, tree: Any, *rest: Any, is_leaf: Callable | None = None) Any#
将
fn应用于 Python 树tree的叶子,并返回包含结果的新集合。如果提供了
rest,则假定每个项都是tree的超集,并且相应的叶子作为额外的 positional arguments 提供给fn。在这方面,tree_map()更类似于itertools.starmap()而不是map()。关键字参数
is_leaf决定tree中哪些元素被视为叶子,类似于tree_flatten()。import mlx.nn as nn from mlx.utils import tree_map model = nn.Linear(10, 10) print(model.parameters().keys()) # dict_keys(['weight', 'bias']) # square the parameters model.update(tree_map(lambda x: x*x, model.parameters()))
- 参数:
fn (callable) – 处理树叶子的函数。
tree (Any) – 将被迭代的主要 Python 树。
rest (tuple[Any]) – 与
tree一起迭代的额外树。is_leaf (callable, optional) – 一个可选的可调用对象,如果传入的对象被视为叶子则返回
True,否则返回False。
- 返回值:
一个 Python 树,其中包含由
fn返回的新值。