mlx.utils.tree_map_with_path

mlx.utils.tree_map_with_path#

tree_map_with_path(fn: Callable, tree: Any, *rest: Any, is_leaf: Callable | None = None, path: Any | None = None) Any#

fn 应用于 Python 树 tree 的路径和叶节点,并返回一个包含结果的新集合。

此函数与 tree_map() 相同,但 fn 将路径作为第一个参数,后跟其余的树节点。

参数:
  • fn (callable) – 处理树的叶节点的函数。

  • tree (Any) – 将被迭代处理的主要 Python 树。

  • rest (tuple[Any]) – 与 tree 一起迭代处理的额外树。

  • is_leaf (Optional[Callable]) – 一个可选的可调用对象,如果传入的对象被视为叶节点,则返回 True,否则返回 False

  • path (Optional[Any]) – 将添加到结果中的前缀。

返回:

一个包含 fn 返回的新值的 Python 树。

示例

>>> from mlx.utils import tree_map_with_path
>>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
>>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
model.0.w
model.0.b
model.1.w
model.1.b