diff --git a/src/graph_impl/stable_graph/mod.rs b/src/graph_impl/stable_graph/mod.rs index bd852d85a..c62ff7533 100644 --- a/src/graph_impl/stable_graph/mod.rs +++ b/src/graph_impl/stable_graph/mod.rs @@ -560,6 +560,24 @@ where } } + /// Return an iterator over all the edges connecting `a` and `b`. + /// + /// - `Directed`: Outgoing edges from `a`. + /// - `Undirected`: All edges connected to `a`. + /// + /// Iterator element type is `EdgeReference`. + pub fn edges_connecting( + &self, + a: NodeIndex, + b: NodeIndex, + ) -> EdgesConnecting { + EdgesConnecting { + target_node: b, + edges: self.edges_directed(a, Direction::Outgoing), + ty: PhantomData, + } + } + /// Lookup if there is an edge from `a` to `b`. /// /// Computes in **O(e')** time, where **e'** is the number of edges @@ -1425,6 +1443,37 @@ where } } +/// Iterator over the multiple directed edges connecting a source node to a target node +#[derive(Debug, Clone)] +pub struct EdgesConnecting<'a, E: 'a, Ty, Ix: 'a = DefaultIx> +where + Ty: EdgeType, + Ix: IndexType, +{ + target_node: NodeIndex, + edges: Edges<'a, E, Ty, Ix>, + ty: PhantomData, +} + +impl<'a, E, Ty, Ix> Iterator for EdgesConnecting<'a, E, Ty, Ix> +where + Ty: EdgeType, + Ix: IndexType, +{ + type Item = EdgeReference<'a, E, Ix>; + + fn next(&mut self) -> Option> { + let target_node = self.target_node; + self.edges + .by_ref() + .find(|&edge| edge.node[1] == target_node) + } + fn size_hint(&self) -> (usize, Option) { + let (_, upper) = self.edges.size_hint(); + (0, upper) + } +} + fn swap_pair(mut x: [T; 2]) -> [T; 2] { x.swap(0, 1); x diff --git a/tests/stable_graph.rs b/tests/stable_graph.rs index 406f02bb1..ab104baa3 100644 --- a/tests/stable_graph.rs +++ b/tests/stable_graph.rs @@ -5,6 +5,8 @@ extern crate petgraph; #[macro_use] extern crate defmac; +use std::collections::HashSet; + use itertools::assert_equal; use petgraph::algo::{kosaraju_scc, min_spanning_tree, tarjan_scc}; use petgraph::dot::Dot; @@ -312,6 +314,70 @@ fn iterators_undir() { itertools::assert_equal(g.neighbors(c), vec![]); } +#[test] +fn iter_multi_edges() { + let mut gr = StableGraph::new(); + let a = gr.add_node("a"); + let b = gr.add_node("b"); + let c = gr.add_node("c"); + + let mut connecting_edges = HashSet::new(); + + gr.add_edge(a, a, ()); + connecting_edges.insert(gr.add_edge(a, b, ())); + gr.add_edge(a, c, ()); + gr.add_edge(c, b, ()); + connecting_edges.insert(gr.add_edge(a, b, ())); + gr.add_edge(b, a, ()); + + let mut iter = gr.edges_connecting(a, b); + + let edge_id = iter.next().unwrap().id(); + assert!(connecting_edges.contains(&edge_id)); + connecting_edges.remove(&edge_id); + + let edge_id = iter.next().unwrap().id(); + assert!(connecting_edges.contains(&edge_id)); + connecting_edges.remove(&edge_id); + + assert_eq!(None, iter.next()); + assert!(connecting_edges.is_empty()); +} + +#[test] +fn iter_multi_undirected_edges() { + let mut gr: StableUnGraph<_, _> = Default::default(); + let a = gr.add_node("a"); + let b = gr.add_node("b"); + let c = gr.add_node("c"); + + let mut connecting_edges = HashSet::new(); + + gr.add_edge(a, a, ()); + connecting_edges.insert(gr.add_edge(a, b, ())); + gr.add_edge(a, c, ()); + gr.add_edge(c, b, ()); + connecting_edges.insert(gr.add_edge(a, b, ())); + connecting_edges.insert(gr.add_edge(b, a, ())); + + let mut iter = gr.edges_connecting(a, b); + + let edge_id = iter.next().unwrap().id(); + assert!(connecting_edges.contains(&edge_id)); + connecting_edges.remove(&edge_id); + + let edge_id = iter.next().unwrap().id(); + assert!(connecting_edges.contains(&edge_id)); + connecting_edges.remove(&edge_id); + + let edge_id = iter.next().unwrap().id(); + assert!(connecting_edges.contains(&edge_id)); + connecting_edges.remove(&edge_id); + + assert_eq!(None, iter.next()); + assert!(connecting_edges.is_empty()); +} + #[test] fn dot() { let mut gr = StableGraph::new();