mlx.utils.tree_flatten

mlx.utils.tree_flatten#

tree_flatten(tree: Any, prefix: str = '', is_leaf: Callable | None = None) Any#

将 Python 树展平为键值元组的列表。

键使用点表示法来定义任意深度和复杂度的树。

from mlx.utils import tree_flatten

print(tree_flatten([[[0]]]))
# [("0.0.0", 0)]

print(tree_flatten([[[0]]], ".hello"))
# [("hello.0.0.0", 0)]

注意

字典的键应该是合法的 Python 标识符。

参数:
  • tree (Any) – 要展平的 Python 树。

  • prefix (str) – 用于键的前缀。第一个字符总是被丢弃。

  • is_leaf (callable) – 一个可选的可调用对象,如果传入的对象被视为叶节点,则返回 True,否则返回 False。

返回:

Python 树的平面表示。

返回类型:

List[Tuple[str, Any]]