Source code for mode.utils.graphs.graph

"""Data structure: Dependency graph."""
from functools import partial
from typing import (
    Any,
    Callable,
    IO,
    ItemsView,
    Iterable,
    Iterator,
    List,
    MutableMapping,
    Sequence,
    Set,
    cast,
)

from mode.utils.typing import Counter
from mode.utils.types.graphs import DependencyGraphT, GraphFormatterT, _T

from .formatter import GraphFormatter


class CycleError(Exception):
    """A cycle was detected in an acyclic graph."""


[docs]class DependencyGraph(DependencyGraphT): """A directed acyclic graph of objects and their dependencies. Supports a robust topological sort to detect the order in which they must be handled. Takes an optional iterator of ``(obj, dependencies)`` tuples to build the graph from. Warning: Does not support cycle detection. """ adjacent: MutableMapping def __init__(self, it: Iterable = None, formatter: GraphFormatterT[_T] = None) -> None: self.formatter = formatter or GraphFormatter() self.adjacent = {} if it is not None: self.update(it)
[docs] def add_arc(self, obj: _T) -> None: """Add an object to the graph.""" self.adjacent.setdefault(obj, [])
[docs] def add_edge(self, A: _T, B: _T) -> None: """Add an edge from object ``A`` to object ``B``. I.e. ``A`` depends on ``B``. """ self[A].append(B)
[docs] def connect(self, graph: DependencyGraphT[_T]) -> None: """Add nodes from another graph.""" self.adjacent.update(graph.adjacent)
[docs] def topsort(self) -> Sequence: """Sort the graph topologically. Returns: List: of objects in the order in which they must be handled. """ graph = DependencyGraph() components = self._tarjan72() NC = { node: component for component in components for node in component } for component in components: graph.add_arc(component) for node in self: node_c = NC[node] for successor in self[node]: successor_c = NC[successor] if node_c != successor_c: graph.add_edge(node_c, successor_c) return [t[0] for t in graph._khan62()]
[docs] def valency_of(self, obj: _T) -> int: """Return the valency (degree) of a vertex in the graph.""" try: sizes = [len(self[obj])] except KeyError: return 0 for node in self[obj]: sizes.append(self.valency_of(node)) return sum(sizes)
[docs] def update(self, it: Iterable) -> None: """Update graph with data from a list of ``(obj, deps)`` tuples.""" tups = list(it) for obj, _ in tups: self.add_arc(obj) for obj, deps in tups: for dep in deps: self.add_edge(obj, dep)
[docs] def edges(self) -> Iterable: """Return generator that yields for all edges in the graph.""" return (obj for obj, adj in self.items() if adj)
def _khan62(self) -> Sequence: """Perform Khan's simple topological sort algorithm from '62. See https://en.wikipedia.org/wiki/Topological_sorting """ count: Counter[Any] = Counter() result = [] for node in self: for successor in self[node]: count[successor] += 1 ready = [node for node in self if not count[node]] while ready: node = ready.pop() result.append(node) for successor in self[node]: count[successor] -= 1 if count[successor] == 0: ready.append(successor) result.reverse() return result def _tarjan72(self) -> Sequence: """Perform Tarjan's algorithm to find strongly connected components. See Also: :wikipedia:`Tarjan%27s_strongly_connected_components_algorithm` """ result: List = [] stack: List = [] low: List = [] def visit(node: Any) -> None: if node in low: return num = len(low) low[node] = num stack_pos = len(stack) stack.append(node) for successor in self[node]: visit(successor) low[node] = min(low[node], low[successor]) if num == low[node]: component = tuple(stack[stack_pos:]) stack[stack_pos:] = [] result.append(component) for item in component: low[item] = len(self) for node in self: visit(node) return result
[docs] def to_dot(self, fh: IO, *, formatter: GraphFormatterT[_T] = None) -> None: """Convert the graph to DOT format. Arguments: fh (IO): A file, or a file-like object to write the graph to. formatter (celery.utils.graph.GraphFormatter): Custom graph formatter to use. """ seen: Set = set() draw = formatter or self.formatter write = partial(print, file=fh) # noqa: T101 def if_not_seen(fun: Callable[[Any], str], obj: Any) -> None: label = draw.label(obj) if label not in seen: write(fun(obj)) seen.add(label) write(draw.head()) for obj, adjacent in self.items(): if not adjacent: if_not_seen(draw.terminal_node, obj) for req in adjacent: if_not_seen(draw.node, obj) write(draw.edge(obj, req)) write(draw.tail())
def __iter__(self) -> Iterator: return iter(self.adjacent) def __getitem__(self, node: _T) -> Any: return self.adjacent[node] def __len__(self) -> int: return len(self.adjacent) def __contains__(self, obj: _T) -> bool: return obj in self.adjacent
[docs] def items(self) -> ItemsView: return cast(ItemsView, self.adjacent.items())
def __repr__(self) -> str: return '\n'.join(self._repr_node(N) for N in self) def _repr_node(self, obj: _T, level: int = 1, fmt: str = '{0}({1})') -> str: output = [fmt.format(obj, self.valency_of(obj))] if obj in self: for other in self[obj]: d = fmt.format(other, self.valency_of(other)) output.append(' ' * level + d) output.extend( self._repr_node(other, level + 1).split('\n')[1:]) return '\n'.join(output)