From 3081df326f541d20303cf93c4fce317cfcd71d1c Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Sat, 19 Sep 2020 08:52:27 -0400 Subject: [PATCH] Rely on PyO3 type conversion for HashMap and HashSet Since PyO3 0.12.0 the trait implementations for converting from hashbrown's HashMap and HashSet types. [1] This commit leverage this so we do not have to internally convert these objects to and from python types. [1] https://github.com/PyO3/pyo3/pull/1114 --- Cargo.toml | 2 +- src/digraph.rs | 81 ++++++++++++++++---------------------------------- src/graph.rs | 26 +++++----------- src/lib.rs | 42 +++++--------------------- 4 files changed, 41 insertions(+), 110 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 51b1755a7..568895717 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ rayon = "1.4" [dependencies.pyo3] version = "0.12.1" -features = ["extension-module"] +features = ["extension-module", "hashbrown"] [dependencies.hashbrown] version = "0.9" diff --git a/src/digraph.rs b/src/digraph.rs index 22d4fb9c1..e2690d0d4 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -856,7 +856,7 @@ impl PyDiGraph { /// specified node. /// :rtype: dict #[text_signature = "(node, /)"] - pub fn adj(&mut self, py: Python, node: usize) -> PyResult { + pub fn adj(&mut self, node: usize) -> HashMap { let index = NodeIndex::new(node); let neighbors = self.graph.neighbors(index); let mut out_map: HashMap = HashMap::new(); @@ -869,11 +869,7 @@ impl PyDiGraph { let edge_w = self.graph.edge_weight(edge.unwrap()); out_map.insert(neighbor.index(), edge_w.unwrap()); } - let out_dict = PyDict::new(py); - for (index, value) in out_map { - out_dict.set_item(index, value)?; - } - Ok(out_dict.into()) + out_map } /// Get the index and data for either the parent or children of a node. @@ -895,10 +891,9 @@ impl PyDiGraph { #[text_signature = "(node, direction, /)"] pub fn adj_direction( &mut self, - py: Python, node: usize, direction: bool, - ) -> PyResult { + ) -> PyResult> { let index = NodeIndex::new(node); let dir = if direction { petgraph::Direction::Incoming @@ -930,11 +925,7 @@ impl PyDiGraph { let edge_w = self.graph.edge_weight(edge); out_map.insert(neighbor.index(), edge_w.unwrap()); } - let out_dict = PyDict::new(py); - for (index, value) in out_map { - out_dict.set_item(index, value)?; - } - Ok(out_dict.into()) + Ok(out_map) } /// Get the index and edge data for all parents of a node. @@ -1253,36 +1244,16 @@ impl PyDiGraph { &mut self, py: Python, other: &PyDiGraph, - node_map: PyObject, + node_map: HashMap, node_map_func: Option, edge_map_func: Option, - ) -> PyResult { + ) -> PyResult> { let mut new_node_map: HashMap = HashMap::new(); - let node_map_dict = node_map.cast_as::(py)?; - let mut node_map_hashmap: HashMap = - HashMap::default(); - for (k, v) in node_map_dict.iter() { - node_map_hashmap.insert(k.extract()?, v.extract()?); - } - - fn node_weight_callable( - py: Python, - node_map: &Option, - node: &PyObject, - ) -> PyResult { - match node_map { - Some(node_map) => { - let res = node_map.call1(py, (node,))?; - Ok(res.to_object(py)) - } - None => Ok(node.clone_ref(py)), - } - } // TODO: Reimplement this without looping over the graphs // Loop over other nodes add add to self graph for node in other.graph.node_indices() { - let new_index = self.graph.add_node(node_weight_callable( + let new_index = self.graph.add_node(weight_transform_callable( py, &node_map_func, &other.graph[node], @@ -1290,30 +1261,16 @@ impl PyDiGraph { new_node_map.insert(node, new_index); } - fn edge_weight_callable( - py: Python, - edge_map: &Option, - edge: &PyObject, - ) -> PyResult { - match edge_map { - Some(edge_map) => { - let res = edge_map.call1(py, (edge,))?; - Ok(res.to_object(py)) - } - None => Ok(edge.clone_ref(py)), - } - } - // loop over other edges and add to self graph for edge in other.graph.edge_references() { let new_p_index = new_node_map.get(&edge.source()).unwrap(); let new_c_index = new_node_map.get(&edge.target()).unwrap(); let weight = - edge_weight_callable(py, &edge_map_func, edge.weight())?; + weight_transform_callable(py, &edge_map_func, edge.weight())?; self.graph.add_edge(*new_p_index, *new_c_index, weight); } // Add edges from map - for (this_index, (index, weight)) in node_map_hashmap.iter() { + for (this_index, (index, weight)) in node_map.iter() { let new_index = new_node_map.get(&NodeIndex::new(*index)).unwrap(); self.graph.add_edge( NodeIndex::new(*this_index), @@ -1321,14 +1278,26 @@ impl PyDiGraph { weight.clone_ref(py), ); } - let out_dict = PyDict::new(py); - for (orig_node, new_node) in new_node_map.iter() { - out_dict.set_item(orig_node.index(), new_node.index())?; + Ok(new_node_map.iter().map(|(old, new)| (old.index(), new.index())).collect()) + } +} + +fn weight_transform_callable( + py: Python, + edge_map: &Option, + edge: &PyObject, +) -> PyResult { + match edge_map { + Some(edge_map) => { + let res = edge_map.call1(py, (edge,))?; + Ok(res.to_object(py)) } - Ok(out_dict.into()) + None => Ok(edge.clone_ref(py)), } } + + #[pyproto] impl PyMappingProtocol for PyDiGraph { /// Return the number of nodes in the graph diff --git a/src/graph.rs b/src/graph.rs index 1f0e29e71..f604b4ec0 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -639,7 +639,7 @@ impl PyGraph { /// edge with the specified node. /// :rtype: dict #[text_signature = "(node, /)"] - pub fn adj(&mut self, py: Python, node: usize) -> PyResult { + pub fn adj(&mut self, node: usize) -> PyResult> { let index = NodeIndex::new(node); let neighbors = self.graph.neighbors(index); let mut out_map: HashMap = HashMap::new(); @@ -649,11 +649,7 @@ impl PyGraph { let edge_w = self.graph.edge_weight(edge.unwrap()); out_map.insert(neighbor.index(), edge_w.unwrap()); } - let out_dict = PyDict::new(py); - for (index, value) in out_map { - out_dict.set_item(index, value)?; - } - Ok(out_dict.into()) + Ok(out_map) } /// Get the degree for a node @@ -843,17 +839,11 @@ impl PyGraph { &mut self, py: Python, other: &PyGraph, - node_map: PyObject, + node_map: HashMap, node_map_func: Option, edge_map_func: Option, - ) -> PyResult { + ) -> PyResult> { let mut new_node_map: HashMap = HashMap::new(); - let node_map_dict = node_map.cast_as::(py)?; - let mut node_map_hashmap: HashMap = - HashMap::default(); - for (k, v) in node_map_dict.iter() { - node_map_hashmap.insert(k.extract()?, v.extract()?); - } fn node_weight_callable( py: Python, @@ -903,7 +893,7 @@ impl PyGraph { self.graph.add_edge(*new_p_index, *new_c_index, weight); } // Add edges from map - for (this_index, (index, weight)) in node_map_hashmap.iter() { + for (this_index, (index, weight)) in node_map.iter() { let new_index = new_node_map.get(&NodeIndex::new(*index)).unwrap(); self.graph.add_edge( NodeIndex::new(*this_index), @@ -915,7 +905,7 @@ impl PyGraph { for (orig_node, new_node) in new_node_map.iter() { out_dict.set_item(orig_node.index(), new_node.index())?; } - Ok(out_dict.into()) + Ok(new_node_map.iter().map(|(old, new)| (old.index(), new.index())).collect()) } } @@ -926,7 +916,7 @@ impl PyMappingProtocol for PyGraph { Ok(self.graph.node_count()) } fn __getitem__(&'p self, idx: usize) -> PyResult<&'p PyObject> { - match self.graph.node_weight(NodeIndex::new(idx as usize)) { + match self.graph.node_weight(NodeIndex::new(idx)) { Some(data) => Ok(data), None => Err(PyIndexError::new_err("No node found for index")), } @@ -935,7 +925,7 @@ impl PyMappingProtocol for PyGraph { fn __setitem__(&'p mut self, idx: usize, value: PyObject) -> PyResult<()> { let data = match self .graph - .node_weight_mut(NodeIndex::new(idx as usize)) + .node_weight_mut(NodeIndex::new(idx)) { Some(node_data) => node_data, None => { diff --git a/src/lib.rs b/src/lib.rs index 1b40e9f26..016cf168c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,7 +36,7 @@ use hashbrown::{HashMap, HashSet}; use pyo3::create_exception; use pyo3::exceptions::{PyException, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PySet}; +use pyo3::types::{PyDict, PyList}; use pyo3::wrap_pyfunction; use pyo3::wrap_pymodule; use pyo3::Python; @@ -307,7 +307,7 @@ fn bfs_successors( /// :rtype: list #[pyfunction] #[text_signature = "(graph, node, /)"] -fn ancestors(py: Python, graph: &digraph::PyDiGraph, node: usize) -> PyObject { +fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> HashSet { let index = NodeIndex::new(node); let mut out_set: HashSet = HashSet::new(); let reverse_graph = Reversed(graph); @@ -317,13 +317,7 @@ fn ancestors(py: Python, graph: &digraph::PyDiGraph, node: usize) -> PyObject { out_set.insert(n_int); } out_set.remove(&node); - let set = PySet::empty(py).expect("Failed to construct empty set"); - { - for val in out_set { - set.add(val).expect("Failed to add to set"); - } - } - set.into() + out_set } /// Return the descendants of a node in a graph. @@ -341,10 +335,9 @@ fn ancestors(py: Python, graph: &digraph::PyDiGraph, node: usize) -> PyObject { #[pyfunction] #[text_signature = "(graph, node, /)"] fn descendants( - py: Python, graph: &digraph::PyDiGraph, node: usize, -) -> PyObject { +) -> HashSet { let index = NodeIndex::new(node); let mut out_set: HashSet = HashSet::new(); let res = algo::dijkstra(graph, index, None, |_| 1); @@ -353,13 +346,7 @@ fn descendants( out_set.insert(n_int); } out_set.remove(&node); - let set = PySet::empty(py).expect("Failed to construct empty set"); - { - for val in out_set { - set.add(val).expect("Failed to add to set"); - } - } - set.into() + out_set } /// Get the lexicographical topological sorted nodes from the provided DAG @@ -457,10 +444,7 @@ fn lexicographical_topological_sort( /// :rtype: dict #[pyfunction] #[text_signature = "(graph, /)"] -fn graph_greedy_color( - py: Python, - graph: &graph::PyGraph, -) -> PyResult { +fn graph_greedy_color(graph: &graph::PyGraph) -> PyResult> { let mut colors: HashMap = HashMap::new(); let mut node_vec: Vec = graph.graph.node_indices().collect(); let mut sort_map: HashMap = HashMap::new(); @@ -487,11 +471,7 @@ fn graph_greedy_color( } colors.insert(u_index.index(), count); } - let out_dict = PyDict::new(py); - for (index, color) in colors { - out_dict.set_item(index, color)?; - } - Ok(out_dict.into()) + Ok(colors) } /// Return the shortest path lengths between ever pair of nodes that has a @@ -1506,11 +1486,3 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pymodule!(generators))?; Ok(()) } - -#[cfg(test)] -mod tests { - #[test] - fn it_works() { - assert_eq!(2 + 2, 4); - } -}