Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random_ordered_tree and forest_str #4294

Merged
54 changes: 50 additions & 4 deletions networkx/generators/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _helper(paths, root, B):
# > method of generating uniformly distributed random labelled trees.
#
@py_random_state(1)
def random_tree(n, seed=None):
def random_tree(n, seed=None, create_using=None):
"""Returns a uniformly random tree on `n` nodes.

Parameters
Expand Down Expand Up @@ -184,11 +184,57 @@ def random_tree(n, seed=None):
*n* nodes, the tree is chosen uniformly at random from the set of
all trees on *n* nodes.

Example
-------
>>> import networkx as nx
Erotemic marked this conversation as resolved.
Show resolved Hide resolved
>>> tree = nx.random_tree(n=10, seed=0)
>>> print(nx.forest_str(tree, sources=[0]))
╙── 0
├── 3
└── 4
├── 6
│   ├── 1
│   ├── 2
│   └── 7
│   └── 8
│   └── 5
└── 9

>>> import networkx as nx
Erotemic marked this conversation as resolved.
Show resolved Hide resolved
>>> tree = nx.random_tree(n=10, seed=0, create_using=nx.DiGraph)
>>> print(nx.forest_str(tree))
╙── 0
├─╼ 3
└─╼ 4
├─╼ 6
│   ├─╼ 1
│   ├─╼ 2
│   └─╼ 7
│   └─╼ 8
│   └─╼ 5
└─╼ 9
"""
if n == 0:
raise nx.NetworkXPointlessConcept("the null graph is not a tree")
# Cannot create a Prüfer sequence unless `n` is at least two.
if n == 1:
return nx.empty_graph(1)
sequence = [seed.choice(range(n)) for i in range(n - 2)]
return nx.from_prufer_sequence(sequence)
utree = nx.empty_graph(1, create_using)
else:
sequence = [seed.choice(range(n)) for i in range(n - 2)]
utree = nx.from_prufer_sequence(sequence)

if create_using is None:
tree = utree
else:
tree = nx.empty_graph(0, create_using)
if tree.is_directed():
# Use a arbitrary root node and dfs to define edge directions
edges = nx.dfs_edges(utree, source=0)
else:
edges = utree.edges

# Populate the specified graph type
tree.add_nodes_from(utree.nodes)
tree.add_edges_from(edges)

return tree
1 change: 1 addition & 0 deletions networkx/readwrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from networkx.readwrite.gexf import *
from networkx.readwrite.nx_shp import *
from networkx.readwrite.json_graph import *
from networkx.readwrite.text import *
315 changes: 315 additions & 0 deletions networkx/readwrite/tests/test_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
import pytest
import networkx as nx
from textwrap import dedent


def test_directed_tree_str():
# Create a directed forest with labels
graph = nx.balanced_tree(r=2, h=2, create_using=nx.DiGraph)
for node in graph.nodes:
graph.nodes[node]["label"] = "node_" + chr(ord("a") + node)

node_target = dedent(
"""
╙── 0
├─╼ 1
│   ├─╼ 3
│   └─╼ 4
└─╼ 2
├─╼ 5
└─╼ 6
"""
).strip()

label_target = dedent(
"""
╙── node_a
├─╼ node_b
│   ├─╼ node_d
│   └─╼ node_e
└─╼ node_c
├─╼ node_f
└─╼ node_g
"""
).strip()

# Basic node case
ret = nx.forest_str(graph, with_labels=False)
print(ret)
dschult marked this conversation as resolved.
Show resolved Hide resolved
assert ret == node_target

# Basic label case
ret = nx.forest_str(graph, with_labels=True)
print(ret)
assert ret == label_target

# Custom write function case
lines = []
ret = nx.forest_str(graph, write=lines.append, with_labels=False)
assert ret is None
assert lines == node_target.split("\n")

# Smoke test to ensure passing the print function works. To properly test
# this case we would need to capture stdout. (for potential reference
# implementation see :class:`ubelt.util_stream.CaptureStdout`)
ret = nx.forest_str(graph, write=print)
assert ret is None


def test_empty_graph():
assert nx.forest_str(nx.DiGraph()) == "╙"
assert nx.forest_str(nx.Graph()) == "╙"


def test_directed_multi_tree_forest():
tree1 = nx.balanced_tree(r=2, h=2, create_using=nx.DiGraph)
tree2 = nx.balanced_tree(r=2, h=2, create_using=nx.DiGraph)
forest = nx.disjoint_union_all([tree1, tree2])
ret = nx.forest_str(forest)
print(ret)

target = dedent(
"""
╟── 0
╎   ├─╼ 1
╎   │   ├─╼ 3
╎   │   └─╼ 4
╎   └─╼ 2
╎   ├─╼ 5
╎   └─╼ 6
╙── 7
├─╼ 8
│   ├─╼ 10
│   └─╼ 11
└─╼ 9
├─╼ 12
└─╼ 13
"""
).strip()
assert ret == target

