Source code for treelib.tree

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2011
# Brett Alistair Kromkamp - brettkromkamp@gmail.com
# Copyright (C) 2012-2025
# Xiaming Chen - chenxm35@gmail.com
# and other contributors.
# All rights reserved.
"""
Tree structure in `treelib`.

The :class:`Tree` object defines the tree-like structure based on :class:`Node` objects.
Trees provide hierarchical data organization with efficient operations for traversal,
modification, search, and visualization.

A new tree can be created from scratch without any parameter or a shallow/deep copy of another tree.
When deep=True, a deepcopy operation is performed on feeding tree parameter and more memory
is required to create the tree.

Key Features:
    - O(1) node lookup and access
    - Multiple traversal modes (depth-first, breadth-first, zigzag)
    - Flexible tree modification (add, remove, move, paste)
    - Rich display options with customizable formatting
    - JSON/dict export and import capabilities
    - Subtree operations and filtering
    - Tree metrics and analysis tools
"""
from __future__ import print_function, unicode_literals

try:
    from builtins import str as text
except ImportError:
    from __builtin__ import str as text  # type: ignore

import codecs
import json
import sys
import uuid
from copy import deepcopy
from typing import Any, Callable, List, Optional, Union, cast

from six import iteritems, python_2_unicode_compatible

try:
    from StringIO import StringIO  # type: ignore
except ImportError:
    from io import StringIO

from .exceptions import (
    DuplicatedNodeIdError,
    InvalidLevelNumber,
    LinkPastRootNodeError,
    LoopError,
    MultipleRootError,
    NodeIDAbsentError,
)
from .node import Node

if sys.version_info >= (3, 9):
    StrList = list[str]
    StrLList = list[list[str]]
    NodeList = list[Node]
else:
    StrList = List[str]  # Python 3.8 and earlier
    StrLList = List[List[str]]
    NodeList = List[Node]


__author__ = "chenxm"


