mlx.utils.tree_map

目录

mlx.utils.tree_map#

tree_map(fn: Callable, tree: Any, *rest: Any, is_leaf: Callable | None = None) Any#

fn 应用于 Python 树 tree 的叶子,并返回包含结果的新集合。

如果提供了 rest,则假定每个项都是 tree 的超集,并且相应的叶子作为额外的 positional arguments 提供给 fn。在这方面,tree_map() 更类似于 itertools.starmap() 而不是 map()

关键字参数 is_leaf 决定 tree 中哪些元素被视为叶子,类似于 tree_flatten()

import mlx.nn as nn
from mlx.utils import tree_map

model = nn.Linear(10, 10)
print(model.parameters().keys())
# dict_keys(['weight', 'bias'])

# square the parameters
model.update(tree_map(lambda x: x*x, model.parameters()))
参数:
  • fn (callable) – 处理树叶子的函数。

  • tree (Any) – 将被迭代的主要 Python 树。

  • rest (tuple[Any]) – 与 tree 一起迭代的额外树。

  • is_leaf (callable, optional) – 一个可选的可调用对象,如果传入的对象被视为叶子则返回 True,否则返回 False

返回值:

一个 Python 树,其中包含由 fn 返回的新值。