tree3 = nx.balanced_tree(r=2, h=2, create_using=nx.DiGraph)
forest = nx.disjoint_union_all([tree1, tree2, tree3])
ret = nx.forest_str(forest, sources=[0, 14, 7])
print(ret)

target = dedent(
"""
╟── 0
╎   ├─╼ 1
╎   │   ├─╼ 3
╎   │   └─╼ 4
╎   └─╼ 2
╎   ├─╼ 5
╎   └─╼ 6
╟── 14
╎   ├─╼ 15
╎   │   ├─╼ 17
╎   │   └─╼ 18
╎   └─╼ 16
╎   ├─╼ 19
╎   └─╼ 20
╙── 7
├─╼ 8
│   ├─╼ 10
│   └─╼ 11
└─╼ 9
├─╼ 12
└─╼ 13
"""
).strip()
assert ret == target

ret = nx.forest_str(forest, sources=[0, 14, 7], ascii_only=True)
print(ret)

target = dedent(
"""
+-- 0
:   |-> 1
:   |   |-> 3
:   |   L-> 4
:   L-> 2
:   |-> 5
:   L-> 6
+-- 14
:   |-> 15
:   |   |-> 17
:   |   L-> 18
:   L-> 16
:   |-> 19
:   L-> 20
+-- 7
|-> 8
|   |-> 10
|   L-> 11
L-> 9
|-> 12
L-> 13
"""
).strip()
assert ret == target


def test_undirected_multi_tree_forest():
tree1 = nx.balanced_tree(r=2, h=2, create_using=nx.Graph)
tree2 = nx.balanced_tree(r=2, h=2, create_using=nx.Graph)
tree2 = nx.relabel_nodes(tree2, {n: n + len(tree1) for n in tree2.nodes})
forest = nx.union(tree1, tree2)
ret = nx.forest_str(forest, sources=[0, 7])
print(ret)

target = dedent(
"""
╟── 0
╎   ├── 1
╎   │   ├── 3
╎   │   └── 4
╎   └── 2
╎   ├── 5
╎   └── 6
╙── 7
├── 8
│   ├── 10
│   └── 11
└── 9
├── 12
└── 13
"""
).strip()
assert ret == target

ret = nx.forest_str(forest, sources=[0, 7], ascii_only=True)
print(ret)

target = dedent(
"""
+-- 0
:   |-- 1
:   |   |-- 3
:   |   L-- 4
:   L-- 2
:   |-- 5
:   L-- 6
+-- 7
|-- 8
|   |-- 10
|   L-- 11
L-- 9
|-- 12
L-- 13
"""
).strip()
assert ret == target


def test_undirected_tree_str():
# Create a directed forest with labels
graph = nx.balanced_tree(r=2, h=2, create_using=nx.Graph)

# arbitrary starting point
nx.forest_str(graph)

node_target0 = dedent(
"""
╙── 0
├── 1
│   ├── 3
│   └── 4
└── 2
├── 5
└── 6
"""
).strip()

# defined starting point
ret = nx.forest_str(graph, sources=[0])
print(ret)
assert ret == node_target0

# defined starting point
node_target2 = dedent(
"""
╙── 2
├── 0
│   └── 1
│   ├── 3
│   └── 4
├── 5
└── 6
"""
).strip()
ret = nx.forest_str(graph, sources=[2])
print(ret)
assert ret == node_target2


def test_forest_str_errors():
ugraph = nx.complete_graph(3, create_using=nx.Graph)

with pytest.raises(nx.NetworkXNotImplemented):
nx.forest_str(ugraph)

dgraph = nx.complete_graph(3, create_using=nx.DiGraph)

with pytest.raises(nx.NetworkXNotImplemented):
nx.forest_str(dgraph)


def test_overspecified_sources():
"""
When sources are directly specified, we wont be able to determine when we
are in the last component, so there will always be a trailing, leftmost
pipe.
"""
graph = nx.disjoint_union_all(
[
nx.balanced_tree(r=2, h=1, create_using=nx.DiGraph),
nx.balanced_tree(r=1, h=2, create_using=nx.DiGraph),
nx.balanced_tree(r=2, h=1, create_using=nx.DiGraph),
]
)

# defined starting point
target1 = dedent(
"""
╟── 0
╎   ├─╼ 1
╎   └─╼ 2
╟── 3
╎   └─╼ 4
╎   └─╼ 5
╟── 6
╎   ├─╼ 7
╎   └─╼ 8
"""
).strip()

target2 = dedent(
"""
╟── 0
╎   ├─╼ 1
╎   └─╼ 2
╟── 3
╎   └─╼ 4
╎   └─╼ 5
╙── 6
├─╼ 7
└─╼ 8
"""
).strip()

lines = []
nx.forest_str(graph, write=lines.append, sources=graph.nodes)
got1 = chr(10).join(lines)
print("got1: ")
print(got1)

lines = []
nx.forest_str(graph, write=lines.append)
got2 = chr(10).join(lines)
print("got2: ")
print(got2)

assert got1 == target1
assert got2 == target2