problemreductions/rules/
graph.rs

1//! Runtime reduction graph for discovering and executing reduction paths.
2//!
3//! The graph uses variant-level nodes: each node is a unique `(problem_name, variant)` pair.
4//! Nodes are built in two phases:
5//! 1. From `VariantEntry` inventory (with complexity metadata)
6//! 2. From `ReductionEntry` inventory (fallback for backwards compatibility)
7//!
8//! Edges come exclusively from `#[reduction]` registrations via `inventory::iter::<ReductionEntry>`.
9//!
10//! This module implements:
11//! - Variant-level graph construction from `VariantEntry` and `ReductionEntry` inventory
12//! - Dijkstra's algorithm with custom cost functions for optimal paths
13//! - JSON export for documentation and visualization
14
15use crate::rules::cost::PathCostFn;
16use crate::rules::registry::{ReductionEntry, ReductionOverhead};
17use crate::rules::traits::DynReductionResult;
18use crate::types::ProblemSize;
19use ordered_float::OrderedFloat;
20use petgraph::algo::all_simple_paths;
21use petgraph::graph::{DiGraph, NodeIndex};
22use petgraph::visit::EdgeRef;
23use serde::Serialize;
24use std::any::Any;
25use std::cmp::Reverse;
26use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet};
27
28/// A source/target pair from the reduction graph, returned by
29/// [`ReductionGraph::outgoing_reductions`] and [`ReductionGraph::incoming_reductions`].
30#[derive(Debug, Clone)]
31pub struct ReductionEdgeInfo {
32    pub source_name: &'static str,
33    pub source_variant: BTreeMap<String, String>,
34    pub target_name: &'static str,
35    pub target_variant: BTreeMap<String, String>,
36    pub overhead: ReductionOverhead,
37}
38
39/// Internal edge data combining overhead and executable reduce function.
40#[derive(Clone)]
41pub(crate) struct ReductionEdgeData {
42    pub overhead: ReductionOverhead,
43    pub reduce_fn: fn(&dyn Any) -> Box<dyn DynReductionResult>,
44}
45
46/// JSON-serializable representation of the reduction graph.
47#[derive(Debug, Clone, Serialize)]
48pub(crate) struct ReductionGraphJson {
49    /// List of problem type nodes.
50    pub(crate) nodes: Vec<NodeJson>,
51    /// List of reduction edges.
52    pub(crate) edges: Vec<EdgeJson>,
53}
54
55impl ReductionGraphJson {
56    /// Get the source node of an edge.
57    #[cfg_attr(not(test), allow(dead_code))]
58    pub(crate) fn source_node(&self, edge: &EdgeJson) -> &NodeJson {
59        &self.nodes[edge.source]
60    }
61
62    /// Get the target node of an edge.
63    #[cfg_attr(not(test), allow(dead_code))]
64    pub(crate) fn target_node(&self, edge: &EdgeJson) -> &NodeJson {
65        &self.nodes[edge.target]
66    }
67}
68
69/// A node in the reduction graph JSON.
70#[derive(Debug, Clone, Serialize)]
71pub(crate) struct NodeJson {
72    /// Base problem name (e.g., "MaximumIndependentSet").
73    pub(crate) name: String,
74    /// Variant attributes as key-value pairs.
75    pub(crate) variant: BTreeMap<String, String>,
76    /// Category of the problem (e.g., "graph", "set", "optimization", "satisfiability", "specialized").
77    pub(crate) category: String,
78    /// Relative rustdoc path (e.g., "models/graph/maximum_independent_set").
79    pub(crate) doc_path: String,
80    /// Worst-case time complexity expression (empty if not declared).
81    pub(crate) complexity: String,
82}
83
84/// Internal reference to a problem variant, used as HashMap key.
85#[derive(Debug, Clone, PartialEq, Eq, Hash)]
86struct VariantRef {
87    name: String,
88    variant: BTreeMap<String, String>,
89}
90
91/// A single output field in the reduction overhead.
92#[derive(Debug, Clone, Serialize)]
93pub(crate) struct OverheadFieldJson {
94    /// Output field name (e.g., "num_vars").
95    pub(crate) field: String,
96    /// Formula as a human-readable string (e.g., "num_vertices").
97    pub(crate) formula: String,
98}
99
100/// An edge in the reduction graph JSON.
101#[derive(Debug, Clone, Serialize)]
102pub(crate) struct EdgeJson {
103    /// Index into the `nodes` array for the source problem variant.
104    pub(crate) source: usize,
105    /// Index into the `nodes` array for the target problem variant.
106    pub(crate) target: usize,
107    /// Reduction overhead: output size as expressions of input size.
108    pub(crate) overhead: Vec<OverheadFieldJson>,
109    /// Relative rustdoc path for the reduction module.
110    pub(crate) doc_path: String,
111}
112
113/// A path through the variant-level reduction graph.
114#[derive(Debug, Clone)]
115pub struct ReductionPath {
116    /// Variant-level steps in the path.
117    pub steps: Vec<ReductionStep>,
118}
119
120impl ReductionPath {
121    /// Number of edges (reductions) in the path.
122    pub fn len(&self) -> usize {
123        if self.steps.is_empty() {
124            0
125        } else {
126            self.steps.len() - 1
127        }
128    }
129
130    /// Whether the path is empty.
131    pub fn is_empty(&self) -> bool {
132        self.steps.is_empty()
133    }
134
135    /// Source problem name.
136    pub fn source(&self) -> Option<&str> {
137        self.steps.first().map(|s| s.name.as_str())
138    }
139
140    /// Target problem name.
141    pub fn target(&self) -> Option<&str> {
142        self.steps.last().map(|s| s.name.as_str())
143    }
144
145    /// Name-level path (deduplicated consecutive same-name steps).
146    pub fn type_names(&self) -> Vec<&str> {
147        let mut names: Vec<&str> = Vec::new();
148        for step in &self.steps {
149            if names.last() != Some(&step.name.as_str()) {
150                names.push(&step.name);
151            }
152        }
153        names
154    }
155}
156
157impl std::fmt::Display for ReductionPath {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        let mut prev_name = "";
160        for step in &self.steps {
161            if step.name != prev_name {
162                if prev_name.is_empty() {
163                    write!(f, "{step}")?;
164                } else {
165                    write!(f, " → {step}")?;
166                }
167                prev_name = &step.name;
168            }
169        }
170        Ok(())
171    }
172}
173
174/// A node in a variant-level reduction path.
175#[derive(Debug, Clone, Serialize)]
176pub struct ReductionStep {
177    /// Problem name (e.g., "MaximumIndependentSet").
178    pub name: String,
179    /// Variant at this point (e.g., {"graph": "KingsSubgraph", "weight": "i32"}).
180    pub variant: BTreeMap<String, String>,
181}
182
183impl std::fmt::Display for ReductionStep {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        write!(f, "{}", self.name)?;
186        if !self.variant.is_empty() {
187            let vars: Vec<_> = self
188                .variant
189                .iter()
190                .map(|(k, v)| format!("{k}: {v:?}"))
191                .collect();
192            write!(f, " {{{}}}", vars.join(", "))?;
193        }
194        Ok(())
195    }
196}
197
198/// Classify a problem's category from its module path.
199/// Expected format: "problemreductions::models::<category>::<module_name>"
200pub(crate) fn classify_problem_category(module_path: &str) -> &str {
201    let parts: Vec<&str> = module_path.split("::").collect();
202    if parts.len() >= 3 {
203        if let Some(pos) = parts.iter().position(|&p| p == "models") {
204            if pos + 1 < parts.len() {
205                return parts[pos + 1];
206            }
207        }
208    }
209    "other"
210}
211
212/// Internal node data for the variant-level graph.
213#[derive(Debug, Clone)]
214struct VariantNode {
215    name: &'static str,
216    variant: BTreeMap<String, String>,
217    complexity: &'static str,
218}
219
220/// Information about a neighbor in the reduction graph.
221#[derive(Debug, Clone)]
222pub struct NeighborInfo {
223    /// Problem name.
224    pub name: &'static str,
225    /// Variant attributes.
226    pub variant: BTreeMap<String, String>,
227    /// Hop distance from the source.
228    pub hops: usize,
229}
230
231/// Direction for graph traversal.
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233pub enum TraversalDirection {
234    /// Follow outgoing edges (what can this reduce to?).
235    Outgoing,
236    /// Follow incoming edges (what can reduce to this?).
237    Incoming,
238    /// Follow edges in both directions.
239    Both,
240}
241
242/// A tree node for neighbor traversal results.
243#[derive(Debug, Clone)]
244pub struct NeighborTree {
245    /// Problem name.
246    pub name: String,
247    /// Variant attributes.
248    pub variant: BTreeMap<String, String>,
249    /// Child nodes (sorted by name).
250    pub children: Vec<NeighborTree>,
251}
252
253/// Runtime graph of all registered reductions.
254///
255/// Uses variant-level nodes: each node is a unique `(problem_name, variant)` pair.
256/// All edges come from `inventory::iter::<ReductionEntry>` registrations.
257///
258/// The graph supports:
259/// - Auto-discovery of reductions from `inventory::iter::<ReductionEntry>`
260/// - Dijkstra with custom cost functions
261/// - Path finding by problem type or by name
262pub struct ReductionGraph {
263    /// Graph with node indices as node data, edge weights as ReductionEdgeData.
264    graph: DiGraph<usize, ReductionEdgeData>,
265    /// All variant nodes, indexed by position.
266    nodes: Vec<VariantNode>,
267    /// Map from base type name to all NodeIndex values for that name.
268    name_to_nodes: HashMap<&'static str, Vec<NodeIndex>>,
269}
270
271impl ReductionGraph {
272    /// Create a new reduction graph with all registered reductions from inventory.
273    pub fn new() -> Self {
274        let mut graph = DiGraph::new();
275        let mut nodes: Vec<VariantNode> = Vec::new();
276        let mut node_index: HashMap<VariantRef, NodeIndex> = HashMap::new();
277        let mut name_to_nodes: HashMap<&'static str, Vec<NodeIndex>> = HashMap::new();
278
279        // Helper to ensure a variant node exists in the graph.
280        let ensure_node = |name: &'static str,
281                           variant: BTreeMap<String, String>,
282                           complexity: &'static str,
283                           nodes: &mut Vec<VariantNode>,
284                           graph: &mut DiGraph<usize, ReductionEdgeData>,
285                           node_index: &mut HashMap<VariantRef, NodeIndex>,
286                           name_to_nodes: &mut HashMap<&'static str, Vec<NodeIndex>>|
287         -> NodeIndex {
288            let vref = VariantRef {
289                name: name.to_string(),
290                variant: variant.clone(),
291            };
292            if let Some(&idx) = node_index.get(&vref) {
293                idx
294            } else {
295                let node_id = nodes.len();
296                nodes.push(VariantNode {
297                    name,
298                    variant,
299                    complexity,
300                });
301                let idx = graph.add_node(node_id);
302                node_index.insert(vref, idx);
303                name_to_nodes.entry(name).or_default().push(idx);
304                idx
305            }
306        };
307
308        // Phase 1: Build nodes from VariantEntry inventory
309        for entry in inventory::iter::<crate::registry::VariantEntry> {
310            let variant = Self::variant_to_map(&entry.variant());
311            ensure_node(
312                entry.name,
313                variant,
314                entry.complexity,
315                &mut nodes,
316                &mut graph,
317                &mut node_index,
318                &mut name_to_nodes,
319            );
320        }
321
322        // Phase 2: Build edges from ReductionEntry inventory
323        for entry in inventory::iter::<ReductionEntry> {
324            let source_variant = Self::variant_to_map(&entry.source_variant());
325            let target_variant = Self::variant_to_map(&entry.target_variant());
326
327            // Nodes should already exist from Phase 1.
328            // Fall back to creating them with empty complexity for backwards compatibility.
329            let src_idx = ensure_node(
330                entry.source_name,
331                source_variant,
332                "",
333                &mut nodes,
334                &mut graph,
335                &mut node_index,
336                &mut name_to_nodes,
337            );
338            let dst_idx = ensure_node(
339                entry.target_name,
340                target_variant,
341                "",
342                &mut nodes,
343                &mut graph,
344                &mut node_index,
345                &mut name_to_nodes,
346            );
347
348            let overhead = entry.overhead();
349            if graph.find_edge(src_idx, dst_idx).is_none() {
350                graph.add_edge(
351                    src_idx,
352                    dst_idx,
353                    ReductionEdgeData {
354                        overhead,
355                        reduce_fn: entry.reduce_fn,
356                    },
357                );
358            }
359        }
360
361        Self {
362            graph,
363            nodes,
364            name_to_nodes,
365        }
366    }
367
368    /// Convert a variant slice to a BTreeMap.
369    /// Normalizes empty "graph" values to "SimpleGraph" for consistency.
370    pub fn variant_to_map(variant: &[(&str, &str)]) -> BTreeMap<String, String> {
371        variant
372            .iter()
373            .map(|(k, v)| {
374                let value = if *k == "graph" && v.is_empty() {
375                    "SimpleGraph".to_string()
376                } else {
377                    v.to_string()
378                };
379                (k.to_string(), value)
380            })
381            .collect()
382    }
383
384    /// Look up a variant node by name and variant map.
385    fn lookup_node(&self, name: &str, variant: &BTreeMap<String, String>) -> Option<NodeIndex> {
386        let nodes = self.name_to_nodes.get(name)?;
387        nodes
388            .iter()
389            .find(|&&idx| self.nodes[self.graph[idx]].variant == *variant)
390            .copied()
391    }
392
393    /// Find the cheapest path between two specific problem variants.
394    ///
395    /// Uses Dijkstra's algorithm on the variant-level graph from the exact
396    /// source variant node to the exact target variant node.
397    pub fn find_cheapest_path<C: PathCostFn>(
398        &self,
399        source: &str,
400        source_variant: &BTreeMap<String, String>,
401        target: &str,
402        target_variant: &BTreeMap<String, String>,
403        input_size: &ProblemSize,
404        cost_fn: &C,
405    ) -> Option<ReductionPath> {
406        let src = self.lookup_node(source, source_variant)?;
407        let dst = self.lookup_node(target, target_variant)?;
408        let node_path = self.dijkstra(src, dst, input_size, cost_fn)?;
409        Some(self.node_path_to_reduction_path(&node_path))
410    }
411
412    /// Core Dijkstra search on node indices.
413    fn dijkstra<C: PathCostFn>(
414        &self,
415        src: NodeIndex,
416        dst: NodeIndex,
417        input_size: &ProblemSize,
418        cost_fn: &C,
419    ) -> Option<Vec<NodeIndex>> {
420        let mut costs: HashMap<NodeIndex, f64> = HashMap::new();
421        let mut sizes: HashMap<NodeIndex, ProblemSize> = HashMap::new();
422        let mut prev: HashMap<NodeIndex, NodeIndex> = HashMap::new();
423        let mut heap = BinaryHeap::new();
424
425        costs.insert(src, 0.0);
426        sizes.insert(src, input_size.clone());
427        heap.push(Reverse((OrderedFloat(0.0), src)));
428
429        while let Some(Reverse((cost, node))) = heap.pop() {
430            if node == dst {
431                let mut path = vec![dst];
432                let mut current = dst;
433                while current != src {
434                    let &prev_node = prev.get(&current)?;
435                    path.push(prev_node);
436                    current = prev_node;
437                }
438                path.reverse();
439                return Some(path);
440            }
441
442            if cost.0 > *costs.get(&node).unwrap_or(&f64::INFINITY) {
443                continue;
444            }
445
446            let current_size = match sizes.get(&node) {
447                Some(s) => s.clone(),
448                None => continue,
449            };
450
451            for edge_ref in self.graph.edges(node) {
452                let overhead = &edge_ref.weight().overhead;
453                let next = edge_ref.target();
454
455                let edge_cost = cost_fn.edge_cost(overhead, &current_size);
456                let new_cost = cost.0 + edge_cost;
457                let new_size = overhead.evaluate_output_size(&current_size);
458
459                if new_cost < *costs.get(&next).unwrap_or(&f64::INFINITY) {
460                    costs.insert(next, new_cost);
461                    sizes.insert(next, new_size);
462                    prev.insert(next, node);
463                    heap.push(Reverse((OrderedFloat(new_cost), next)));
464                }
465            }
466        }
467
468        None
469    }
470
471    /// Convert a node index path to a `ReductionPath`.
472    fn node_path_to_reduction_path(&self, node_path: &[NodeIndex]) -> ReductionPath {
473        let steps = node_path
474            .iter()
475            .map(|&idx| {
476                let node = &self.nodes[self.graph[idx]];
477                ReductionStep {
478                    name: node.name.to_string(),
479                    variant: node.variant.clone(),
480                }
481            })
482            .collect();
483        ReductionPath { steps }
484    }
485
486    /// Find all simple paths between two specific problem variants.
487    ///
488    /// Uses `all_simple_paths` on the variant-level graph from the exact
489    /// source variant node to the exact target variant node.
490    pub fn find_all_paths(
491        &self,
492        source: &str,
493        source_variant: &BTreeMap<String, String>,
494        target: &str,
495        target_variant: &BTreeMap<String, String>,
496    ) -> Vec<ReductionPath> {
497        let src = match self.lookup_node(source, source_variant) {
498            Some(idx) => idx,
499            None => return vec![],
500        };
501        let dst = match self.lookup_node(target, target_variant) {
502            Some(idx) => idx,
503            None => return vec![],
504        };
505
506        let paths: Vec<Vec<NodeIndex>> = all_simple_paths::<
507            Vec<NodeIndex>,
508            _,
509            std::hash::RandomState,
510        >(&self.graph, src, dst, 0, None)
511        .collect();
512
513        paths
514            .iter()
515            .map(|p| self.node_path_to_reduction_path(p))
516            .collect()
517    }
518
519    /// Check if a direct reduction exists from S to T.
520    pub fn has_direct_reduction<S: crate::traits::Problem, T: crate::traits::Problem>(
521        &self,
522    ) -> bool {
523        self.has_direct_reduction_by_name(S::NAME, T::NAME)
524    }
525
526    /// Check if a direct reduction exists by name.
527    pub fn has_direct_reduction_by_name(&self, src: &str, dst: &str) -> bool {
528        let src_nodes = match self.name_to_nodes.get(src) {
529            Some(nodes) => nodes,
530            None => return false,
531        };
532        let dst_nodes = match self.name_to_nodes.get(dst) {
533            Some(nodes) => nodes,
534            None => return false,
535        };
536
537        let dst_set: HashSet<NodeIndex> = dst_nodes.iter().copied().collect();
538
539        for &src_idx in src_nodes {
540            for edge_ref in self.graph.edges(src_idx) {
541                if dst_set.contains(&edge_ref.target()) {
542                    return true;
543                }
544            }
545        }
546
547        false
548    }
549
550    /// Get all registered problem type names (base names).
551    pub fn problem_types(&self) -> Vec<&'static str> {
552        self.name_to_nodes.keys().copied().collect()
553    }
554
555    /// Get the number of registered problem types (unique base names).
556    pub fn num_types(&self) -> usize {
557        self.name_to_nodes.len()
558    }
559
560    /// Get the number of registered reductions (edges).
561    pub fn num_reductions(&self) -> usize {
562        self.graph.edge_count()
563    }
564
565    /// Get the number of variant-level nodes.
566    pub fn num_variant_nodes(&self) -> usize {
567        self.nodes.len()
568    }
569
570    /// Get the per-edge overhead expressions along a reduction path.
571    ///
572    /// Returns one `ReductionOverhead` per edge (i.e., `path.steps.len() - 1` items).
573    ///
574    /// Panics if any step in the path does not correspond to an edge in the graph.
575    pub fn path_overheads(&self, path: &ReductionPath) -> Vec<ReductionOverhead> {
576        if path.steps.len() <= 1 {
577            return vec![];
578        }
579
580        let node_indices: Vec<NodeIndex> = path
581            .steps
582            .iter()
583            .map(|step| {
584                self.lookup_node(&step.name, &step.variant)
585                    .unwrap_or_else(|| panic!("Node not found: {} {:?}", step.name, step.variant))
586            })
587            .collect();
588
589        node_indices
590            .windows(2)
591            .map(|pair| {
592                let edge_idx = self.graph.find_edge(pair[0], pair[1]).unwrap_or_else(|| {
593                    let src = &self.nodes[self.graph[pair[0]]];
594                    let dst = &self.nodes[self.graph[pair[1]]];
595                    panic!(
596                        "No edge from {} {:?} to {} {:?}",
597                        src.name, src.variant, dst.name, dst.variant
598                    )
599                });
600                self.graph[edge_idx].overhead.clone()
601            })
602            .collect()
603    }
604
605    /// Compose overheads along a path symbolically.
606    ///
607    /// Returns a single `ReductionOverhead` whose expressions map from the
608    /// source problem's size variables directly to the final target's size variables.
609    pub fn compose_path_overhead(&self, path: &ReductionPath) -> ReductionOverhead {
610        self.path_overheads(path)
611            .into_iter()
612            .reduce(|acc, oh| acc.compose(&oh))
613            .unwrap_or_default()
614    }
615
616    /// Get all variant maps registered for a problem name.
617    ///
618    /// Returns variants sorted deterministically: the "default" variant
619    /// (SimpleGraph, i32, etc.) comes first, then remaining variants
620    /// in lexicographic order.
621    pub fn variants_for(&self, name: &str) -> Vec<BTreeMap<String, String>> {
622        let mut variants: Vec<BTreeMap<String, String>> = self
623            .name_to_nodes
624            .get(name)
625            .map(|indices| {
626                indices
627                    .iter()
628                    .map(|&idx| self.nodes[self.graph[idx]].variant.clone())
629                    .collect()
630            })
631            .unwrap_or_default();
632        // Sort deterministically: default variant values (SimpleGraph, One, KN)
633        // sort first so callers can rely on variants[0] being the "base" variant.
634        variants.sort_by(|a, b| {
635            fn default_rank(v: &BTreeMap<String, String>) -> usize {
636                v.values()
637                    .filter(|val| !["SimpleGraph", "One", "KN"].contains(&val.as_str()))
638                    .count()
639            }
640            default_rank(a).cmp(&default_rank(b)).then_with(|| a.cmp(b))
641        });
642        variants
643    }
644
645    /// Get the complexity expression for a specific variant.
646    pub fn variant_complexity(
647        &self,
648        name: &str,
649        variant: &BTreeMap<String, String>,
650    ) -> Option<&'static str> {
651        let idx = self.lookup_node(name, variant)?;
652        let node = &self.nodes[self.graph[idx]];
653        if node.complexity.is_empty() {
654            None
655        } else {
656            Some(node.complexity)
657        }
658    }
659
660    /// Get all outgoing reductions from a problem (across all its variants).
661    pub fn outgoing_reductions(&self, name: &str) -> Vec<ReductionEdgeInfo> {
662        let Some(indices) = self.name_to_nodes.get(name) else {
663            return vec![];
664        };
665        let index_set: HashSet<NodeIndex> = indices.iter().copied().collect();
666        self.graph
667            .edge_references()
668            .filter(|e| index_set.contains(&e.source()))
669            .map(|e| {
670                let src = &self.nodes[self.graph[e.source()]];
671                let dst = &self.nodes[self.graph[e.target()]];
672                ReductionEdgeInfo {
673                    source_name: src.name,
674                    source_variant: src.variant.clone(),
675                    target_name: dst.name,
676                    target_variant: dst.variant.clone(),
677                    overhead: self.graph[e.id()].overhead.clone(),
678                }
679            })
680            .collect()
681    }
682
683    /// Get the problem size field names for a problem type.
684    ///
685    /// Derives size fields from the overhead expressions of reduction entries
686    /// where this problem appears as source or target. When the problem is a
687    /// source, its size fields are the input variables referenced in the overhead
688    /// expressions. When it's a target, its size fields are the output field names.
689    pub fn size_field_names(&self, name: &str) -> Vec<&'static str> {
690        let mut fields = std::collections::HashSet::new();
691        for entry in inventory::iter::<ReductionEntry> {
692            if entry.source_name == name {
693                // Source's size fields are the input variables of the overhead.
694                fields.extend(entry.overhead().input_variable_names());
695            }
696            if entry.target_name == name {
697                // Target's size fields are the output field names.
698                let overhead = entry.overhead();
699                fields.extend(overhead.output_size.iter().map(|(name, _)| *name));
700            }
701        }
702        let mut result: Vec<&'static str> = fields.into_iter().collect();
703        result.sort_unstable();
704        result
705    }
706
707    /// Get all incoming reductions to a problem (across all its variants).
708    pub fn incoming_reductions(&self, name: &str) -> Vec<ReductionEdgeInfo> {
709        let Some(indices) = self.name_to_nodes.get(name) else {
710            return vec![];
711        };
712        let index_set: HashSet<NodeIndex> = indices.iter().copied().collect();
713        self.graph
714            .edge_references()
715            .filter(|e| index_set.contains(&e.target()))
716            .map(|e| {
717                let src = &self.nodes[self.graph[e.source()]];
718                let dst = &self.nodes[self.graph[e.target()]];
719                ReductionEdgeInfo {
720                    source_name: src.name,
721                    source_variant: src.variant.clone(),
722                    target_name: dst.name,
723                    target_variant: dst.variant.clone(),
724                    overhead: self.graph[e.id()].overhead.clone(),
725                }
726            })
727            .collect()
728    }
729
730    /// Find all problems reachable within `max_hops` edges from a starting node.
731    ///
732    /// Returns neighbors sorted by (hops, name). The starting node itself is excluded.
733    /// If a node is reachable at multiple distances, it appears at the shortest distance only.
734    pub fn k_neighbors(
735        &self,
736        name: &str,
737        variant: &BTreeMap<String, String>,
738        max_hops: usize,
739        direction: TraversalDirection,
740    ) -> Vec<NeighborInfo> {
741        use std::collections::VecDeque;
742
743        let Some(start_idx) = self.lookup_node(name, variant) else {
744            return vec![];
745        };
746
747        let mut visited: HashSet<NodeIndex> = HashSet::new();
748        visited.insert(start_idx);
749        let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
750        queue.push_back((start_idx, 0));
751        let mut results: Vec<NeighborInfo> = Vec::new();
752
753        while let Some((node_idx, hops)) = queue.pop_front() {
754            if hops >= max_hops {
755                continue;
756            }
757
758            let directions: Vec<petgraph::Direction> = match direction {
759                TraversalDirection::Outgoing => vec![petgraph::Direction::Outgoing],
760                TraversalDirection::Incoming => vec![petgraph::Direction::Incoming],
761                TraversalDirection::Both => {
762                    vec![petgraph::Direction::Outgoing, petgraph::Direction::Incoming]
763                }
764            };
765
766            for dir in directions {
767                for neighbor_idx in self.graph.neighbors_directed(node_idx, dir) {
768                    if visited.insert(neighbor_idx) {
769                        let neighbor_node = &self.nodes[self.graph[neighbor_idx]];
770                        results.push(NeighborInfo {
771                            name: neighbor_node.name,
772                            variant: neighbor_node.variant.clone(),
773                            hops: hops + 1,
774                        });
775                        queue.push_back((neighbor_idx, hops + 1));
776                    }
777                }
778            }
779        }
780
781        results.sort_by(|a, b| a.hops.cmp(&b.hops).then_with(|| a.name.cmp(b.name)));
782        results
783    }
784
785    /// Build a tree of neighbors via BFS with parent tracking.
786    ///
787    /// Returns the children of the starting node as a forest of `NeighborTree` nodes.
788    /// Each node appears at most once (shortest-path tree). Children are sorted by name.
789    pub fn k_neighbor_tree(
790        &self,
791        name: &str,
792        variant: &BTreeMap<String, String>,
793        max_hops: usize,
794        direction: TraversalDirection,
795    ) -> Vec<NeighborTree> {
796        use std::collections::VecDeque;
797
798        let Some(start_idx) = self.lookup_node(name, variant) else {
799            return vec![];
800        };
801
802        let mut visited: HashSet<NodeIndex> = HashSet::new();
803        visited.insert(start_idx);
804
805        let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
806        queue.push_back((start_idx, 0));
807
808        // Map from node_idx -> children node indices
809        let mut node_children: HashMap<NodeIndex, Vec<NodeIndex>> = HashMap::new();
810
811        while let Some((node_idx, depth)) = queue.pop_front() {
812            if depth >= max_hops {
813                continue;
814            }
815
816            let directions: Vec<petgraph::Direction> = match direction {
817                TraversalDirection::Outgoing => vec![petgraph::Direction::Outgoing],
818                TraversalDirection::Incoming => vec![petgraph::Direction::Incoming],
819                TraversalDirection::Both => {
820                    vec![petgraph::Direction::Outgoing, petgraph::Direction::Incoming]
821                }
822            };
823
824            let mut children = Vec::new();
825            for dir in directions {
826                for neighbor_idx in self.graph.neighbors_directed(node_idx, dir) {
827                    if visited.insert(neighbor_idx) {
828                        children.push(neighbor_idx);
829                        queue.push_back((neighbor_idx, depth + 1));
830                    }
831                }
832            }
833            children.sort_by(|a, b| {
834                self.nodes[self.graph[*a]]
835                    .name
836                    .cmp(self.nodes[self.graph[*b]].name)
837            });
838            node_children.insert(node_idx, children);
839        }
840
841        // Recursively build NeighborTree from BFS parent map.
842        fn build(
843            idx: NodeIndex,
844            node_children: &HashMap<NodeIndex, Vec<NodeIndex>>,
845            nodes: &[VariantNode],
846            graph: &DiGraph<usize, ReductionEdgeData>,
847        ) -> NeighborTree {
848            let children = node_children
849                .get(&idx)
850                .map(|cs| {
851                    cs.iter()
852                        .map(|&c| build(c, node_children, nodes, graph))
853                        .collect()
854                })
855                .unwrap_or_default();
856            let node = &nodes[graph[idx]];
857            NeighborTree {
858                name: node.name.to_string(),
859                variant: node.variant.clone(),
860                children,
861            }
862        }
863
864        node_children
865            .get(&start_idx)
866            .map(|cs| {
867                cs.iter()
868                    .map(|&c| build(c, &node_children, &self.nodes, &self.graph))
869                    .collect()
870            })
871            .unwrap_or_default()
872    }
873}
874
875impl Default for ReductionGraph {
876    fn default() -> Self {
877        Self::new()
878    }
879}
880
881impl ReductionGraph {
882    /// Export the reduction graph as a JSON-serializable structure.
883    ///
884    /// Nodes and edges come directly from the variant-level graph.
885    pub(crate) fn to_json(&self) -> ReductionGraphJson {
886        use crate::registry::ProblemSchemaEntry;
887
888        // Build name -> module_path lookup from ProblemSchemaEntry inventory
889        let schema_modules: HashMap<&str, &str> = inventory::iter::<ProblemSchemaEntry>
890            .into_iter()
891            .map(|entry| (entry.name, entry.module_path))
892            .collect();
893
894        // Build sorted node list from the internal nodes
895        let mut json_nodes: Vec<(usize, NodeJson)> = self
896            .nodes
897            .iter()
898            .enumerate()
899            .map(|(i, node)| {
900                let (category, doc_path) = if let Some(&mod_path) = schema_modules.get(node.name) {
901                    (
902                        Self::category_from_module_path(mod_path),
903                        Self::doc_path_from_module_path(mod_path, node.name),
904                    )
905                } else {
906                    ("other".to_string(), String::new())
907                };
908                (
909                    i,
910                    NodeJson {
911                        name: node.name.to_string(),
912                        variant: node.variant.clone(),
913                        category,
914                        doc_path,
915                        complexity: node.complexity.to_string(),
916                    },
917                )
918            })
919            .collect();
920        json_nodes.sort_by(|a, b| (&a.1.name, &a.1.variant).cmp(&(&b.1.name, &b.1.variant)));
921
922        // Build old-index -> new-index mapping
923        let mut old_to_new: HashMap<usize, usize> = HashMap::new();
924        for (new_idx, (old_idx, _)) in json_nodes.iter().enumerate() {
925            old_to_new.insert(*old_idx, new_idx);
926        }
927
928        let nodes: Vec<NodeJson> = json_nodes.into_iter().map(|(_, n)| n).collect();
929
930        // Build edges from the graph
931        let mut edges: Vec<EdgeJson> = Vec::new();
932        for edge_ref in self.graph.edge_references() {
933            let src_node_id = self.graph[edge_ref.source()];
934            let dst_node_id = self.graph[edge_ref.target()];
935            let overhead = &edge_ref.weight().overhead;
936
937            let overhead_fields = overhead
938                .output_size
939                .iter()
940                .map(|(field, poly)| OverheadFieldJson {
941                    field: field.to_string(),
942                    formula: poly.to_string(),
943                })
944                .collect();
945
946            // Find the doc_path from the matching ReductionEntry
947            let src_name = self.nodes[src_node_id].name;
948            let dst_name = self.nodes[dst_node_id].name;
949            let src_variant = &self.nodes[src_node_id].variant;
950            let dst_variant = &self.nodes[dst_node_id].variant;
951
952            let doc_path = self.find_entry_doc_path(src_name, dst_name, src_variant, dst_variant);
953
954            edges.push(EdgeJson {
955                source: old_to_new[&src_node_id],
956                target: old_to_new[&dst_node_id],
957                overhead: overhead_fields,
958                doc_path,
959            });
960        }
961
962        // Sort edges for deterministic output
963        edges.sort_by(|a, b| {
964            (
965                &nodes[a.source].name,
966                &nodes[a.source].variant,
967                &nodes[a.target].name,
968                &nodes[a.target].variant,
969            )
970                .cmp(&(
971                    &nodes[b.source].name,
972                    &nodes[b.source].variant,
973                    &nodes[b.target].name,
974                    &nodes[b.target].variant,
975                ))
976        });
977
978        ReductionGraphJson { nodes, edges }
979    }
980
981    /// Find the doc_path for a reduction entry matching the given source/target.
982    fn find_entry_doc_path(
983        &self,
984        src_name: &str,
985        dst_name: &str,
986        src_variant: &BTreeMap<String, String>,
987        dst_variant: &BTreeMap<String, String>,
988    ) -> String {
989        for entry in inventory::iter::<ReductionEntry> {
990            if entry.source_name == src_name && entry.target_name == dst_name {
991                let entry_src = Self::variant_to_map(&entry.source_variant());
992                let entry_dst = Self::variant_to_map(&entry.target_variant());
993                if &entry_src == src_variant && &entry_dst == dst_variant {
994                    return Self::module_path_to_doc_path(entry.module_path);
995                }
996            }
997        }
998        String::new()
999    }
1000
1001    /// Export the reduction graph as a JSON string.
1002    pub fn to_json_string(&self) -> Result<String, serde_json::Error> {
1003        let json = self.to_json();
1004        serde_json::to_string_pretty(&json)
1005    }
1006
1007    /// Export the reduction graph to a JSON file.
1008    pub fn to_json_file(&self, path: &std::path::Path) -> std::io::Result<()> {
1009        let json_string = self
1010            .to_json_string()
1011            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
1012        std::fs::write(path, json_string)
1013    }
1014
1015    /// Convert a module path to a rustdoc relative path.
1016    ///
1017    /// E.g., `"problemreductions::rules::spinglass_qubo"` -> `"rules/spinglass_qubo/index.html"`.
1018    fn module_path_to_doc_path(module_path: &str) -> String {
1019        let stripped = module_path
1020            .strip_prefix("problemreductions::")
1021            .unwrap_or(module_path);
1022        format!("{}/index.html", stripped.replace("::", "/"))
1023    }
1024
1025    /// Extract the category from a module path.
1026    ///
1027    /// E.g., `"problemreductions::models::graph::maximum_independent_set"` -> `"graph"`.
1028    fn category_from_module_path(module_path: &str) -> String {
1029        classify_problem_category(module_path).to_string()
1030    }
1031
1032    /// Build the rustdoc path from a module path and problem name.
1033    ///
1034    /// E.g., `"problemreductions::models::graph::maximum_independent_set"`, `"MaximumIndependentSet"`
1035    /// -> `"models/graph/struct.MaximumIndependentSet.html"`.
1036    fn doc_path_from_module_path(module_path: &str, name: &str) -> String {
1037        let stripped = module_path
1038            .strip_prefix("problemreductions::")
1039            .unwrap_or(module_path);
1040        if let Some(parent) = stripped.rsplit_once("::").map(|(p, _)| p) {
1041            format!("{}/struct.{}.html", parent.replace("::", "/"), name)
1042        } else {
1043            format!("struct.{}.html", name)
1044        }
1045    }
1046
1047    /// Find the best matching `ReductionEntry` for a (source_name, target_name) pair
1048    /// given the caller's current source variant.
1049    ///
1050    /// First tries an exact match on the source variant. If no exact match is found,
1051    /// falls back to a name-only match (returning the first entry whose source and
1052    /// target names match). This is intentional: specific variants (e.g., `K3`) may
1053    /// not have their own `#[reduction]` entry, but the general variant (`KN`) covers
1054    /// them with the same overhead expression. The fallback is safe because cross-name
1055    /// reductions share the same overhead regardless of source variant; it is only
1056    /// used by the JSON export pipeline (`export::lookup_overhead`).
1057    pub fn find_best_entry(
1058        &self,
1059        source_name: &str,
1060        target_name: &str,
1061        current_variant: &BTreeMap<String, String>,
1062    ) -> Option<MatchedEntry> {
1063        let mut fallback: Option<MatchedEntry> = None;
1064
1065        for entry in inventory::iter::<ReductionEntry> {
1066            if entry.source_name != source_name || entry.target_name != target_name {
1067                continue;
1068            }
1069
1070            let entry_source = Self::variant_to_map(&entry.source_variant());
1071            let entry_target = Self::variant_to_map(&entry.target_variant());
1072
1073            // Exact match on source variant — return immediately
1074            if current_variant == &entry_source {
1075                return Some(MatchedEntry {
1076                    source_variant: entry_source,
1077                    target_variant: entry_target,
1078                    overhead: entry.overhead(),
1079                });
1080            }
1081
1082            // Remember the first name-only match as a fallback
1083            if fallback.is_none() {
1084                fallback = Some(MatchedEntry {
1085                    source_variant: entry_source,
1086                    target_variant: entry_target,
1087                    overhead: entry.overhead(),
1088                });
1089            }
1090        }
1091
1092        fallback
1093    }
1094}
1095
1096/// A matched reduction entry returned by [`ReductionGraph::find_best_entry`].
1097pub struct MatchedEntry {
1098    /// The entry's source variant.
1099    pub source_variant: BTreeMap<String, String>,
1100    /// The entry's target variant.
1101    pub target_variant: BTreeMap<String, String>,
1102    /// The overhead of the reduction.
1103    pub overhead: ReductionOverhead,
1104}
1105
1106/// A composed reduction chain produced by [`ReductionGraph::reduce_along_path`].
1107///
1108/// Holds the intermediate reduction results from executing a multi-step
1109/// reduction path. Provides access to the final target problem and
1110/// solution extraction back to the source problem space.
1111pub struct ReductionChain {
1112    steps: Vec<Box<dyn DynReductionResult>>,
1113}
1114
1115impl ReductionChain {
1116    /// Get the final target problem as a type-erased reference.
1117    pub fn target_problem_any(&self) -> &dyn Any {
1118        self.steps
1119            .last()
1120            .expect("ReductionChain has no steps")
1121            .target_problem_any()
1122    }
1123
1124    /// Get a typed reference to the final target problem.
1125    ///
1126    /// Panics if the actual target type does not match `T`.
1127    pub fn target_problem<T: 'static>(&self) -> &T {
1128        self.target_problem_any()
1129            .downcast_ref::<T>()
1130            .expect("ReductionChain target type mismatch")
1131    }
1132
1133    /// Extract a solution from target space back to source space.
1134    pub fn extract_solution(&self, target_solution: &[usize]) -> Vec<usize> {
1135        self.steps
1136            .iter()
1137            .rev()
1138            .fold(target_solution.to_vec(), |sol, step| {
1139                step.extract_solution_dyn(&sol)
1140            })
1141    }
1142}
1143
1144impl ReductionGraph {
1145    /// Execute a reduction path on a source problem instance.
1146    ///
1147    /// Looks up each edge's `reduce_fn`, chains them, and returns the
1148    /// resulting [`ReductionChain`]. The source must be passed as `&dyn Any`
1149    /// (use `&problem as &dyn Any` or pass a concrete reference directly).
1150    ///
1151    /// # Example
1152    ///
1153    /// ```text
1154    /// let chain = graph.reduce_along_path(&path, &source_problem)?;
1155    /// let target: &QUBO<f64> = chain.target_problem();
1156    /// let source_solution = chain.extract_solution(&target_solution);
1157    /// ```
1158    pub fn reduce_along_path(
1159        &self,
1160        path: &ReductionPath,
1161        source: &dyn Any,
1162    ) -> Option<ReductionChain> {
1163        if path.steps.len() < 2 {
1164            return None;
1165        }
1166        // Collect edge reduce_fns
1167        let mut edge_fns = Vec::new();
1168        for window in path.steps.windows(2) {
1169            let src = self.lookup_node(&window[0].name, &window[0].variant)?;
1170            let dst = self.lookup_node(&window[1].name, &window[1].variant)?;
1171            let edge_idx = self.graph.find_edge(src, dst)?;
1172            edge_fns.push(self.graph[edge_idx].reduce_fn);
1173        }
1174        // Execute the chain
1175        let mut steps: Vec<Box<dyn DynReductionResult>> = Vec::new();
1176        let step = (edge_fns[0])(source);
1177        steps.push(step);
1178        for edge_fn in &edge_fns[1..] {
1179            let step = {
1180                let prev_target = steps.last().unwrap().target_problem_any();
1181                edge_fn(prev_target)
1182            };
1183            steps.push(step);
1184        }
1185        Some(ReductionChain { steps })
1186    }
1187}
1188
1189#[cfg(test)]
1190#[path = "../unit_tests/rules/graph.rs"]
1191mod tests;
1192
1193#[cfg(test)]
1194#[path = "../unit_tests/rules/reduction_path_parity.rs"]
1195mod reduction_path_parity_tests;
1196
1197#[cfg(all(test, feature = "ilp-solver"))]
1198#[path = "../unit_tests/rules/maximumindependentset_ilp.rs"]
1199mod maximumindependentset_ilp_path_tests;
1200
1201#[cfg(all(test, feature = "ilp-solver"))]
1202#[path = "../unit_tests/rules/minimumvertexcover_ilp.rs"]
1203mod minimumvertexcover_ilp_path_tests;
1204
1205#[cfg(test)]
1206#[path = "../unit_tests/rules/maximumindependentset_qubo.rs"]
1207mod maximumindependentset_qubo_path_tests;
1208
1209#[cfg(test)]
1210#[path = "../unit_tests/rules/minimumvertexcover_qubo.rs"]
1211mod minimumvertexcover_qubo_path_tests;