Skip to content

Commit

Permalink
Loosen trait constraints and simplify structure for longest_path (#1195)
Browse files Browse the repository at this point in the history
In the recently merged #1192 a new generic DAG longest_path function was
added to rustworkx-core. However, the trait bounds on the function were
a bit tighter than they needed to be. The traits were forcing NodeId to
be of a NodeIndex type and this wasn't really required. The only
requirement that the NodeId type can be put on a hashmap and do a
partial compare (that implements Hash, Eq, and PartialOrd). Also the
IntoNeighborsDirected wasn't required because it's methods weren't ever
used. This commit loosens the traits bounds to facilitate this. At the
same time this also simplifies the code structure a bit to reduce the
separation of the rust code structure in the rustworkx crate using
longest_path().
  • Loading branch information
mtreinish committed May 18, 2024
1 parent 8e81911 commit 0ec113b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 84 deletions.
25 changes: 11 additions & 14 deletions rustworkx-core/src/dag_algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use std::cmp::Eq;
use std::hash::Hash;

use hashbrown::HashMap;

use petgraph::algo;
use petgraph::graph::NodeIndex;
use petgraph::visit::{
EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers,
Visitable,
EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNodeIdentifiers, Visitable,
};
use petgraph::Directed;

Expand Down Expand Up @@ -51,7 +53,6 @@ type LongestPathResult<G, T, E> = Result<Option<(Vec<NodeId<G>>, T)>, E>;
/// # Example
/// ```
/// use petgraph::graph::DiGraph;
/// use petgraph::graph::NodeIndex;
/// use petgraph::Directed;
/// use rustworkx_core::dag_algo::longest_path;
///
Expand All @@ -69,14 +70,10 @@ type LongestPathResult<G, T, E> = Result<Option<(Vec<NodeId<G>>, T)>, E>;
/// ```
pub fn longest_path<G, F, T, E>(graph: G, mut weight_fn: F) -> LongestPathResult<G, T, E>
where
G: GraphProp<EdgeType = Directed>
+ IntoNodeIdentifiers
+ IntoNeighborsDirected
+ IntoEdgesDirected
+ Visitable
+ GraphBase<NodeId = NodeIndex>,
G: GraphProp<EdgeType = Directed> + IntoNodeIdentifiers + IntoEdgesDirected + Visitable,
F: FnMut(G::EdgeRef) -> Result<T, E>,
T: Num + Zero + PartialOrd + Copy,
<G as GraphBase>::NodeId: Hash + Eq + PartialOrd,
{
let mut path: Vec<NodeId<G>> = Vec::new();
let nodes = match algo::toposort(graph, None) {
Expand All @@ -88,20 +85,20 @@ where
return Ok(Some((path, T::zero())));
}

let mut dist: HashMap<NodeIndex, (T, NodeIndex)> = HashMap::with_capacity(nodes.len()); // Stores the distance and the previous node
let mut dist: HashMap<G::NodeId, (T, G::NodeId)> = HashMap::with_capacity(nodes.len()); // Stores the distance and the previous node

// Iterate over nodes in topological order
for node in nodes {
let parents = graph.edges_directed(node, petgraph::Direction::Incoming);
let mut incoming_path: Vec<(T, NodeIndex)> = Vec::new(); // Stores the distance and the previous node for each parent
let mut incoming_path: Vec<(T, G::NodeId)> = Vec::new(); // Stores the distance and the previous node for each parent
for p_edge in parents {
let p_node = p_edge.source();
let weight: T = weight_fn(p_edge)?;
let length = dist[&p_node].0 + weight;
incoming_path.push((length, p_node));
}
// Determine the maximum distance and corresponding parent node
let max_path: (T, NodeIndex) = incoming_path
let max_path: (T, G::NodeId) = incoming_path
.into_iter()
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
.unwrap_or((T::zero(), node)); // If there are no incoming edges, the distance is zero
Expand All @@ -114,7 +111,7 @@ where
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap();
let mut v = *first;
let mut u: Option<NodeIndex> = None;
let mut u: Option<G::NodeId> = None;
// Backtrack from this node to find the path
while u.map_or(true, |u| u != v) {
path.push(v);
Expand Down
64 changes: 0 additions & 64 deletions src/dag_algo/longest_path.rs

This file was deleted.

57 changes: 51 additions & 6 deletions src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
// License for the specific language governing permissions and limitations
// under the License.

mod longest_path;

use super::DictMap;
use hashbrown::{HashMap, HashSet};
use indexmap::IndexSet;
Expand All @@ -22,6 +20,7 @@ use std::collections::BinaryHeap;
use super::iterators::NodeIndices;
use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph};

use rustworkx_core::dag_algo::longest_path as core_longest_path;
use rustworkx_core::traversal::dfs_edges;

use pyo3::exceptions::PyValueError;
Expand All @@ -32,8 +31,54 @@ use pyo3::Python;
use petgraph::algo;
use petgraph::graph::NodeIndex;
use petgraph::prelude::*;
use petgraph::stable_graph::EdgeReference;
use petgraph::visit::NodeCount;

use num_traits::{Num, Zero};

/// Calculate the longest path in a directed acyclic graph (DAG).
///
/// This function interfaces with the Python `PyDiGraph` object to compute the longest path
/// using the provided weight function.
///
/// # Arguments
/// * `graph`: Reference to a `PyDiGraph` object.
/// * `weight_fn`: A callable that takes the source node index, target node index, and the weight
/// object and returns the weight of the edge as a `PyResult<T>`.
///
/// # Type Parameters
/// * `F`: Type of the weight function.
/// * `T`: The type of the edge weight. Must implement `Num`, `Zero`, `PartialOrd`, and `Copy`.
///
/// # Returns
/// * `PyResult<(Vec<G::NodeId>, T)>` representing the longest path as a sequence of node indices and its total weight.
fn longest_path<F, T>(graph: &digraph::PyDiGraph, mut weight_fn: F) -> PyResult<(Vec<usize>, T)>
where
F: FnMut(usize, usize, &PyObject) -> PyResult<T>,
T: Num + Zero + PartialOrd + Copy,
{
let dag = &graph.graph;

// Create a new weight function that matches the required signature
let edge_cost = |edge_ref: EdgeReference<'_, PyObject>| -> Result<T, PyErr> {
let source = edge_ref.source().index();
let target = edge_ref.target().index();
let weight = edge_ref.weight();
weight_fn(source, target, weight)
};

let (path, path_weight) = match core_longest_path(dag, edge_cost) {
Ok(Some((path, path_weight))) => (
path.into_iter().map(NodeIndex::index).collect(),
path_weight,
),
Ok(None) => return Err(DAGHasCycle::new_err("The graph contains a cycle")),
Err(e) => return Err(e),
};

Ok((path, path_weight))
}

/// Return a pair of [`petgraph::Direction`] values corresponding to the "forwards" and "backwards"
/// direction of graph traversal, based on whether the graph is being traved forwards (following
/// the edges) or backward (reversing along edges). The order of returns is (forwards, backwards).
Expand Down Expand Up @@ -82,7 +127,7 @@ pub fn dag_longest_path(
}
};
Ok(NodeIndices {
nodes: longest_path::longest_path(graph, edge_weight_callable)?.0,
nodes: longest_path(graph, edge_weight_callable)?.0,
})
}

Expand Down Expand Up @@ -121,7 +166,7 @@ pub fn dag_longest_path_length(
None => Ok(1),
}
};
let (_, path_weight) = longest_path::longest_path(graph, edge_weight_callable)?;
let (_, path_weight) = longest_path(graph, edge_weight_callable)?;
Ok(path_weight)
}

Expand Down Expand Up @@ -163,7 +208,7 @@ pub fn dag_weighted_longest_path(
Ok(float_res)
};
Ok(NodeIndices {
nodes: longest_path::longest_path(graph, edge_weight_callable)?.0,
nodes: longest_path(graph, edge_weight_callable)?.0,
})
}

Expand Down Expand Up @@ -204,7 +249,7 @@ pub fn dag_weighted_longest_path_length(
}
Ok(float_res)
};
let (_, path_weight) = longest_path::longest_path(graph, edge_weight_callable)?;
let (_, path_weight) = longest_path(graph, edge_weight_callable)?;
Ok(path_weight)
}

Expand Down

0 comments on commit 0ec113b

Please sign in to comment.