Skip to content

Commit 05390bb

Browse files
GiggleLiuclaude
andcommitted
fix: filter variant nodes by graph_type in find_cheapest_path
Address Copilot review feedback: - Add filter_by_graph_type() to constrain Dijkstra start/end nodes by the graph_type parameter that was previously ignored - Clarify find_best_entry fallback doc (intentional for export pipeline) - Fix test_find_direct_path to not depend on HashMap iteration order Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 053efb8 commit 05390bb

2 files changed

Lines changed: 45 additions & 9 deletions

File tree

src/rules/graph.rs

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,13 +275,36 @@ impl ReductionGraph {
275275
.collect()
276276
}
277277

278+
/// Filter variant nodes by graph type.
279+
///
280+
/// If `graph_type` is non-empty, keeps only nodes whose `variant["graph"]`
281+
/// matches. Nodes without a `"graph"` key (e.g., QUBO, ILP) pass through
282+
/// unconditionally, since graph type is not part of their variant space.
283+
fn filter_by_graph_type(&self, nodes: &[NodeIndex], graph_type: &str) -> Vec<NodeIndex> {
284+
if graph_type.is_empty() {
285+
return nodes.to_vec();
286+
}
287+
nodes
288+
.iter()
289+
.copied()
290+
.filter(|&idx| {
291+
let node = &self.nodes[self.graph[idx]];
292+
match node.variant.get("graph") {
293+
Some(g) => g == graph_type,
294+
None => true,
295+
}
296+
})
297+
.collect()
298+
}
299+
278300
/// Find the cheapest path using a custom cost function.
279301
///
280302
/// Uses Dijkstra's algorithm on the variant-level graph.
281303
///
282304
/// # Arguments
283-
/// - `source`: (problem_name, graph_type) for source — used to look up the variant node
284-
/// - `target`: (problem_name, graph_type) for target
305+
/// - `source`: `(problem_name, graph_type)` — if `graph_type` is non-empty,
306+
/// only variant nodes with a matching `"graph"` key are used as start points.
307+
/// - `target`: `(problem_name, graph_type)` — same filtering for destinations.
285308
/// - `input_size`: Initial problem size for cost calculations
286309
/// - `cost_fn`: Custom cost function for path optimization
287310
///
@@ -294,17 +317,23 @@ impl ReductionGraph {
294317
input_size: &ProblemSize,
295318
cost_fn: &C,
296319
) -> Option<ReductionPath> {
297-
// Find source nodes matching the name (we try all variant nodes for that name)
298-
let src_nodes = self.name_to_nodes.get(source.0)?;
299-
let dst_nodes = self.name_to_nodes.get(target.0)?;
320+
let all_src = self.name_to_nodes.get(source.0)?;
321+
let all_dst = self.name_to_nodes.get(target.0)?;
322+
323+
let src_nodes = self.filter_by_graph_type(all_src, source.1);
324+
let dst_nodes = self.filter_by_graph_type(all_dst, target.1);
325+
326+
if src_nodes.is_empty() || dst_nodes.is_empty() {
327+
return None;
328+
}
300329

301330
// Build set of target node indices for quick lookup
302331
let dst_set: HashSet<NodeIndex> = dst_nodes.iter().copied().collect();
303332

304333
let mut best_path: Option<(f64, ReductionPath)> = None;
305334

306335
// Try from each source node
307-
for &src_idx in src_nodes {
336+
for &src_idx in &src_nodes {
308337
let mut costs: HashMap<NodeIndex, f64> = HashMap::new();
309338
let mut sizes: HashMap<NodeIndex, ProblemSize> = HashMap::new();
310339
let mut prev: HashMap<NodeIndex, (NodeIndex, petgraph::graph::EdgeIndex)> =
@@ -678,8 +707,11 @@ impl ReductionGraph {
678707
///
679708
/// First tries an exact match on the source variant. If no exact match is found,
680709
/// falls back to a name-only match (returning the first entry whose source and
681-
/// target names match). This allows looking up overhead for specific variants
682-
/// (e.g., `K3`) when only the general variant (e.g., `KN`) is registered.
710+
/// target names match). This is intentional: specific variants (e.g., `K3`) may
711+
/// not have their own `#[reduction]` entry, but the general variant (`KN`) covers
712+
/// them with the same overhead polynomial. The fallback is safe because cross-name
713+
/// reductions share the same overhead regardless of source variant; it is only
714+
/// used by the JSON export pipeline (`export::lookup_overhead`).
683715
pub fn find_best_entry(
684716
&self,
685717
source_name: &str,

src/unit_tests/reduction_graph.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,11 @@ fn test_find_direct_path() {
133133

134134
let paths = graph.find_paths::<MaximumIndependentSet<SimpleGraph, i32>, MinimumVertexCover<SimpleGraph, i32>>();
135135
assert!(!paths.is_empty());
136-
assert_eq!(paths[0].len(), 1);
136+
assert!(
137+
paths.iter().any(|p| p.len() == 1),
138+
"Should contain a direct (1-step) path, got lengths: {:?}",
139+
paths.iter().map(|p| p.len()).collect::<Vec<_>>()
140+
);
137141
}
138142

139143
#[test]

0 commit comments

Comments
 (0)