[docs] @python_2_unicode_compatible class Tree(object): """ Hierarchical tree data structure. A Tree is a collection of Node objects organized in a hierarchical structure with exactly one root node and zero or more child nodes. Each node (except root) has exactly one parent, but can have multiple children. The tree provides efficient operations for: - Adding, removing, and moving nodes - Traversing the tree in different orders - Searching and filtering nodes - Displaying tree structure - Exporting to various formats Attributes: root (str): Identifier of the root node, or None if tree is empty. nodes (dict): Dictionary mapping node identifiers to Node objects. Constants: ROOT (int): Tree traversal starting from root level. DEPTH (int): Depth-first tree traversal mode. WIDTH (int): Breadth-first tree traversal mode. ZIGZAG (int): Zigzag tree traversal mode. Example: Creating and manipulating a tree:: tree = Tree() # Build structure tree.create_node("Company", "company") tree.create_node("Engineering", "eng", parent="company") tree.create_node("Sales", "sales", parent="company") # Add team members tree.create_node("Alice", "alice", parent="eng") tree.create_node("Bob", "bob", parent="eng") tree.create_node("Carol", "carol", parent="sales") # Display tree tree.show() # Find specific nodes eng_team = tree.children("eng") all_employees = [tree[node].tag for node in tree.expand_tree() if tree.level(node) > 1] # Tree operations tree.move_node("alice", "sales") # Alice moves to sales subtree = tree.subtree("eng") # Get engineering subtree """ #: ROOT, DEPTH, WIDTH, ZIGZAG constants : (ROOT, DEPTH, WIDTH, ZIGZAG) = list(range(4)) node_class = Node def __contains__(self, identifier: str) -> bool: """ Check if a node with the given identifier exists in this tree. Implements the 'in' operator for tree membership testing, providing a convenient way to check node existence without raising exceptions. Args: identifier: Node identifier to check for existence. Returns: bool: True if node exists in tree, False otherwise. Example: Checking node membership:: if "user123" in tree: print("User node exists") user = tree["user123"] else: print("User node not found") # More concise than try/except approach exists = "node_id" in tree # True or False """ return identifier in self.nodes.keys() def __init__( self, tree=None, deep: bool = False, node_class=None, identifier: Optional[str] = None, ) -> None: """ Initialize a new tree or copy another tree. Creates either an empty tree or a copy of an existing tree. When copying, you can choose between shallow copy (references to same nodes) or deep copy (completely independent nodes and data). Args: tree: Existing Tree object to copy from. If None, creates empty tree. deep: If True, perform deep copy of all nodes and their data. If False, creates shallow copy sharing node references. node_class: Custom Node class to use instead of default Node. Must be subclass of Node. identifier: Unique identifier for this tree instance. If None, generates UUID automatically. Example: Different ways to create trees:: # Empty tree tree1 = Tree() # Tree with custom identifier tree2 = Tree(identifier="my_tree") # Shallow copy of existing tree tree3 = Tree(tree1) # Deep copy with independent data tree4 = Tree(tree1, deep=True) # Custom node class class MyNode(Node): def __init__(self, tag, identifier=None): super().__init__(tag, identifier) self.custom_attr = "value" tree5 = Tree(node_class=MyNode) Raises: AssertionError: If node_class is not a subclass of Node. """ self._identifier: Optional[str] = None self._set_identifier(identifier) if node_class: assert issubclass(node_class, Node) self.node_class = node_class #: dictionary, identifier: Node object self._nodes = {} #: Get or set the identifier of the root. This attribute can be accessed and modified #: with ``.`` and ``=`` operator respectively. self.root: Optional[str] = None if tree is not None: self.root = tree.root for nid, node in iteritems(tree.nodes): new_node = deepcopy(node) if deep else node self._nodes[nid] = new_node if tree.identifier != self._identifier: new_node.clone_pointers(tree.identifier, self._identifier) def _clone( self, identifier: Optional[str] = None, with_tree: bool = False, deep: bool = False, ): """ Create a clone of this tree instance with optional content copying. This method is designed to be overridden by subclasses to enable proper cloning of extended tree classes. It provides the foundation for subtree and remove_subtree operations while maintaining polymorphic behavior. Args: identifier: Unique identifier for the new tree instance. If None, generates UUID automatically. with_tree: If True, copy all nodes from current tree. If False, create empty tree with same class. deep: If True and with_tree=True, perform deep copy of node data. If False, create shallow copy sharing node references. Returns: Tree: New tree instance of the same class as this tree. Example: Subclassing with custom clone behavior:: class EnhancedTree(Tree): def __init__(self, metadata=None, **kwargs): super().__init__(**kwargs) self.metadata = metadata or {} def _clone(self, identifier=None, with_tree=False, deep=False): return EnhancedTree( metadata=self.metadata.copy(), identifier=identifier, tree=self if with_tree else None, deep=deep ) # Custom tree operations preserve metadata enhanced = EnhancedTree(metadata={"version": "1.0"}) subtree = enhanced.subtree("node_id") # Preserves metadata """ return self.__class__(identifier=identifier, tree=self if with_tree else None, deep=deep) @property def identifier(self) -> Optional[str]: """ Get the unique identifier of this tree instance. Each tree has its own unique identifier used to distinguish it from other trees, especially when nodes exist in multiple trees simultaneously. This identifier is automatically generated if not provided during creation. Returns: str or None: The unique identifier of this tree instance. Example: Using tree identifiers:: tree1 = Tree(identifier="main_tree") tree2 = Tree() # Auto-generated identifier print(tree1.identifier) # "main_tree" print(tree2.identifier) # UUID string like "abc123..." # Used internally for node relationships node.predecessor(tree1.identifier) # Parent in tree1 """ return self._identifier def _set_identifier(self, nid: Optional[str]) -> None: """ Initialize tree identifier with given value or auto-generate one. Private method used during tree creation to set the unique identifier. If no identifier is provided, generates a UUID automatically to ensure uniqueness across tree instances. Args: nid: Desired tree identifier, or None to auto-generate. Note: This is an internal method used during tree initialization. The identifier should not be changed after tree creation. """ if nid is None: self._identifier = str(uuid.uuid1()) else: self._identifier = nid def __getitem__(self, key: str) -> Node: """ Get node by identifier using bracket notation. Provides convenient dictionary-style access to nodes. Raises exception if node doesn't exist, making it clear when invalid identifiers are used. For safer access that returns None instead of raising, use get_node(). Args: key: Node identifier to retrieve. Returns: Node: The node object with the specified identifier. Raises: NodeIDAbsentError: If no node with the given identifier exists. Example: Accessing nodes by identifier:: # Direct access (raises exception if not found) user_node = tree["user123"] profile_node = tree["user123_profile"] # Use with in operator for safety if "user123" in tree: user_node = tree["user123"] print(user_node.tag) """ try: return self._nodes[key] except KeyError: raise NodeIDAbsentError("Node '%s' is not in the tree" % key) def __len__(self) -> int: """ Get the total number of nodes in this tree. Returns the count of all nodes currently in the tree, including the root node. Useful for tree size analysis and iteration bounds. Returns: int: Total number of nodes in the tree. Example: Getting tree size:: tree_size = len(tree) print(f"Tree has {tree_size} nodes") # Empty tree check if len(tree) == 0: print("Tree is empty") # Use in comparisons if len(tree1) > len(tree2): print("Tree1 is larger") """ return len(self._nodes) def __str__(self) -> str: """ Get string representation of the tree structure. Returns a formatted tree visualization showing the hierarchical structure with default display settings. Useful for debugging, logging, and quick inspection of tree contents. Returns: str: Formatted tree structure as string. Example: Tree string representation:: tree = Tree() tree.create_node("Root", "root") tree.create_node("Child A", "a", parent="root") tree.create_node("Child B", "b", parent="root") print(str(tree)) # Output: # Root # ├── Child A # └── Child B # Use in logging logger.info(f"Current tree structure:\n{tree}") """ self._reader = "" def write(line): self._reader += line.decode("utf-8") + "\n" self.__print_backend(func=write) return self._reader def __print_backend( self, nid: Optional[str] = None, level: int = ROOT, idhidden: bool = True, filter: Optional[Callable[[Node], bool]] = None, key: Optional[Callable[[Node], Node]] = None, reverse: bool = False, line_type="ascii-ex", data_property=None, sorting: bool = True, func: Callable[[bytes], None] = print, ): """ Another implementation of printing tree using Stack Print tree structure in hierarchy style. For example: .. code-block:: bash Root |___ C01 | |___ C11 | |___ C111 | |___ C112 |___ C02 |___ C03 | |___ C31 A more elegant way to achieve this function using Stack structure, for constructing the Nodes Stack push and pop nodes with additional level info. UPDATE: the @key @reverse is present to sort node at each level. """ # Factory for proper get_label() function if data_property: if idhidden: def get_label(node: Node) -> str: return getattr(node.data, data_property) else: def get_label(node: Node) -> str: return "%s[%s]" % ( getattr(node.data, data_property), node.identifier, ) else: if idhidden: def get_label(node: Node) -> str: return node.tag else: def get_label(node: Node) -> str: return "%s[%s]" % (node.tag, node.identifier) # legacy ordering if sorting and key is None: def key(node): return node # iter with func for pre, node in self.__get(nid, level, filter, key, reverse, line_type, sorting): label = get_label(node) func("{0}{1}".format(pre, label).encode("utf-8")) def __get( self, nid: Optional[str], level: int, filter_: Optional[Callable[[Node], bool]], key: Optional[Callable[[Node], Node]], reverse: bool, line_type: str, sorting: bool, ): # default filter if filter_ is None: def filter_(node): return True # render characters dt = { "ascii": ("|", "|-- ", "+-- "), "ascii-ex": ("\u2502", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 "), "ascii-exr": ("\u2502", "\u251c\u2500\u2500 ", "\u2570\u2500\u2500 "), "ascii-em": ("\u2551", "\u2560\u2550\u2550 ", "\u255a\u2550\u2550 "), "ascii-emv": ("\u2551", "\u255f\u2500\u2500 ", "\u2559\u2500\u2500 "), "ascii-emh": ("\u2502", "\u255e\u2550\u2550 ", "\u2558\u2550\u2550 "), }[line_type] return self.__get_iter(nid, level, filter_, key, reverse, dt, [], sorting) def __get_iter( self, nid: Optional[str], level: int, filter_: Callable[[Node], bool], key: Optional[Callable[[Node], Node]], reverse: bool, dt, # tuple[str, str, str] is_last: list, sorting: bool, ): dt_vertical_line, dt_line_box, dt_line_corner = dt nid = cast(str, self.root if nid is None else nid) node = self[nid] if level == self.ROOT: yield "", node else: leading = "".join( map( lambda x: dt_vertical_line + " " * 3 if not x else " " * 4, is_last[0:-1], ) ) lasting = dt_line_corner if is_last[-1] else dt_line_box yield leading + lasting, node if filter_(node) and node.expanded: children = [self[i] for i in node.successors(self._identifier) if filter_(self[i])] idxlast = len(children) - 1 if sorting: if key: children.sort(key=key, reverse=reverse) elif reverse: children = list(reversed(children)) level += 1 for idx, child in enumerate(children): is_last.append(idx == idxlast) for item in self.__get_iter(child.identifier, level, filter_, key, reverse, dt, is_last, sorting): yield item is_last.pop() def __update_bpointer(self, nid, parent_id): """set self[nid].bpointer""" self[nid].set_predecessor(parent_id, self._identifier) def __update_fpointer(self, nid: str, child_id: str, mode: int): if nid is None: return else: self[nid].update_successors(child_id, mode, tree_id=self._identifier)
[docs] def add_node(self, node: Node, parent: Optional[Union[Node, str]] = None) -> None: """ Add an existing node object to the tree. Integrates a pre-created Node instance into the tree structure by establishing parent-child relationships and updating internal mappings. For creating and adding nodes simultaneously, use create_node() instead. Args: node: Existing Node instance to add to the tree. Must be an instance of the tree's node_class. parent: Parent node (Node object or identifier string), or None to make this node the root. If None, tree must be empty. Raises: OSError: If node is not an instance of the expected node class. DuplicatedNodeIdError: If node identifier already exists in tree. MultipleRootError: If parent is None but tree already has a root. NodeIDAbsentError: If parent identifier doesn't exist in tree. Example: Adding pre-created nodes:: # Create nodes first root_node = Node("Company", "company") dept_node = Node("Engineering", "eng") # Add to tree tree.add_node(root_node) # Root node tree.add_node(dept_node, parent="company") # Child node # Adding with Node object as parent mgr_node = Node("Manager", "mgr") tree.add_node(mgr_node, parent=dept_node) """ if not isinstance(node, self.node_class): raise OSError("First parameter must be object of {}".format(self.node_class)) if node.identifier in self._nodes: raise DuplicatedNodeIdError("Can't create node " "with ID '%s'" % node.identifier) pid = parent.identifier if isinstance(parent, self.node_class) else parent if pid is None: if self.root is not None: raise MultipleRootError("A tree takes one root merely.") else: self.root = node.identifier elif not self.contains(pid): raise NodeIDAbsentError("Parent node '%s' " "is not in the tree" % pid) pid = cast(str, pid) self._nodes.update({node.identifier: node}) self.__update_fpointer(pid, node.identifier, self.node_class.ADD) self.__update_bpointer(node.identifier, pid) node.set_initial_tree_id(cast(str, self._identifier))
[docs] def all_nodes(self) -> NodeList: """ Get all nodes in the tree as a list. Returns a list containing all Node objects currently in the tree. The order is not guaranteed. For ordered traversal, use expand_tree(). Useful for operations that need to process all nodes simultaneously. Returns: list[Node]: List of all Node objects in the tree. Example: Processing all nodes:: # Get all nodes all_nodes = tree.all_nodes() print(f"Total nodes: {len(all_nodes)}") # Process each node for node in all_nodes: print(f"Node: {node.tag}") # Filter nodes by property leaf_nodes = [node for node in tree.all_nodes() if node.is_leaf(tree.identifier)] """ return list(self._nodes.values())
[docs] def all_nodes_itr(self) -> Any: """ Get all nodes in the tree as an iterator. Returns an iterator over all Node objects in the tree, providing memory-efficient access for large trees. Useful when you need to process nodes one at a time without loading all into memory. Returns: Iterator[Node]: Iterator over all Node objects in the tree. Example: Memory-efficient node processing:: # Iterate without loading all nodes for node in tree.all_nodes_itr(): if node.tag.startswith("temp_"): process_temporary_node(node) # Use with filter operations filtered = filter(lambda n: n.data is not None, tree.all_nodes_itr()) Note: Added by William Rusnack for memory efficiency. """ return self._nodes.values()
[docs] def ancestor(self, nid, level=None): """ Get the ancestor node at a specific level above the given node. Traverses up the tree hierarchy from the specified node to find an ancestor at the desired level. If no level is specified, returns the immediate parent. Useful for navigating tree hierarchies. Args: nid: Identifier of the node to start from. level: Target level of the ancestor (0 = root). If None, returns immediate parent. Returns: str or None: Identifier of the ancestor node, or None if not found. Raises: NodeIDAbsentError: If nid doesn't exist in the tree. InvalidLevelNumber: If level is invalid (>= descendant level). Example: Finding ancestors:: # Get immediate parent parent_id = tree.ancestor("grandchild") # Get ancestor at specific level root_id = tree.ancestor("grandchild", level=0) # Root grandparent_id = tree.ancestor("grandchild", level=1) # Use in hierarchy navigation if tree.ancestor("node", level=0) == "root": print("Node is in main hierarchy") """ if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) descendant = self[nid] ascendant = self[nid]._predecessor[self._identifier] ascendant_level = self.level(ascendant) if level is None: return ascendant elif nid == self.root: return self[nid] elif level >= self.level(descendant.identifier): raise InvalidLevelNumber( "Descendant level (level %s) must be greater \ than its ancestor's level (level %s)" % (str(self.level(descendant.identifier)), level) ) while ascendant is not None: if ascendant_level == level: return self[ascendant] else: descendant = ascendant ascendant = self[descendant]._predecessor[self._identifier] ascendant_level = self.level(ascendant) return None
[docs] def children(self, nid: str) -> NodeList: """ Get all direct children of the specified node. Returns a list of Node objects that are immediate children of the given node. The order follows the insertion order unless the tree has been sorted. Returns empty list if node has no children. Args: nid: Identifier of the parent node. Returns: list[Node]: List of child Node objects. Empty if no children. Raises: NodeIDAbsentError: If nid doesn't exist in the tree. Example: Working with child nodes:: # Get all children child_nodes = tree.children("parent_id") print(f"Parent has {len(child_nodes)} children") # Process each child for child in tree.children("parent_id"): print(f"Child: {child.tag}") # Check if node has children if tree.children("node_id"): print("Node has children") else: print("Node is a leaf") """ return [self[i] for i in self.is_branch(nid)]
[docs] def contains(self, nid): """ Check if the tree contains a node with the given identifier. Determines whether a node with the specified identifier exists in this tree. Equivalent to using the 'in' operator but provided as an explicit method for clarity and consistency. Args: nid: Node identifier to check for existence. Returns: bool: True if node exists in tree, False otherwise. Example: Checking node existence:: # Explicit method call if tree.contains("user123"): print("User node exists") # Equivalent using 'in' operator if "user123" in tree: print("User node exists") # Use before operations if tree.contains("node_id"): tree.move_node("node_id", "new_parent") """ return True if nid in self._nodes else False
[docs] def create_node( self, tag: Optional[str] = None, identifier: Optional[str] = None, parent: Optional[Union[Node, str]] = None, data: Any = None, ) -> Node: """ Create and add a new node to the tree. This is the primary method for building tree structures. Creates a new node with the specified properties and attaches it to the given parent. If no parent is specified, the node becomes the root (only allowed if tree is empty). Args: tag: Human-readable label for display. If None, uses identifier. identifier: Unique ID for the node. If None, generates UUID automatically. parent: Parent node identifier, Node object, or None for root. Must be None if tree is empty, must exist if tree has nodes. data: Optional user data to associate with this node. Returns: Node: The newly created Node object. Raises: DuplicatedNodeIdError: If identifier already exists in the tree. MultipleRootError: If parent is None but tree already has a root. NodeIDAbsentError: If parent identifier doesn't exist in the tree. Example: Building a family tree:: tree = Tree() # Create root (no parent) tree.create_node("Grandpa", "grandpa") # Add children tree.create_node("Dad", "dad", parent="grandpa") tree.create_node("Uncle", "uncle", parent="grandpa") # Add grandchildren tree.create_node("Me", "me", parent="dad") tree.create_node("Sister", "sister", parent="dad") # Add node with custom data tree.create_node("Baby", "baby", parent="me", data={"age": 1, "cute": True}) """ node = self.node_class(tag=tag, identifier=identifier, data=data) self.add_node(node, parent) return node
[docs] def depth(self, node: Optional[Node] = None) -> int: """ Get the maximum depth of the tree or the level of a specific node. When called without arguments, returns the maximum depth of the entire tree (distance from root to deepest leaf). When called with a node, returns the level of that specific node (distance from root). Args: node: Node object or identifier string. If None, returns tree depth. If provided, returns the level of this specific node. Returns: int: Tree depth (max level) or node level. Root is at level 0. Raises: NodeIDAbsentError: If specified node doesn't exist in the tree. Example: Measuring tree dimensions:: # Get overall tree depth max_depth = tree.depth() print(f"Tree depth: {max_depth}") # Get specific node level node_level = tree.depth("some_node") node_level2 = tree.depth(tree["some_node"]) # Same result # Use for tree analysis if tree.depth() > 5: print("Tree is quite deep") """ ret = 0 if node is None: # Get maximum level of this tree leaves = self.leaves() for leave in leaves: level = self.level(leave.identifier) ret = level if level >= ret else ret else: # Get level of the given node if not isinstance(node, self.node_class): nid = node else: nid = node.identifier if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) ret = self.level(nid) return ret
[docs] def expand_tree( self, nid: Optional[str] = None, mode: int = DEPTH, filter: Optional[Callable[[Node], bool]] = None, key: Optional[Callable[[Node], Node]] = None, reverse: bool = False, sorting: bool = True, ): """ Traverse the tree and yield node identifiers in specified order. This is the primary method for tree iteration, providing flexible traversal with multiple algorithms, filtering, and sorting options. Essential for most tree processing operations. Args: nid: Starting node identifier. If None, starts from tree root. Must exist in the tree if specified. mode: Traversal algorithm to use: Tree.DEPTH (0) - Depth-first search (default) Tree.WIDTH (1) - Breadth-first search Tree.ZIGZAG (2) - ZigZag traversal (alternating levels) filter: Optional function to filter nodes during traversal. Takes Node object, returns bool. If False, node and its entire subtree are skipped. key: Sorting function for nodes at each level. Takes Node object, returns comparison key. If None and sorting=True, sorts by node. reverse: If True, reverse the sorting order at each level. sorting: If True, sort nodes at each level using key function. If False, preserve insertion order (key/reverse ignored). Yields: str: Node identifiers in traversal order. Raises: NodeIDAbsentError: If nid doesn't exist in the tree. ValueError: If mode is not a valid traversal constant. Example: Different traversal patterns:: tree = Tree() tree.create_node("A", "a") tree.create_node("B", "b", parent="a") tree.create_node("C", "c", parent="a") tree.create_node("D", "d", parent="b") tree.create_node("E", "e", parent="c") # Depth-first (default): A, B, D, C, E for node_id in tree.expand_tree(): print(f"DFS: {tree[node_id].tag}") # Breadth-first: A, B, C, D, E for node_id in tree.expand_tree(mode=Tree.WIDTH): print(f"BFS: {tree[node_id].tag}") # ZigZag traversal for node_id in tree.expand_tree(mode=Tree.ZIGZAG): print(f"ZigZag: {tree[node_id].tag}") # Filtered traversal (only nodes with vowels) vowel_filter = lambda node: any(v in node.tag.lower() for v in 'aeiou') for node_id in tree.expand_tree(filter=vowel_filter): print(f"Vowels: {tree[node_id].tag}") # Sorted by tag, reversed for node_id in tree.expand_tree(key=lambda x: x.tag, reverse=True): print(f"Sorted: {tree[node_id].tag}") # From specific subtree for node_id in tree.expand_tree(nid="b"): print(f"Subtree: {tree[node_id].tag}") Common Usage Patterns:: # Process all nodes for node_id in tree.expand_tree(): process_node(tree[node_id]) # Get all node tags tags = [tree[nid].tag for nid in tree.expand_tree()] # Find specific nodes matching_nodes = [nid for nid in tree.expand_tree() if tree[nid].tag.startswith("prefix")] # Level-order processing for node_id in tree.expand_tree(mode=Tree.WIDTH): level = tree.level(node_id) print(f"Level {level}: {tree[node_id].tag}") Performance: - Time complexity: O(n) where n is number of visited nodes - Memory: O(h) where h is tree height (for traversal stack) - Generator-based: memory efficient for large trees Note: This is a generator function - it yields results lazily. Perfect for large trees where you might not need to visit all nodes, or when memory efficiency is important. """ nid = cast(str, self.root if nid is None else nid) if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) filter = (lambda x: True) if (filter is None) else filter if filter(self[nid]): yield nid queue = [self[i] for i in self[nid].successors(self._identifier) if filter(self[i])] if mode in [self.DEPTH, self.WIDTH]: if sorting: queue.sort(key=key, reverse=reverse) while queue: yield queue[0].identifier expansion = [self[i] for i in queue[0].successors(self._identifier) if filter(self[i])] if sorting: expansion.sort(key=key, reverse=reverse) if mode is self.DEPTH: queue = expansion + queue[1:] # depth-first elif mode is self.WIDTH: queue = queue[1:] + expansion # width-first elif mode is self.ZIGZAG: # Suggested by Ilya Kuprik (ilya-spy@ynadex.ru). stack_fw: NodeList = [] queue.reverse() stack = stack_bw = queue direction = False while stack: expansion = [self[i] for i in stack[0].successors(self._identifier) if filter(self[i])] yield stack.pop(0).identifier if direction: expansion.reverse() stack_bw = expansion + stack_bw else: stack_fw = expansion + stack_fw if not stack: direction = not direction stack = stack_fw if direction else stack_bw else: raise ValueError("Traversal mode '{}' is not supported".format(mode))
[docs] def filter_nodes(self, func: Callable[[Node], bool]): """ Filter all nodes in the tree using a custom function. Applies the given function to every node and returns an iterator over nodes where the function returns True. Useful for finding specific types of nodes or nodes with certain properties. Args: func: Function that takes a Node object and returns bool. True means include the node in results. Returns: Iterator[Node]: Iterator over nodes where func returns True. Example: Filtering nodes by criteria:: # Find all leaf nodes leaves = list(tree.filter_nodes(lambda n: n.is_leaf(tree.identifier))) # Find nodes with specific data important_nodes = list(tree.filter_nodes( lambda n: hasattr(n.data, 'priority') and n.data.priority == 'high' )) # Find nodes by tag pattern temp_nodes = list(tree.filter_nodes( lambda n: n.tag.startswith('temp_') )) # Memory efficient processing for node in tree.filter_nodes(lambda n: n.data is not None): process_node_with_data(node) Note: Added by William Rusnack for flexible node filtering. """ return filter(func, self.all_nodes_itr())
[docs] def get_node(self, nid: Optional[str]) -> Optional[Node]: """ Safely get a node by identifier without raising exceptions. Provides null-safe access to nodes, returning None if the node doesn't exist instead of raising an exception. Preferred over bracket notation when node existence is uncertain. Args: nid: Node identifier to retrieve, or None. Returns: Node or None: The node object if found, None otherwise. Example: Safe node access:: # Safe access (no exception) node = tree.get_node("might_not_exist") if node is not None: print(f"Found: {node.tag}") else: print("Node not found") # Compare with bracket notation try: node = tree["might_not_exist"] # Raises exception except NodeIDAbsentError: node = None # Use in conditional logic user_node = tree.get_node("user123") if user_node and user_node.data.active: process_active_user(user_node) """ if nid is None or not self.contains(nid): return None return self._nodes[nid]
[docs] def is_branch(self, nid): """ Get the list of child node identifiers for the specified node. Returns the identifiers of all direct children of the given node. Despite the name suggesting a boolean, this method actually returns the list of child identifiers. Use children() for Node objects. Args: nid: Identifier of the parent node. Cannot be None. Returns: list[str]: List of child node identifiers. Empty if no children. Raises: OSError: If nid is None. NodeIDAbsentError: If nid doesn't exist in the tree. Example: Getting child identifiers:: # Get child IDs child_ids = tree.is_branch("parent_id") for child_id in child_ids: print(f"Child ID: {child_id}") # Check if node has children if tree.is_branch("node_id"): print("Node has children") # Use with node objects child_nodes = [tree[child_id] for child_id in tree.is_branch("parent")] """ if nid is None: raise OSError("First parameter can't be None") if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) try: fpointer = self[nid].successors(self._identifier) except KeyError: fpointer = [] return fpointer
[docs] def leaves(self, nid: Optional[str] = None) -> NodeList: """ Get all leaf nodes in the tree or subtree. Returns all nodes that have no children. If a starting node is specified, only considers leaves within that subtree. Leaf nodes are terminal nodes in the tree structure and often represent end points in hierarchies. Args: nid: Root of subtree to search. If None, searches entire tree. Returns: list[Node]: List of all leaf Node objects in the specified scope. Example: Finding leaf nodes:: # Get all leaves in tree all_leaves = tree.leaves() print(f"Tree has {len(all_leaves)} leaf nodes") # Get leaves in specific subtree subtree_leaves = tree.leaves("department_root") # Process leaf nodes for leaf in tree.leaves(): print(f"Leaf: {leaf.tag}") # Find deepest nodes max_level = max(tree.level(leaf.identifier) for leaf in tree.leaves()) deepest_leaves = [leaf for leaf in tree.leaves() if tree.level(leaf.identifier) == max_level] """ leaves = [] if nid is None: for node in self._nodes.values(): if node.is_leaf(self._identifier): leaves.append(node) else: for node in self.expand_tree(nid): if self[node].is_leaf(self._identifier): leaves.append(self[node]) return leaves
[docs] def level(self, nid, filter=None): """ Get the level (depth) of a node in the tree hierarchy. Returns the distance from the root to the specified node, where the root is at level 0. Optionally filters the path calculation by excluding certain nodes. Args: nid: Identifier of the node to get level for. filter: Optional function to filter nodes during path calculation. Takes Node object, returns bool. Filtered nodes are excluded. Returns: int: Level of the node (0 for root, 1 for root's children, etc.) Example: Getting node levels:: # Basic level calculation root_level = tree.level("root") # 0 child_level = tree.level("child") # 1 grandchild_level = tree.level("gc") # 2 # Use for tree analysis max_depth = max(tree.level(node.identifier) for node in tree.all_nodes()) # Filtered level calculation level = tree.level("node", filter=lambda n: n.tag != "skip") """ return len([n for n in self.rsearch(nid, filter)]) - 1
[docs] def move_node(self, source, destination): """ Move a node from its current parent to a new parent. Changes the parent-child relationship by moving the source node (and its entire subtree) from its current parent to the specified destination parent. This operation preserves the subtree structure while reorganizing the tree hierarchy. Args: source: Identifier of the node to move. destination: Identifier of the new parent node. Raises: NodeIDAbsentError: If source or destination node doesn't exist. LoopError: If moving would create a circular reference (destination is a descendant of source). Example: Reorganizing tree structure:: # Move employee between departments tree.move_node("employee123", "sales_dept") # Move entire department tree.move_node("engineering_dept", "new_division") # Prevent circular moves (this would raise LoopError) try: tree.move_node("parent", "child_of_parent") except LoopError: print("Cannot create circular reference") # Bulk reorganization employees = ["emp1", "emp2", "emp3"] for emp in employees: tree.move_node(emp, "new_manager") """ if not self.contains(source) or not self.contains(destination): raise NodeIDAbsentError elif self.is_ancestor(source, destination): raise LoopError parent = self[source].predecessor(self._identifier) self.__update_fpointer(parent, source, self.node_class.DELETE) self.__update_fpointer(destination, source, self.node_class.ADD) self.__update_bpointer(source, destination)
[docs] def is_ancestor(self, ancestor, grandchild): """ Check if one node is an ancestor of another node. Determines whether the ancestor node appears in the path from the grandchild node to the root. An ancestor is any node that appears above another node in the tree hierarchy. Args: ancestor: Identifier of the potential ancestor node. grandchild: Identifier of the potential descendant node. Returns: bool: True if ancestor is indeed an ancestor of grandchild, False otherwise. Example: Checking ancestry relationships:: # Direct parent-child is_parent = tree.is_ancestor("parent", "child") # True # Multi-level ancestry is_grandparent = tree.is_ancestor("grandparent", "grandchild") # True # Not related is_ancestor = tree.is_ancestor("sibling1", "sibling2") # False # Prevent circular moves if not tree.is_ancestor("node_to_move", "new_parent"): tree.move_node("node_to_move", "new_parent") else: print("Cannot create circular reference") """ parent = self[grandchild].predecessor(self._identifier) child = grandchild while parent is not None: if parent == ancestor: return True else: child = self[child].predecessor(self._identifier) parent = self[child].predecessor(self._identifier) return False
@property def nodes(self): """ Get the internal dictionary mapping node identifiers to Node objects. Returns the underlying dictionary that stores all nodes in the tree, mapping from node identifiers to Node instances. Useful for direct access to the node collection and bulk operations. Returns: dict[str, Node]: Dictionary mapping node IDs to Node objects. Example: Working with the nodes dictionary:: # Get all node identifiers all_ids = list(tree.nodes.keys()) # Get all node objects all_nodes = list(tree.nodes.values()) # Direct access to nodes dict nodes_dict = tree.nodes for node_id, node in nodes_dict.items(): print(f"ID: {node_id}, Tag: {node.tag}") # Bulk operations total_nodes = len(tree.nodes) has_node = "some_id" in tree.nodes """ return self._nodes
[docs] def parent(self, nid: str) -> Optional[Node]: """ Get the parent Node object of the specified node. Returns the immediate parent of the given node, or None if the node is the root or has no parent. Provides object-based access to the parent, unlike ancestor() which returns identifiers. Args: nid: Identifier of the node whose parent to retrieve. Returns: Node or None: Parent Node object, or None if node is root. Raises: NodeIDAbsentError: If nid doesn't exist in the tree. Example: Working with parent relationships:: # Get parent node parent_node = tree.parent("child_id") if parent_node: print(f"Parent: {parent_node.tag}") else: print("Node is root or orphaned") # Navigate up the hierarchy current = "leaf_node" while tree.parent(current): current = tree.parent(current).identifier print(f"Ancestor: {tree[current].tag}") # Check parent properties parent = tree.parent("employee") if parent and hasattr(parent.data, 'department'): print(f"Department: {parent.data.department}") """ if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) pid = self[nid].predecessor(self._identifier) if pid is None or not self.contains(pid): return None return self[pid]
[docs] def merge(self, nid: str, new_tree, deep: bool = False): """Patch @new_tree on current tree by pasting new_tree root children on current tree @nid node. Consider the following tree: >>> current.show() root ├── A └── B >>> new_tree.show() root2 ├── C └── D └── D1 Merging new_tree on B node: >>>current.merge('B', new_tree) >>>current.show() root ├── A └── B ├── C └── D └── D1 Note: if current tree is empty and nid is None, the new_tree root will be used as root on current tree. In all other cases new_tree root is not pasted. """ if new_tree.root is None: return if nid is None: if self.root is None: new_tree_root = new_tree[new_tree.root] self.add_node(new_tree_root) nid = new_tree.root else: raise ValueError('Must define "nid" under which new tree is merged.') for child in new_tree.children(new_tree.root): self.paste(nid=nid, new_tree=new_tree.subtree(child.identifier), deep=deep)
[docs] def paste(self, nid: str, new_tree, deep: bool = False): """ Paste a @new_tree to the original one by linking the root of new tree to given node (nid). Update: add @deep copy of pasted tree. """ assert isinstance(new_tree, Tree) if new_tree.root is None: return if nid is None: raise ValueError('Must define "nid" under which new tree is pasted.') if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) set_joint = set(new_tree._nodes) & set(self._nodes) # joint keys if set_joint: raise ValueError("Duplicated nodes %s exists." % list(map(text, set_joint))) for cid, node in iteritems(new_tree.nodes): if deep: node = deepcopy(new_tree[node]) self._nodes.update({cid: node}) node.clone_pointers(new_tree.identifier, self._identifier) self.__update_bpointer(new_tree.root, nid) self.__update_fpointer(nid, new_tree.root, self.node_class.ADD)
[docs] def paths_to_leaves(self) -> StrLList: """ Use this function to get the identifiers allowing to go from the root nodes to each leaf. :return: a list of list of identifiers, root being not omitted. For example: .. code-block:: python Harry |___ Bill |___ Jane | |___ Diane | |___ George | |___ Jill | |___ Mary | |___ Mark Expected result: .. code-block:: python [['harry', 'jane', 'diane', 'mary'], ['harry', 'jane', 'mark'], ['harry', 'jane', 'diane', 'george', 'jill'], ['harry', 'bill']] """ res = [] for leaf in self.leaves(): res.append([nid for nid in self.rsearch(leaf.identifier)][::-1]) return res
[docs] def remove_node(self, identifier: str) -> int: """Remove a node indicated by 'identifier' with all its successors. Return the number of removed nodes. """ if not self.contains(identifier): raise NodeIDAbsentError("Node '%s' " "is not in the tree" % identifier) parent = self[identifier].predecessor(self._identifier) # Remove node and its children removed = list(self.expand_tree(identifier)) for id_ in removed: if id_ == self.root: self.root = None self.__update_bpointer(id_, None) for cid in self[id_].successors(self._identifier) or []: self.__update_fpointer(id_, cid, self.node_class.DELETE) # Update parent info self.__update_fpointer(parent, identifier, self.node_class.DELETE) self.__update_bpointer(identifier, None) for id_ in removed: self.nodes.pop(id_) return len(removed)
[docs] def remove_subtree(self, nid: str, identifier: Optional[str] = None): """ Get a subtree with ``nid`` being the root. If nid is None, an empty tree is returned. For the original tree, this method is similar to `remove_node(self,nid)`, because given node and its children are removed from the original tree in both methods. For the returned value and performance, these two methods are different: * `remove_node` returns the number of deleted nodes; * `remove_subtree` returns a subtree of deleted nodes; You are always suggested to use `remove_node` if your only to delete nodes from a tree, as the other one need memory allocation to store the new tree. :return: a :class:`Tree` object. """ st = self._clone(identifier) if nid is None: return st if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) st.root = nid # in original tree, the removed nid will be unreferenced from its # parents children parent = self[nid].predecessor(self._identifier) removed = list(self.expand_tree(nid)) for id_ in removed: if id_ == self.root: self.root = None st._nodes.update({id_: self._nodes.pop(id_)}) st[id_].clone_pointers(cast(str, self._identifier), cast(str, st.identifier)) st[id_].reset_pointers(self._identifier) if id_ == nid: st[id_].set_predecessor(None, st.identifier) self.__update_fpointer(parent, nid, self.node_class.DELETE) return st
[docs] def rsearch(self, nid: str, filter: Optional[Callable[[Node], bool]] = None): """ Traverse the tree branch along the branch from nid to its ancestors (until root). :param filter: the function of one variable to act on the :class:`Node` object. """ if nid is None: return if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) filter = (lambda x: True) if (filter is None) else filter current = nid while current is not None: if filter(self[current]): yield current # subtree() hasn't update the bpointer current = self[current].predecessor(self._identifier) if self.root != current else None
[docs] def save2file( self, filename: str, nid: Optional[str] = None, level: int = ROOT, idhidden: bool = True, filter: Optional[Callable[[Node], bool]] = None, key: Optional[Callable[[Node], Node]] = None, reverse: bool = False, line_type: str = "ascii-ex", data_property: Optional[str] = None, sorting: bool = True, ): """ Save the tree into file for offline analysis. """ def _write_line(line, f): f.write(line + b"\n") def handler(x): return _write_line(x, open(filename, "ab")) self.__print_backend( nid, level, idhidden, filter, key, reverse, line_type, data_property, sorting, func=handler, )
[docs] def show( self, nid: Optional[str] = None, level: int = ROOT, idhidden: bool = True, filter: Optional[Callable[[Node], bool]] = None, key: Optional[Callable[[Node], Node]] = None, reverse: bool = False, line_type: str = "ascii-ex", data_property: Optional[str] = None, stdout: bool = True, sorting: bool = True, ): """ Display the tree structure in a beautiful hierarchical format. This method provides rich visualization options for tree structures, allowing customization of appearance, content filtering, and display properties. Perfect for debugging, documentation, and user interfaces. Args: nid: Starting node identifier. If None, starts from root. level: Starting level (0 for root). Mainly for internal use. idhidden: If True, hide node identifiers in output (cleaner display). If False, show both tag and identifier as "tag[id]". filter: Function to selectively display nodes. Takes Node object, returns bool. Filtered nodes and their subtrees are hidden. key: Sorting function for nodes at each level. Takes Node object, returns comparison key. If None and sorting=True, sorts by Node. reverse: If True, reverse the sorting order at each level. line_type: Visual style for tree connections. Options: 'ascii' - Simple |-- style 'ascii-ex' - Unicode ├── style (default) 'ascii-exr' - Unicode with rounded corners 'ascii-em' - Double-line Unicode ╠══ style 'ascii-emv' - Mixed vertical Unicode style 'ascii-emh' - Mixed horizontal Unicode style data_property: If specified, display this property from node.data instead of node.tag. Useful for rich data objects. stdout: If True, print to console. If False, return as string. sorting: If True, sort nodes at each level. If False, preserve insertion order (key and reverse are ignored). Returns: str or None: If stdout=False, returns formatted string. If stdout=True, prints to console and returns None. Example: Different display styles:: tree = Tree() tree.create_node("Root", "root") tree.create_node("Child A", "a", parent="root") tree.create_node("Child B", "b", parent="root") # Basic display tree.show() # Fancy style with IDs tree.show(idhidden=False, line_type="ascii-em") # Filtered display (only nodes starting with 'C') tree.show(filter=lambda x: x.tag.startswith('C')) # Sorted by tag, reversed tree.show(key=lambda x: x.tag, reverse=True) # Show custom data property tree.show(data_property="name") # If node.data.name exists # Return as string instead of printing tree_str = tree.show(stdout=False) Output styles:: ascii: Root |-- Child A |-- Child B ascii-ex: Root ├── Child A └── Child B ascii-em: Root ╠══ Child A ╚══ Child B Note: The tree display automatically handles complex hierarchies with proper indentation and connection lines. Very large trees may be truncated for performance reasons. """ self._reader = "" def write(line): self._reader += line.decode("utf-8") + "\n" try: self.__print_backend( nid, level, idhidden, filter, key, reverse, line_type, data_property, sorting, func=write, ) except NodeIDAbsentError: print("Tree is empty") if stdout: print(self._reader) else: return self._reader
[docs] def siblings(self, nid: str) -> NodeList: """ Return the siblings of given @nid. If @nid is root or there are no siblings, an empty list is returned. """ siblings = [] if nid != self.root: pid = self[nid].predecessor(self._identifier) siblings = [self[i] for i in self[pid].successors(self._identifier) if i != nid] return siblings
[docs] def size(self, level: Optional[int] = None) -> int: """ Get the number of nodes of the whole tree if @level is not given. Otherwise, the total number of nodes at specific level is returned. @param level The level number in the tree. It must be between [0, tree.depth]. Otherwise, InvalidLevelNumber exception will be raised. """ if level is None: return len(self._nodes) else: try: level = int(level) return len([node for node in self.all_nodes_itr() if self.level(node.identifier) == level]) except Exception: raise TypeError("level should be an integer instead of '%s'" % type(level))
[docs] def subtree(self, nid: str, identifier: Optional[str] = None): """ Return a shallow COPY of subtree with nid being the new root. If nid is None, return an empty tree. If you are looking for a deepcopy, please create a new tree with this shallow copy, e.g., .. code-block:: python new_tree = Tree(t.subtree(t.root), deep=True) This line creates a deep copy of the entire tree. """ st = self._clone(identifier) if nid is None: return st if not self.contains(nid): raise NodeIDAbsentError("Node '%s' is not in the tree" % nid) st.root = nid for node_n in self.expand_tree(nid): st._nodes.update({self[node_n].identifier: self[node_n]}) # define nodes parent/children in this tree # all pointers are the same as copied tree, except the root st[node_n].clone_pointers(cast(str, self._identifier), cast(str, st.identifier)) if node_n == nid: # reset root parent for the new tree st[node_n].set_predecessor(None, st.identifier) return st
[docs] def update_node(self, nid: str, **attrs) -> None: """ Update node's attributes. :param nid: the identifier of modified node :param attrs: attribute pairs recognized by Node object :return: None """ cn = self[nid] for attr, val in iteritems(attrs): if attr == "identifier": # Updating node id meets following contraints: # * Update node identifier property # * Update parent's followers # * Update children's parents # * Update tree registration of var _nodes # * Update tree root if necessary cn = self._nodes.pop(nid) setattr(cn, "identifier", val) self._nodes[val] = cn if cn.predecessor(self._identifier) is not None: self[cn.predecessor(self._identifier)].update_successors( nid, mode=self.node_class.REPLACE, replace=val, tree_id=self._identifier, ) for fp in cn.successors(self._identifier): self[fp].set_predecessor(val, self._identifier) if self.root == nid: self.root = val else: setattr(cn, attr, val)
[docs] def to_dict(self, nid=None, key=None, sort=True, reverse=False, with_data=False): """Transform the whole tree into a dict.""" nid = self.root if (nid is None) else nid ntag = self[nid].tag tree_dict = {ntag: {"children": []}} if with_data: tree_dict[ntag]["data"] = self[nid].data if self[nid].expanded: queue = [self[i] for i in self[nid].successors(self._identifier)] key = (lambda x: x) if (key is None) else key if sort: queue.sort(key=key, reverse=reverse) for elem in queue: tree_dict[ntag]["children"].append( self.to_dict(elem.identifier, with_data=with_data, sort=sort, reverse=reverse) ) if len(tree_dict[ntag]["children"]) == 0: tree_dict = self[nid].tag if not with_data else {ntag: {"data": self[nid].data}} return tree_dict
[docs] def to_json(self, with_data: bool = False, sort: bool = True, reverse: bool = False): """ Export the tree structure as a JSON string. Converts the entire tree into a JSON representation that can be easily saved, transmitted, or reconstructed. Perfect for data persistence, API responses, and cross-system integration. Args: with_data: If True, include node.data in the JSON output. If False, only include tree structure and tags. Default False for smaller output size. sort: If True, sort children at each level by node comparison. If False, preserve insertion order. reverse: If True, reverse the sorting order at each level. Only applies when sort=True. Returns: str: JSON string representation of the tree. Example: Basic JSON export:: tree = Tree() tree.create_node("Root", "root") tree.create_node("Child A", "a", parent="root") tree.create_node("Child B", "b", parent="root") # Basic structure only json_str = tree.to_json() print(json_str) # Output: {"Root": {"children": [{"Child A": {}}, {"Child B": {}}]}} # Include node data tree.create_node("Data Node", "data", parent="root", data={"type": "important", "value": 42}) json_with_data = tree.to_json(with_data=True) print(json_with_data) # Output includes: "data": {"type": "important", "value": 42} # Sorted output sorted_json = tree.to_json(sort=True, reverse=True) JSON Structure:: The output follows this hierarchical format: { "root_tag": { "children": [ { "child1_tag": { "children": [...], "data": {...} // if with_data=True } }, { "child2_tag": {} // leaf node } ], "data": {...} // if with_data=True } } Usage with APIs:: # Save to file with open('tree.json', 'w') as f: f.write(tree.to_json(with_data=True)) # Send via HTTP import requests response = requests.post('/api/trees', json={'tree': tree.to_json()}) # Pretty print import json formatted = json.dumps(json.loads(tree.to_json()), indent=2) print(formatted) Note: The JSON format preserves the complete tree structure and can be used to reconstruct equivalent trees. However, complex data objects in node.data must be JSON-serializable (no functions, custom classes without __dict__, etc.). """ return json.dumps(self.to_dict(with_data=with_data, sort=sort, reverse=reverse))
[docs] def to_graphviz( self, filename: Optional[str] = None, shape: str = "circle", graph: str = "digraph", filter=None, key=None, reverse: bool = False, sorting: bool = True, ): """Exports the tree in the dot format of the graphviz software""" nodes, connections = [], [] if self.nodes: for n in self.expand_tree( mode=self.WIDTH, filter=filter, key=key, reverse=reverse, sorting=sorting, ): nid = self[n].identifier label = str(self[n].tag).translate(str.maketrans({'"': r"\""})) state = '"{0}" [label="{1}", shape={2}]'.format(nid, label, shape) nodes.append(state) for c in self.children(nid): cid = c.identifier edge = "->" if graph == "digraph" else "--" connections.append(('"{0}" ' + edge + ' "{1}"').format(nid, cid)) # write nodes and connections to dot format is_plain_file = filename is not None if is_plain_file: f = codecs.open(cast(str, filename), "w", "utf-8") else: f = StringIO() f.write(graph + " tree {\n") for n in nodes: f.write("\t" + n + "\n") if len(connections) > 0: f.write("\n") for cns in connections: f.write("\t" + cns + "\n") f.write("}") if not is_plain_file: print(f.getvalue()) f.close()
[docs] @classmethod def from_map(cls, child_parent_dict, id_func=None, data_func=None): """ takes a dict with child:parent, and form a tree """ tree = Tree() if tree is None or tree.size() > 0: raise ValueError("need to pass in an empty tree") id_func = id_func if id_func else lambda x: x data_func = data_func if data_func else lambda x: None parent_child_dict = {} root_node = None for k, v in child_parent_dict.items(): if v is None and root_node is None: root_node = k elif v is None and root_node is not None: raise ValueError("invalid input, more than 1 child has no parent") else: if v in parent_child_dict: parent_child_dict[v].append(k) else: parent_child_dict[v] = [k] if root_node is None: raise ValueError("cannot find root") tree.create_node(root_node, id_func(root_node), data=data_func(root_node)) queue = [root_node] while len(queue) > 0: parent_node = queue.pop() for child in parent_child_dict.get(parent_node, []): tree.create_node( child, id_func(child), parent=id_func(parent_node), data=data_func(child), ) queue.append(child) return tree