"""Data structure: Dependency graph."""
from functools import partial
from typing import (
Any,
Callable,
IO,
ItemsView,
Iterable,
Iterator,
List,
MutableMapping,
Sequence,
Set,
cast,
)
from mode.utils.typing import Counter
from mode.utils.types.graphs import DependencyGraphT, GraphFormatterT, _T
from .formatter import GraphFormatter
class CycleError(Exception):
"""A cycle was detected in an acyclic graph."""
[docs]class DependencyGraph(DependencyGraphT):
"""A directed acyclic graph of objects and their dependencies.
Supports a robust topological sort
to detect the order in which they must be handled.
Takes an optional iterator of ``(obj, dependencies)``
tuples to build the graph from.
Warning:
Does not support cycle detection.
"""
adjacent: MutableMapping
def __init__(self,
it: Iterable = None,
formatter: GraphFormatterT[_T] = None) -> None:
self.formatter = formatter or GraphFormatter()
self.adjacent = {}
if it is not None:
self.update(it)
[docs] def add_arc(self, obj: _T) -> None:
"""Add an object to the graph."""
self.adjacent.setdefault(obj, [])
[docs] def add_edge(self, A: _T, B: _T) -> None:
"""Add an edge from object ``A`` to object ``B``.
I.e. ``A`` depends on ``B``.
"""
self[A].append(B)
[docs] def connect(self, graph: DependencyGraphT[_T]) -> None:
"""Add nodes from another graph."""
self.adjacent.update(graph.adjacent)
[docs] def topsort(self) -> Sequence:
"""Sort the graph topologically.
Returns:
List: of objects in the order in which they must be handled.
"""
graph = DependencyGraph()
components = self._tarjan72()
NC = {
node: component for component in components for node in component
}
for component in components:
graph.add_arc(component)
for node in self:
node_c = NC[node]
for successor in self[node]:
successor_c = NC[successor]
if node_c != successor_c:
graph.add_edge(node_c, successor_c)
return [t[0] for t in graph._khan62()]
[docs] def valency_of(self, obj: _T) -> int:
"""Return the valency (degree) of a vertex in the graph."""
try:
sizes = [len(self[obj])]
except KeyError:
return 0
for node in self[obj]:
sizes.append(self.valency_of(node))
return sum(sizes)
[docs] def update(self, it: Iterable) -> None:
"""Update graph with data from a list of ``(obj, deps)`` tuples."""
tups = list(it)
for obj, _ in tups:
self.add_arc(obj)
for obj, deps in tups:
for dep in deps:
self.add_edge(obj, dep)
[docs] def edges(self) -> Iterable:
"""Return generator that yields for all edges in the graph."""
return (obj for obj, adj in self.items() if adj)
def _khan62(self) -> Sequence:
"""Perform Khan's simple topological sort algorithm from '62.
See https://en.wikipedia.org/wiki/Topological_sorting
"""
count: Counter[Any] = Counter()
result = []
for node in self:
for successor in self[node]:
count[successor] += 1
ready = [node for node in self if not count[node]]
while ready:
node = ready.pop()
result.append(node)
for successor in self[node]:
count[successor] -= 1
if count[successor] == 0:
ready.append(successor)
result.reverse()
return result
def _tarjan72(self) -> Sequence:
"""Perform Tarjan's algorithm to find strongly connected components.
See Also:
:wikipedia:`Tarjan%27s_strongly_connected_components_algorithm`
"""
result: List = []
stack: List = []
low: List = []
def visit(node: Any) -> None:
if node in low:
return
num = len(low)
low[node] = num
stack_pos = len(stack)
stack.append(node)
for successor in self[node]:
visit(successor)
low[node] = min(low[node], low[successor])
if num == low[node]:
component = tuple(stack[stack_pos:])
stack[stack_pos:] = []
result.append(component)
for item in component:
low[item] = len(self)
for node in self:
visit(node)
return result
[docs] def to_dot(self, fh: IO,
*,
formatter: GraphFormatterT[_T] = None) -> None:
"""Convert the graph to DOT format.
Arguments:
fh (IO): A file, or a file-like object to write the graph to.
formatter (celery.utils.graph.GraphFormatter): Custom graph
formatter to use.
"""
seen: Set = set()
draw = formatter or self.formatter
write = partial(print, file=fh) # noqa: T101
def if_not_seen(fun: Callable[[Any], str], obj: Any) -> None:
label = draw.label(obj)
if label not in seen:
write(fun(obj))
seen.add(label)
write(draw.head())
for obj, adjacent in self.items():
if not adjacent:
if_not_seen(draw.terminal_node, obj)
for req in adjacent:
if_not_seen(draw.node, obj)
write(draw.edge(obj, req))
write(draw.tail())
def __iter__(self) -> Iterator:
return iter(self.adjacent)
def __getitem__(self, node: _T) -> Any:
return self.adjacent[node]
def __len__(self) -> int:
return len(self.adjacent)
def __contains__(self, obj: _T) -> bool:
return obj in self.adjacent
[docs] def items(self) -> ItemsView:
return cast(ItemsView, self.adjacent.items())
def __repr__(self) -> str:
return '\n'.join(self._repr_node(N) for N in self)
def _repr_node(self, obj: _T,
level: int = 1,
fmt: str = '{0}({1})') -> str:
output = [fmt.format(obj, self.valency_of(obj))]
if obj in self:
for other in self[obj]:
d = fmt.format(other, self.valency_of(other))
output.append(' ' * level + d)
output.extend(
self._repr_node(other, level + 1).split('\n')[1:])
return '\n'.join(output)