mlx.utils.tree_map_with_path#
- tree_map_with_path(fn: Callable, tree: Any, *rest: Any, is_leaf: Callable | None = None, path: Any | None = None) Any #
将
fn
应用于 Python 树tree
的路径和叶节点,并返回一个包含结果的新集合。此函数与
tree_map()
相同,但fn
将路径作为第一个参数,后跟其余的树节点。- 参数:
fn (callable) – 处理树的叶节点的函数。
tree (Any) – 将被迭代处理的主要 Python 树。
rest (tuple[Any]) – 与
tree
一起迭代处理的额外树。is_leaf (Optional[Callable]) – 一个可选的可调用对象,如果传入的对象被视为叶节点,则返回
True
,否则返回False
。path (Optional[Any]) – 将添加到结果中的前缀。
- 返回:
一个包含
fn
返回的新值的 Python 树。
示例
>>> from mlx.utils import tree_map_with_path >>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]} >>> new_tree = tree_map_with_path(lambda path, _: print(path), tree) model.0.w model.0.b model.1.w model.1.b