From 3d2dfded4f5ab1d7e9db455e58c8beb37bc0ca24 Mon Sep 17 00:00:00 2001 From: At-sushi Date: Fri, 10 Jun 2022 16:07:34 +0900 Subject: [PATCH] Fix Floyd-Warshall algorithm behavior toward undirected graphs (#487) --- src/algo/floyd_warshall.rs | 9 +++-- tests/floyd_warshall.rs | 70 +++++++++++++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/src/algo/floyd_warshall.rs b/src/algo/floyd_warshall.rs index b22efdf61..bfe4d7eff 100644 --- a/src/algo/floyd_warshall.rs +++ b/src/algo/floyd_warshall.rs @@ -3,7 +3,9 @@ use std::collections::HashMap; use std::hash::Hash; use crate::algo::{BoundedMeasure, NegativeCycle}; -use crate::visit::{EdgeRef, IntoEdgeReferences, IntoNodeIdentifiers, NodeCompactIndexable}; +use crate::visit::{ + EdgeRef, GraphProp, IntoEdgeReferences, IntoNodeIdentifiers, NodeCompactIndexable, +}; #[allow(clippy::type_complexity, clippy::needless_range_loop)] /// \[Generic\] [Floyd–Warshall algorithm](https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm) is an algorithm for all pairs shortest path problem @@ -81,7 +83,7 @@ pub fn floyd_warshall( mut edge_cost: F, ) -> Result, NegativeCycle> where - G: NodeCompactIndexable + IntoEdgeReferences + IntoNodeIdentifiers, + G: NodeCompactIndexable + IntoEdgeReferences + IntoNodeIdentifiers + GraphProp, G::NodeId: Eq + Hash, F: FnMut(G::EdgeRef) -> K, K: BoundedMeasure + Copy, @@ -94,6 +96,9 @@ where // init distances of paths with no intermediate nodes for edge in graph.edge_references() { dist[graph.to_index(edge.source())][graph.to_index(edge.target())] = edge_cost(edge); + if !graph.is_directed() { + dist[graph.to_index(edge.target())][graph.to_index(edge.source())] = edge_cost(edge); + } } // distance of each node to itself is 0(default value) diff --git a/tests/floyd_warshall.rs b/tests/floyd_warshall.rs index 6bea6443b..6e6ab68a2 100644 --- a/tests/floyd_warshall.rs +++ b/tests/floyd_warshall.rs @@ -1,5 +1,5 @@ use petgraph::algo::floyd_warshall; -use petgraph::{prelude::*, Directed, Graph}; +use petgraph::{prelude::*, Directed, Graph, Undirected}; use std::collections::HashMap; #[test] @@ -181,6 +181,74 @@ fn floyd_warshall_weighted() { } } +#[test] +fn floyd_warshall_weighted_undirected() { + let mut graph: Graph<(), (), Undirected> = Graph::new_undirected(); + let a = graph.add_node(()); + let b = graph.add_node(()); + let c = graph.add_node(()); + let d = graph.add_node(()); + + graph.extend_with_edges(&[(a, b), (a, c), (a, d), (b, d), (c, b), (c, d)]); + + let inf = std::i32::MAX; + let expected_res: HashMap<(NodeIndex, NodeIndex), i32> = [ + ((a, a), 0), + ((a, b), 1), + ((a, c), 3), + ((a, d), 3), + ((b, a), 1), + ((b, b), 0), + ((b, c), 2), + ((b, d), 2), + ((c, a), 3), + ((c, b), 2), + ((c, c), 0), + ((c, d), 2), + ((d, a), 3), + ((d, b), 2), + ((d, c), 2), + ((d, d), 0), + ] + .iter() + .cloned() + .collect(); + + let weight_map: HashMap<(NodeIndex, NodeIndex), i32> = [ + ((a, a), 0), + ((a, b), 1), + ((a, c), 4), + ((a, d), 10), + ((b, b), 0), + ((b, d), 2), + ((c, b), 2), + ((c, c), 0), + ((c, d), 2), + ] + .iter() + .cloned() + .collect(); + + let res = floyd_warshall(&graph, |edge| { + if let Some(weight) = weight_map.get(&(edge.source(), edge.target())) { + *weight + } else { + inf + } + }) + .unwrap(); + + let nodes = [a, b, c, d]; + for node1 in &nodes { + for node2 in &nodes { + assert_eq!( + res.get(&(*node1, *node2)).unwrap(), + expected_res.get(&(*node1, *node2)).unwrap() + ); + } + } +} + #[test] fn floyd_warshall_negative_cycle() { let mut graph: Graph<(), (), Directed> = Graph::new();