Skip to content

Commit

Permalink
Rely on PyO3 type conversion for HashMap and HashSet
Browse files Browse the repository at this point in the history
Since PyO3 0.12.0 the trait implementations for converting from
hashbrown's HashMap and HashSet types. [1] This commit leverage
this so we do not have to internally convert these objects to and from
python types.

[1] PyO3/pyo3#1114
  • Loading branch information
mtreinish committed Sep 19, 2020
1 parent c207e4f commit 3081df3
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 110 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Expand Up @@ -26,7 +26,7 @@ rayon = "1.4"

[dependencies.pyo3]
version = "0.12.1"
features = ["extension-module"]
features = ["extension-module", "hashbrown"]

[dependencies.hashbrown]
version = "0.9"
Expand Down
81 changes: 25 additions & 56 deletions src/digraph.rs
Expand Up @@ -856,7 +856,7 @@ impl PyDiGraph {
/// specified node.
/// :rtype: dict
#[text_signature = "(node, /)"]
pub fn adj(&mut self, py: Python, node: usize) -> PyResult<PyObject> {
pub fn adj(&mut self, node: usize) -> HashMap<usize, &PyObject> {
let index = NodeIndex::new(node);
let neighbors = self.graph.neighbors(index);
let mut out_map: HashMap<usize, &PyObject> = HashMap::new();
Expand All @@ -869,11 +869,7 @@ impl PyDiGraph {
let edge_w = self.graph.edge_weight(edge.unwrap());
out_map.insert(neighbor.index(), edge_w.unwrap());
}
let out_dict = PyDict::new(py);
for (index, value) in out_map {
out_dict.set_item(index, value)?;
}
Ok(out_dict.into())
out_map
}

/// Get the index and data for either the parent or children of a node.
Expand All @@ -895,10 +891,9 @@ impl PyDiGraph {
#[text_signature = "(node, direction, /)"]
pub fn adj_direction(
&mut self,
py: Python,
node: usize,
direction: bool,
) -> PyResult<PyObject> {
) -> PyResult<HashMap<usize, &PyObject>> {
let index = NodeIndex::new(node);
let dir = if direction {
petgraph::Direction::Incoming
Expand Down Expand Up @@ -930,11 +925,7 @@ impl PyDiGraph {
let edge_w = self.graph.edge_weight(edge);
out_map.insert(neighbor.index(), edge_w.unwrap());
}
let out_dict = PyDict::new(py);
for (index, value) in out_map {
out_dict.set_item(index, value)?;
}
Ok(out_dict.into())
Ok(out_map)
}

/// Get the index and edge data for all parents of a node.
Expand Down Expand Up @@ -1253,82 +1244,60 @@ impl PyDiGraph {
&mut self,
py: Python,
other: &PyDiGraph,
node_map: PyObject,
node_map: HashMap<usize, (usize, PyObject)>,
node_map_func: Option<PyObject>,
edge_map_func: Option<PyObject>,
) -> PyResult<PyObject> {
) -> PyResult<HashMap<usize, usize>> {
let mut new_node_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
let node_map_dict = node_map.cast_as::<PyDict>(py)?;
let mut node_map_hashmap: HashMap<usize, (usize, PyObject)> =
HashMap::default();
for (k, v) in node_map_dict.iter() {
node_map_hashmap.insert(k.extract()?, v.extract()?);
}

fn node_weight_callable(
py: Python,
node_map: &Option<PyObject>,
node: &PyObject,
) -> PyResult<PyObject> {
match node_map {
Some(node_map) => {
let res = node_map.call1(py, (node,))?;
Ok(res.to_object(py))
}
None => Ok(node.clone_ref(py)),
}
}

// TODO: Reimplement this without looping over the graphs
// Loop over other nodes add add to self graph
for node in other.graph.node_indices() {
let new_index = self.graph.add_node(node_weight_callable(
let new_index = self.graph.add_node(weight_transform_callable(
py,
&node_map_func,
&other.graph[node],
)?);
new_node_map.insert(node, new_index);
}

fn edge_weight_callable(
py: Python,
edge_map: &Option<PyObject>,
edge: &PyObject,
) -> PyResult<PyObject> {
match edge_map {
Some(edge_map) => {
let res = edge_map.call1(py, (edge,))?;
Ok(res.to_object(py))
}
None => Ok(edge.clone_ref(py)),
}
}

// loop over other edges and add to self graph
for edge in other.graph.edge_references() {
let new_p_index = new_node_map.get(&edge.source()).unwrap();
let new_c_index = new_node_map.get(&edge.target()).unwrap();
let weight =
edge_weight_callable(py, &edge_map_func, edge.weight())?;
weight_transform_callable(py, &edge_map_func, edge.weight())?;
self.graph.add_edge(*new_p_index, *new_c_index, weight);
}
// Add edges from map
for (this_index, (index, weight)) in node_map_hashmap.iter() {
for (this_index, (index, weight)) in node_map.iter() {
let new_index = new_node_map.get(&NodeIndex::new(*index)).unwrap();
self.graph.add_edge(
NodeIndex::new(*this_index),
*new_index,
weight.clone_ref(py),
);
}
let out_dict = PyDict::new(py);
for (orig_node, new_node) in new_node_map.iter() {
out_dict.set_item(orig_node.index(), new_node.index())?;
Ok(new_node_map.iter().map(|(old, new)| (old.index(), new.index())).collect())
}
}

fn weight_transform_callable(
py: Python,
edge_map: &Option<PyObject>,
edge: &PyObject,
) -> PyResult<PyObject> {
match edge_map {
Some(edge_map) => {
let res = edge_map.call1(py, (edge,))?;
Ok(res.to_object(py))
}
Ok(out_dict.into())
None => Ok(edge.clone_ref(py)),
}
}



#[pyproto]
impl PyMappingProtocol for PyDiGraph {
/// Return the number of nodes in the graph
Expand Down
26 changes: 8 additions & 18 deletions src/graph.rs
Expand Up @@ -639,7 +639,7 @@ impl PyGraph {
/// edge with the specified node.
/// :rtype: dict
#[text_signature = "(node, /)"]
pub fn adj(&mut self, py: Python, node: usize) -> PyResult<PyObject> {
pub fn adj(&mut self, node: usize) -> PyResult<HashMap<usize, &PyObject>> {
let index = NodeIndex::new(node);
let neighbors = self.graph.neighbors(index);
let mut out_map: HashMap<usize, &PyObject> = HashMap::new();
Expand All @@ -649,11 +649,7 @@ impl PyGraph {
let edge_w = self.graph.edge_weight(edge.unwrap());
out_map.insert(neighbor.index(), edge_w.unwrap());
}
let out_dict = PyDict::new(py);
for (index, value) in out_map {
out_dict.set_item(index, value)?;
}
Ok(out_dict.into())
Ok(out_map)
}

/// Get the degree for a node
Expand Down Expand Up @@ -843,17 +839,11 @@ impl PyGraph {
&mut self,
py: Python,
other: &PyGraph,
node_map: PyObject,
node_map: HashMap<usize, (usize, PyObject)>,
node_map_func: Option<PyObject>,
edge_map_func: Option<PyObject>,
) -> PyResult<PyObject> {
) -> PyResult<HashMap<usize, usize>> {
let mut new_node_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
let node_map_dict = node_map.cast_as::<PyDict>(py)?;
let mut node_map_hashmap: HashMap<usize, (usize, PyObject)> =
HashMap::default();
for (k, v) in node_map_dict.iter() {
node_map_hashmap.insert(k.extract()?, v.extract()?);
}

fn node_weight_callable(
py: Python,
Expand Down Expand Up @@ -903,7 +893,7 @@ impl PyGraph {
self.graph.add_edge(*new_p_index, *new_c_index, weight);
}
// Add edges from map
for (this_index, (index, weight)) in node_map_hashmap.iter() {
for (this_index, (index, weight)) in node_map.iter() {
let new_index = new_node_map.get(&NodeIndex::new(*index)).unwrap();
self.graph.add_edge(
NodeIndex::new(*this_index),
Expand All @@ -915,7 +905,7 @@ impl PyGraph {
for (orig_node, new_node) in new_node_map.iter() {
out_dict.set_item(orig_node.index(), new_node.index())?;
}
Ok(out_dict.into())
Ok(new_node_map.iter().map(|(old, new)| (old.index(), new.index())).collect())
}
}

Expand All @@ -926,7 +916,7 @@ impl PyMappingProtocol for PyGraph {
Ok(self.graph.node_count())
}
fn __getitem__(&'p self, idx: usize) -> PyResult<&'p PyObject> {
match self.graph.node_weight(NodeIndex::new(idx as usize)) {
match self.graph.node_weight(NodeIndex::new(idx)) {
Some(data) => Ok(data),
None => Err(PyIndexError::new_err("No node found for index")),
}
Expand All @@ -935,7 +925,7 @@ impl PyMappingProtocol for PyGraph {
fn __setitem__(&'p mut self, idx: usize, value: PyObject) -> PyResult<()> {
let data = match self
.graph
.node_weight_mut(NodeIndex::new(idx as usize))
.node_weight_mut(NodeIndex::new(idx))
{
Some(node_data) => node_data,
None => {
Expand Down
42 changes: 7 additions & 35 deletions src/lib.rs
Expand Up @@ -36,7 +36,7 @@ use hashbrown::{HashMap, HashSet};
use pyo3::create_exception;
use pyo3::exceptions::{PyException, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PySet};
use pyo3::types::{PyDict, PyList};
use pyo3::wrap_pyfunction;
use pyo3::wrap_pymodule;
use pyo3::Python;
Expand Down Expand Up @@ -307,7 +307,7 @@ fn bfs_successors(
/// :rtype: list
#[pyfunction]
#[text_signature = "(graph, node, /)"]
fn ancestors(py: Python, graph: &digraph::PyDiGraph, node: usize) -> PyObject {
fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
let index = NodeIndex::new(node);
let mut out_set: HashSet<usize> = HashSet::new();
let reverse_graph = Reversed(graph);
Expand All @@ -317,13 +317,7 @@ fn ancestors(py: Python, graph: &digraph::PyDiGraph, node: usize) -> PyObject {
out_set.insert(n_int);
}
out_set.remove(&node);
let set = PySet::empty(py).expect("Failed to construct empty set");
{
for val in out_set {
set.add(val).expect("Failed to add to set");
}
}
set.into()
out_set
}

/// Return the descendants of a node in a graph.
Expand All @@ -341,10 +335,9 @@ fn ancestors(py: Python, graph: &digraph::PyDiGraph, node: usize) -> PyObject {
#[pyfunction]
#[text_signature = "(graph, node, /)"]
fn descendants(
py: Python,
graph: &digraph::PyDiGraph,
node: usize,
) -> PyObject {
) -> HashSet<usize> {
let index = NodeIndex::new(node);
let mut out_set: HashSet<usize> = HashSet::new();
let res = algo::dijkstra(graph, index, None, |_| 1);
Expand All @@ -353,13 +346,7 @@ fn descendants(
out_set.insert(n_int);
}
out_set.remove(&node);
let set = PySet::empty(py).expect("Failed to construct empty set");
{
for val in out_set {
set.add(val).expect("Failed to add to set");
}
}
set.into()
out_set
}

/// Get the lexicographical topological sorted nodes from the provided DAG
Expand Down Expand Up @@ -457,10 +444,7 @@ fn lexicographical_topological_sort(
/// :rtype: dict
#[pyfunction]
#[text_signature = "(graph, /)"]
fn graph_greedy_color(
py: Python,
graph: &graph::PyGraph,
) -> PyResult<PyObject> {
fn graph_greedy_color(graph: &graph::PyGraph) -> PyResult<HashMap<usize, usize>> {
let mut colors: HashMap<usize, usize> = HashMap::new();
let mut node_vec: Vec<NodeIndex> = graph.graph.node_indices().collect();
let mut sort_map: HashMap<NodeIndex, usize> = HashMap::new();
Expand All @@ -487,11 +471,7 @@ fn graph_greedy_color(
}
colors.insert(u_index.index(), count);
}
let out_dict = PyDict::new(py);
for (index, color) in colors {
out_dict.set_item(index, color)?;
}
Ok(out_dict.into())
Ok(colors)
}

/// Return the shortest path lengths between ever pair of nodes that has a
Expand Down Expand Up @@ -1506,11 +1486,3 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(generators))?;
Ok(())
}

#[cfg(test)]
mod tests {
#[test]
fn it_works() {
assert_eq!(2 + 2, 4);
}
}

0 comments on commit 3081df3

Please sign in to comment.