1use crate::rules::cost::PathCostFn;
16use crate::rules::registry::{
17 AggregateReduceFn, EdgeCapabilities, ReduceFn, ReductionEntry, ReductionOverhead,
18};
19use crate::rules::traits::{DynAggregateReductionResult, DynReductionResult};
20use crate::types::ProblemSize;
21use ordered_float::OrderedFloat;
22use petgraph::algo::all_simple_paths;
23use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
24use petgraph::visit::EdgeRef;
25use serde::Serialize;
26use std::any::Any;
27use std::cmp::Reverse;
28use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet};
29
30#[derive(Debug, Clone)]
33pub struct ReductionEdgeInfo {
34 pub source_name: &'static str,
35 pub source_variant: BTreeMap<String, String>,
36 pub target_name: &'static str,
37 pub target_variant: BTreeMap<String, String>,
38 pub overhead: ReductionOverhead,
39 pub capabilities: EdgeCapabilities,
40}
41
42#[derive(Clone)]
44pub(crate) struct ReductionEdgeData {
45 pub overhead: ReductionOverhead,
46 pub reduce_fn: Option<ReduceFn>,
47 pub reduce_aggregate_fn: Option<AggregateReduceFn>,
48 pub capabilities: EdgeCapabilities,
49}
50
51#[derive(Debug, Clone, Serialize)]
53pub(crate) struct ReductionGraphJson {
54 pub(crate) nodes: Vec<NodeJson>,
56 pub(crate) edges: Vec<EdgeJson>,
58}
59
60impl ReductionGraphJson {
61 #[cfg_attr(not(test), allow(dead_code))]
63 pub(crate) fn source_node(&self, edge: &EdgeJson) -> &NodeJson {
64 &self.nodes[edge.source]
65 }
66
67 #[cfg_attr(not(test), allow(dead_code))]
69 pub(crate) fn target_node(&self, edge: &EdgeJson) -> &NodeJson {
70 &self.nodes[edge.target]
71 }
72}
73
74#[derive(Debug, Clone, Serialize)]
76pub(crate) struct NodeJson {
77 pub(crate) name: String,
79 pub(crate) variant: BTreeMap<String, String>,
81 pub(crate) category: String,
83 pub(crate) doc_path: String,
85 pub(crate) complexity: String,
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, Hash)]
91struct VariantRef {
92 name: String,
93 variant: BTreeMap<String, String>,
94}
95
96#[derive(Debug, Clone, Serialize)]
98pub(crate) struct OverheadFieldJson {
99 pub(crate) field: String,
101 pub(crate) formula: String,
103}
104
105#[derive(Debug, Clone, Serialize)]
107pub(crate) struct EdgeJson {
108 pub(crate) source: usize,
110 pub(crate) target: usize,
112 pub(crate) overhead: Vec<OverheadFieldJson>,
114 pub(crate) doc_path: String,
116 pub(crate) witness: bool,
118 pub(crate) aggregate: bool,
120 pub(crate) turing: bool,
122}
123
124#[derive(Debug, Clone)]
126pub struct ReductionPath {
127 pub steps: Vec<ReductionStep>,
129}
130
131impl ReductionPath {
132 pub fn len(&self) -> usize {
134 if self.steps.is_empty() {
135 0
136 } else {
137 self.steps.len() - 1
138 }
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.steps.is_empty()
144 }
145
146 pub fn source(&self) -> Option<&str> {
148 self.steps.first().map(|s| s.name.as_str())
149 }
150
151 pub fn target(&self) -> Option<&str> {
153 self.steps.last().map(|s| s.name.as_str())
154 }
155
156 pub fn type_names(&self) -> Vec<&str> {
158 let mut names: Vec<&str> = Vec::new();
159 for step in &self.steps {
160 if names.last() != Some(&step.name.as_str()) {
161 names.push(&step.name);
162 }
163 }
164 names
165 }
166}
167
168impl std::fmt::Display for ReductionPath {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 let mut prev_name = "";
171 for step in &self.steps {
172 if step.name != prev_name {
173 if prev_name.is_empty() {
174 write!(f, "{step}")?;
175 } else {
176 write!(f, " → {step}")?;
177 }
178 prev_name = &step.name;
179 }
180 }
181 Ok(())
182 }
183}
184
185#[derive(Debug, Clone, Serialize)]
187pub struct ReductionStep {
188 pub name: String,
190 pub variant: BTreeMap<String, String>,
192}
193
194impl std::fmt::Display for ReductionStep {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 write!(f, "{}", self.name)?;
197 if !self.variant.is_empty() {
198 let vars: Vec<_> = self
199 .variant
200 .iter()
201 .map(|(k, v)| format!("{k}: {v:?}"))
202 .collect();
203 write!(f, " {{{}}}", vars.join(", "))?;
204 }
205 Ok(())
206 }
207}
208
209pub(crate) fn classify_problem_category(module_path: &str) -> &str {
212 let parts: Vec<&str> = module_path.split("::").collect();
213 if parts.len() >= 3 {
214 if let Some(pos) = parts.iter().position(|&p| p == "models") {
215 if pos + 1 < parts.len() {
216 return parts[pos + 1];
217 }
218 }
219 }
220 "other"
221}
222
223#[derive(Debug, Clone)]
225struct VariantNode {
226 name: &'static str,
227 variant: BTreeMap<String, String>,
228 complexity: &'static str,
229}
230
231#[derive(Debug, Clone)]
233pub struct NeighborInfo {
234 pub name: &'static str,
236 pub variant: BTreeMap<String, String>,
238 pub hops: usize,
240}
241
242#[derive(Debug, Clone, Copy, PartialEq, Eq)]
244pub enum TraversalFlow {
245 Outgoing,
247 Incoming,
249 Both,
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq)]
255pub enum ReductionMode {
256 Witness,
257 Aggregate,
258 Turing,
261}
262
263#[derive(Debug, Clone)]
265pub struct NeighborTree {
266 pub name: String,
268 pub variant: BTreeMap<String, String>,
270 pub children: Vec<NeighborTree>,
272}
273
274pub struct ReductionGraph {
284 graph: DiGraph<usize, ReductionEdgeData>,
286 nodes: Vec<VariantNode>,
288 name_to_nodes: HashMap<&'static str, Vec<NodeIndex>>,
290 default_variants: HashMap<String, BTreeMap<String, String>>,
292}
293
294impl ReductionGraph {
295 pub fn new() -> Self {
297 let mut graph = DiGraph::new();
298 let mut nodes: Vec<VariantNode> = Vec::new();
299 let mut node_index: HashMap<VariantRef, NodeIndex> = HashMap::new();
300 let mut name_to_nodes: HashMap<&'static str, Vec<NodeIndex>> = HashMap::new();
301
302 let ensure_node = |name: &'static str,
304 variant: BTreeMap<String, String>,
305 complexity: &'static str,
306 nodes: &mut Vec<VariantNode>,
307 graph: &mut DiGraph<usize, ReductionEdgeData>,
308 node_index: &mut HashMap<VariantRef, NodeIndex>,
309 name_to_nodes: &mut HashMap<&'static str, Vec<NodeIndex>>|
310 -> NodeIndex {
311 let vref = VariantRef {
312 name: name.to_string(),
313 variant: variant.clone(),
314 };
315 if let Some(&idx) = node_index.get(&vref) {
316 idx
317 } else {
318 let node_id = nodes.len();
319 nodes.push(VariantNode {
320 name,
321 variant,
322 complexity,
323 });
324 let idx = graph.add_node(node_id);
325 node_index.insert(vref, idx);
326 name_to_nodes.entry(name).or_default().push(idx);
327 idx
328 }
329 };
330
331 let mut default_variants: HashMap<String, BTreeMap<String, String>> = HashMap::new();
333
334 for entry in inventory::iter::<crate::registry::VariantEntry> {
336 let variant = Self::variant_to_map(&entry.variant());
337 ensure_node(
338 entry.name,
339 variant.clone(),
340 entry.complexity,
341 &mut nodes,
342 &mut graph,
343 &mut node_index,
344 &mut name_to_nodes,
345 );
346 if entry.is_default {
347 default_variants.insert(entry.name.to_string(), variant);
348 }
349 }
350
351 for entry in inventory::iter::<ReductionEntry> {
353 let source_variant = Self::variant_to_map(&entry.source_variant());
354 let target_variant = Self::variant_to_map(&entry.target_variant());
355
356 let src_idx = ensure_node(
359 entry.source_name,
360 source_variant,
361 "",
362 &mut nodes,
363 &mut graph,
364 &mut node_index,
365 &mut name_to_nodes,
366 );
367 let dst_idx = ensure_node(
368 entry.target_name,
369 target_variant,
370 "",
371 &mut nodes,
372 &mut graph,
373 &mut node_index,
374 &mut name_to_nodes,
375 );
376
377 let overhead = entry.overhead();
378 if graph.find_edge(src_idx, dst_idx).is_none() {
379 graph.add_edge(
380 src_idx,
381 dst_idx,
382 ReductionEdgeData {
383 overhead,
384 reduce_fn: entry.reduce_fn,
385 reduce_aggregate_fn: entry.reduce_aggregate_fn,
386 capabilities: entry.capabilities,
387 },
388 );
389 }
390 }
391
392 Self {
393 graph,
394 nodes,
395 name_to_nodes,
396 default_variants,
397 }
398 }
399
400 pub fn variant_to_map(variant: &[(&str, &str)]) -> BTreeMap<String, String> {
403 variant
404 .iter()
405 .map(|(k, v)| {
406 let value = if *k == "graph" && v.is_empty() {
407 "SimpleGraph".to_string()
408 } else {
409 v.to_string()
410 };
411 (k.to_string(), value)
412 })
413 .collect()
414 }
415
416 fn lookup_node(&self, name: &str, variant: &BTreeMap<String, String>) -> Option<NodeIndex> {
418 let nodes = self.name_to_nodes.get(name)?;
419 nodes
420 .iter()
421 .find(|&&idx| self.nodes[self.graph[idx]].variant == *variant)
422 .copied()
423 }
424
425 fn edge_supports_mode(edge: &ReductionEdgeData, mode: ReductionMode) -> bool {
426 match mode {
427 ReductionMode::Witness => edge.capabilities.witness,
428 ReductionMode::Aggregate => edge.capabilities.aggregate,
429 ReductionMode::Turing => edge.capabilities.turing,
430 }
431 }
432
433 fn node_path_supports_mode(&self, node_path: &[NodeIndex], mode: ReductionMode) -> bool {
434 node_path.windows(2).all(|pair| {
435 self.graph
436 .find_edge(pair[0], pair[1])
437 .is_some_and(|edge_idx| Self::edge_supports_mode(&self.graph[edge_idx], mode))
438 })
439 }
440
441 pub fn find_cheapest_path<C: PathCostFn>(
446 &self,
447 source: &str,
448 source_variant: &BTreeMap<String, String>,
449 target: &str,
450 target_variant: &BTreeMap<String, String>,
451 input_size: &ProblemSize,
452 cost_fn: &C,
453 ) -> Option<ReductionPath> {
454 self.find_cheapest_path_mode(
455 source,
456 source_variant,
457 target,
458 target_variant,
459 ReductionMode::Witness,
460 input_size,
461 cost_fn,
462 )
463 }
464
465 #[allow(clippy::too_many_arguments)]
468 pub fn find_cheapest_path_mode<C: PathCostFn>(
469 &self,
470 source: &str,
471 source_variant: &BTreeMap<String, String>,
472 target: &str,
473 target_variant: &BTreeMap<String, String>,
474 mode: ReductionMode,
475 input_size: &ProblemSize,
476 cost_fn: &C,
477 ) -> Option<ReductionPath> {
478 let src = self.lookup_node(source, source_variant)?;
479 let dst = self.lookup_node(target, target_variant)?;
480 let node_path = self.dijkstra(src, dst, mode, input_size, cost_fn)?;
481 Some(self.node_path_to_reduction_path(&node_path))
482 }
483
484 fn dijkstra<C: PathCostFn>(
486 &self,
487 src: NodeIndex,
488 dst: NodeIndex,
489 mode: ReductionMode,
490 input_size: &ProblemSize,
491 cost_fn: &C,
492 ) -> Option<Vec<NodeIndex>> {
493 let mut costs: HashMap<NodeIndex, f64> = HashMap::new();
494 let mut sizes: HashMap<NodeIndex, ProblemSize> = HashMap::new();
495 let mut prev: HashMap<NodeIndex, NodeIndex> = HashMap::new();
496 let mut heap = BinaryHeap::new();
497
498 costs.insert(src, 0.0);
499 sizes.insert(src, input_size.clone());
500 heap.push(Reverse((OrderedFloat(0.0), src)));
501
502 while let Some(Reverse((cost, node))) = heap.pop() {
503 if node == dst {
504 let mut path = vec![dst];
505 let mut current = dst;
506 while current != src {
507 let &prev_node = prev.get(¤t)?;
508 path.push(prev_node);
509 current = prev_node;
510 }
511 path.reverse();
512 return Some(path);
513 }
514
515 if cost.0 > *costs.get(&node).unwrap_or(&f64::INFINITY) {
516 continue;
517 }
518
519 let current_size = match sizes.get(&node) {
520 Some(s) => s.clone(),
521 None => continue,
522 };
523
524 for edge_ref in self.graph.edges(node) {
525 if !Self::edge_supports_mode(edge_ref.weight(), mode) {
526 continue;
527 }
528 let overhead = &edge_ref.weight().overhead;
529 let next = edge_ref.target();
530
531 let edge_cost = cost_fn.edge_cost(overhead, ¤t_size);
532 let new_cost = cost.0 + edge_cost;
533 let new_size = overhead.evaluate_output_size(¤t_size);
534
535 if new_cost < *costs.get(&next).unwrap_or(&f64::INFINITY) {
536 costs.insert(next, new_cost);
537 sizes.insert(next, new_size);
538 prev.insert(next, node);
539 heap.push(Reverse((OrderedFloat(new_cost), next)));
540 }
541 }
542 }
543
544 None
545 }
546
547 fn node_path_to_reduction_path(&self, node_path: &[NodeIndex]) -> ReductionPath {
549 let steps = node_path
550 .iter()
551 .map(|&idx| {
552 let node = &self.nodes[self.graph[idx]];
553 ReductionStep {
554 name: node.name.to_string(),
555 variant: node.variant.clone(),
556 }
557 })
558 .collect();
559 ReductionPath { steps }
560 }
561
562 pub fn find_all_paths(
567 &self,
568 source: &str,
569 source_variant: &BTreeMap<String, String>,
570 target: &str,
571 target_variant: &BTreeMap<String, String>,
572 ) -> Vec<ReductionPath> {
573 self.find_all_paths_mode(
574 source,
575 source_variant,
576 target,
577 target_variant,
578 ReductionMode::Witness,
579 )
580 }
581
582 pub fn find_all_paths_mode(
585 &self,
586 source: &str,
587 source_variant: &BTreeMap<String, String>,
588 target: &str,
589 target_variant: &BTreeMap<String, String>,
590 mode: ReductionMode,
591 ) -> Vec<ReductionPath> {
592 let src = match self.lookup_node(source, source_variant) {
593 Some(idx) => idx,
594 None => return vec![],
595 };
596 let dst = match self.lookup_node(target, target_variant) {
597 Some(idx) => idx,
598 None => return vec![],
599 };
600
601 let paths: Vec<Vec<NodeIndex>> = all_simple_paths::<
602 Vec<NodeIndex>,
603 _,
604 std::hash::RandomState,
605 >(&self.graph, src, dst, 0, None)
606 .collect();
607
608 paths
609 .iter()
610 .filter(|p| self.node_path_supports_mode(p, mode))
611 .map(|p| self.node_path_to_reduction_path(p))
612 .collect()
613 }
614
615 pub fn find_paths_up_to(
620 &self,
621 source: &str,
622 source_variant: &BTreeMap<String, String>,
623 target: &str,
624 target_variant: &BTreeMap<String, String>,
625 limit: usize,
626 ) -> Vec<ReductionPath> {
627 self.find_paths_up_to_mode_bounded(
628 source,
629 source_variant,
630 target,
631 target_variant,
632 ReductionMode::Witness,
633 limit,
634 None,
635 )
636 }
637
638 pub fn find_paths_up_to_mode(
641 &self,
642 source: &str,
643 source_variant: &BTreeMap<String, String>,
644 target: &str,
645 target_variant: &BTreeMap<String, String>,
646 mode: ReductionMode,
647 limit: usize,
648 ) -> Vec<ReductionPath> {
649 self.find_paths_up_to_mode_bounded(
650 source,
651 source_variant,
652 target,
653 target_variant,
654 mode,
655 limit,
656 None,
657 )
658 }
659
660 #[allow(clippy::too_many_arguments)]
663 pub fn find_paths_up_to_mode_bounded(
664 &self,
665 source: &str,
666 source_variant: &BTreeMap<String, String>,
667 target: &str,
668 target_variant: &BTreeMap<String, String>,
669 mode: ReductionMode,
670 limit: usize,
671 max_intermediate_nodes: Option<usize>,
672 ) -> Vec<ReductionPath> {
673 let src = match self.lookup_node(source, source_variant) {
674 Some(idx) => idx,
675 None => return vec![],
676 };
677 let dst = match self.lookup_node(target, target_variant) {
678 Some(idx) => idx,
679 None => return vec![],
680 };
681
682 let paths: Vec<Vec<NodeIndex>> = all_simple_paths::<
683 Vec<NodeIndex>,
684 _,
685 std::hash::RandomState,
686 >(&self.graph, src, dst, 0, max_intermediate_nodes)
687 .take(limit)
688 .collect();
689
690 paths
691 .iter()
692 .filter(|p| self.node_path_supports_mode(p, mode))
693 .map(|p| self.node_path_to_reduction_path(p))
694 .collect()
695 }
696
697 pub fn has_direct_reduction<S: crate::traits::Problem, T: crate::traits::Problem>(
699 &self,
700 ) -> bool {
701 self.has_direct_reduction_by_name(S::NAME, T::NAME)
702 }
703
704 pub fn has_direct_reduction_by_name(&self, src: &str, dst: &str) -> bool {
706 let src_nodes = match self.name_to_nodes.get(src) {
707 Some(nodes) => nodes,
708 None => return false,
709 };
710 let dst_nodes = match self.name_to_nodes.get(dst) {
711 Some(nodes) => nodes,
712 None => return false,
713 };
714
715 let dst_set: HashSet<NodeIndex> = dst_nodes.iter().copied().collect();
716
717 for &src_idx in src_nodes {
718 for edge_ref in self.graph.edges(src_idx) {
719 if dst_set.contains(&edge_ref.target()) {
720 return true;
721 }
722 }
723 }
724
725 false
726 }
727
728 pub fn has_direct_reduction_by_name_mode(
730 &self,
731 src: &str,
732 dst: &str,
733 mode: ReductionMode,
734 ) -> bool {
735 let src_nodes = match self.name_to_nodes.get(src) {
736 Some(nodes) => nodes,
737 None => return false,
738 };
739 let dst_nodes = match self.name_to_nodes.get(dst) {
740 Some(nodes) => nodes,
741 None => return false,
742 };
743
744 let dst_set: HashSet<NodeIndex> = dst_nodes.iter().copied().collect();
745
746 for &src_idx in src_nodes {
747 for edge_ref in self.graph.edges(src_idx) {
748 if dst_set.contains(&edge_ref.target())
749 && Self::edge_supports_mode(edge_ref.weight(), mode)
750 {
751 return true;
752 }
753 }
754 }
755
756 false
757 }
758
759 pub fn has_direct_reduction_mode<S: crate::traits::Problem, T: crate::traits::Problem>(
761 &self,
762 mode: ReductionMode,
763 ) -> bool {
764 self.has_direct_reduction_by_name_mode(S::NAME, T::NAME, mode)
765 }
766
767 pub fn problem_types(&self) -> Vec<&'static str> {
769 self.name_to_nodes.keys().copied().collect()
770 }
771
772 pub fn num_types(&self) -> usize {
774 self.name_to_nodes.len()
775 }
776
777 pub fn num_reductions(&self) -> usize {
779 self.graph.edge_count()
780 }
781
782 pub fn num_variant_nodes(&self) -> usize {
784 self.nodes.len()
785 }
786
787 pub fn path_overheads(&self, path: &ReductionPath) -> Vec<ReductionOverhead> {
793 if path.steps.len() <= 1 {
794 return vec![];
795 }
796
797 let node_indices: Vec<NodeIndex> = path
798 .steps
799 .iter()
800 .map(|step| {
801 self.lookup_node(&step.name, &step.variant)
802 .unwrap_or_else(|| panic!("Node not found: {} {:?}", step.name, step.variant))
803 })
804 .collect();
805
806 node_indices
807 .windows(2)
808 .map(|pair| {
809 let edge_idx = self.graph.find_edge(pair[0], pair[1]).unwrap_or_else(|| {
810 let src = &self.nodes[self.graph[pair[0]]];
811 let dst = &self.nodes[self.graph[pair[1]]];
812 panic!(
813 "No edge from {} {:?} to {} {:?}",
814 src.name, src.variant, dst.name, dst.variant
815 )
816 });
817 self.graph[edge_idx].overhead.clone()
818 })
819 .collect()
820 }
821
822 pub fn compose_path_overhead(&self, path: &ReductionPath) -> ReductionOverhead {
827 self.path_overheads(path)
828 .into_iter()
829 .reduce(|acc, oh| acc.compose(&oh))
830 .unwrap_or_default()
831 }
832
833 pub fn variants_for(&self, name: &str) -> Vec<BTreeMap<String, String>> {
839 let mut variants: Vec<BTreeMap<String, String>> = self
840 .name_to_nodes
841 .get(name)
842 .map(|indices| {
843 indices
844 .iter()
845 .map(|&idx| self.nodes[self.graph[idx]].variant.clone())
846 .collect()
847 })
848 .unwrap_or_default();
849 variants.sort_by(|a, b| {
852 fn default_rank(v: &BTreeMap<String, String>) -> usize {
853 v.values()
854 .filter(|val| !["SimpleGraph", "One", "KN"].contains(&val.as_str()))
855 .count()
856 }
857 default_rank(a).cmp(&default_rank(b)).then_with(|| a.cmp(b))
858 });
859 variants
860 }
861
862 pub fn default_variant_for(&self, name: &str) -> Option<BTreeMap<String, String>> {
869 self.default_variants.get(name).cloned()
870 }
871
872 pub fn variant_complexity(
874 &self,
875 name: &str,
876 variant: &BTreeMap<String, String>,
877 ) -> Option<&'static str> {
878 let idx = self.lookup_node(name, variant)?;
879 let node = &self.nodes[self.graph[idx]];
880 if node.complexity.is_empty() {
881 None
882 } else {
883 Some(node.complexity)
884 }
885 }
886
887 pub fn outgoing_reductions(&self, name: &str) -> Vec<ReductionEdgeInfo> {
889 let Some(indices) = self.name_to_nodes.get(name) else {
890 return vec![];
891 };
892 let index_set: HashSet<NodeIndex> = indices.iter().copied().collect();
893 self.graph
894 .edge_references()
895 .filter(|e| index_set.contains(&e.source()))
896 .map(|e| {
897 let src = &self.nodes[self.graph[e.source()]];
898 let dst = &self.nodes[self.graph[e.target()]];
899 ReductionEdgeInfo {
900 source_name: src.name,
901 source_variant: src.variant.clone(),
902 target_name: dst.name,
903 target_variant: dst.variant.clone(),
904 overhead: self.graph[e.id()].overhead.clone(),
905 capabilities: self.graph[e.id()].capabilities,
906 }
907 })
908 .collect()
909 }
910
911 pub fn size_field_names(&self, name: &str) -> Vec<&'static str> {
918 let mut fields: std::collections::HashSet<&'static str> =
919 crate::registry::declared_size_fields(name)
920 .into_iter()
921 .collect();
922 for entry in inventory::iter::<ReductionEntry> {
923 if entry.source_name == name {
924 fields.extend(entry.overhead().input_variable_names());
926 }
927 if entry.target_name == name {
928 let overhead = entry.overhead();
930 fields.extend(overhead.output_size.iter().map(|(name, _)| *name));
931 }
932 }
933 let mut result: Vec<&'static str> = fields.into_iter().collect();
934 result.sort_unstable();
935 result
936 }
937
938 pub fn evaluate_path_overhead(
944 &self,
945 path: &ReductionPath,
946 input_size: &ProblemSize,
947 ) -> Option<ProblemSize> {
948 let mut current_size = input_size.clone();
949 for pair in path.steps.windows(2) {
950 let src = self.lookup_node(&pair[0].name, &pair[0].variant)?;
951 let dst = self.lookup_node(&pair[1].name, &pair[1].variant)?;
952 let edge_idx = self.graph.find_edge(src, dst)?;
953 let edge = &self.graph[edge_idx];
954 current_size = edge.overhead.evaluate_output_size(¤t_size);
955 }
956 Some(current_size)
957 }
958
959 pub fn compute_source_size(name: &str, instance: &dyn Any) -> ProblemSize {
966 let mut merged: Vec<(String, usize)> = Vec::new();
967 let mut seen: HashSet<String> = HashSet::new();
968
969 for entry in inventory::iter::<ReductionEntry> {
970 if entry.source_name == name {
971 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
972 (entry.source_size_fn)(instance)
973 }));
974 if let Ok(size) = result {
975 for (k, v) in size.components {
976 if seen.insert(k.clone()) {
977 merged.push((k, v));
978 }
979 }
980 }
981 }
982 }
983 ProblemSize { components: merged }
984 }
985
986 pub fn incoming_reductions(&self, name: &str) -> Vec<ReductionEdgeInfo> {
988 let Some(indices) = self.name_to_nodes.get(name) else {
989 return vec![];
990 };
991 let index_set: HashSet<NodeIndex> = indices.iter().copied().collect();
992 self.graph
993 .edge_references()
994 .filter(|e| index_set.contains(&e.target()))
995 .map(|e| {
996 let src = &self.nodes[self.graph[e.source()]];
997 let dst = &self.nodes[self.graph[e.target()]];
998 ReductionEdgeInfo {
999 source_name: src.name,
1000 source_variant: src.variant.clone(),
1001 target_name: dst.name,
1002 target_variant: dst.variant.clone(),
1003 overhead: self.graph[e.id()].overhead.clone(),
1004 capabilities: self.graph[e.id()].capabilities,
1005 }
1006 })
1007 .collect()
1008 }
1009
1010 pub fn k_neighbors(
1015 &self,
1016 name: &str,
1017 variant: &BTreeMap<String, String>,
1018 max_hops: usize,
1019 direction: TraversalFlow,
1020 ) -> Vec<NeighborInfo> {
1021 use std::collections::VecDeque;
1022
1023 let Some(start_idx) = self.lookup_node(name, variant) else {
1024 return vec![];
1025 };
1026
1027 let mut visited: HashSet<NodeIndex> = HashSet::new();
1028 visited.insert(start_idx);
1029 let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
1030 queue.push_back((start_idx, 0));
1031 let mut results: Vec<NeighborInfo> = Vec::new();
1032
1033 while let Some((node_idx, hops)) = queue.pop_front() {
1034 if hops >= max_hops {
1035 continue;
1036 }
1037
1038 let directions = match direction {
1039 TraversalFlow::Outgoing => vec![petgraph::Outgoing],
1040 TraversalFlow::Incoming => vec![petgraph::Incoming],
1041 TraversalFlow::Both => {
1042 vec![petgraph::Outgoing, petgraph::Incoming]
1043 }
1044 };
1045
1046 for dir in directions {
1047 for neighbor_idx in self.graph.neighbors_directed(node_idx, dir) {
1048 if visited.insert(neighbor_idx) {
1049 let neighbor_node = &self.nodes[self.graph[neighbor_idx]];
1050 results.push(NeighborInfo {
1051 name: neighbor_node.name,
1052 variant: neighbor_node.variant.clone(),
1053 hops: hops + 1,
1054 });
1055 queue.push_back((neighbor_idx, hops + 1));
1056 }
1057 }
1058 }
1059 }
1060
1061 results.sort_by(|a, b| a.hops.cmp(&b.hops).then_with(|| a.name.cmp(b.name)));
1062 results
1063 }
1064
1065 pub fn k_neighbor_tree(
1070 &self,
1071 name: &str,
1072 variant: &BTreeMap<String, String>,
1073 max_hops: usize,
1074 direction: TraversalFlow,
1075 ) -> Vec<NeighborTree> {
1076 use std::collections::VecDeque;
1077
1078 let Some(start_idx) = self.lookup_node(name, variant) else {
1079 return vec![];
1080 };
1081
1082 let mut visited: HashSet<NodeIndex> = HashSet::new();
1083 visited.insert(start_idx);
1084
1085 let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
1086 queue.push_back((start_idx, 0));
1087
1088 let mut node_children: HashMap<NodeIndex, Vec<NodeIndex>> = HashMap::new();
1090
1091 while let Some((node_idx, depth)) = queue.pop_front() {
1092 if depth >= max_hops {
1093 continue;
1094 }
1095
1096 let directions = match direction {
1097 TraversalFlow::Outgoing => vec![petgraph::Outgoing],
1098 TraversalFlow::Incoming => vec![petgraph::Incoming],
1099 TraversalFlow::Both => {
1100 vec![petgraph::Outgoing, petgraph::Incoming]
1101 }
1102 };
1103
1104 let mut children = Vec::new();
1105 for dir in directions {
1106 for neighbor_idx in self.graph.neighbors_directed(node_idx, dir) {
1107 if visited.insert(neighbor_idx) {
1108 children.push(neighbor_idx);
1109 queue.push_back((neighbor_idx, depth + 1));
1110 }
1111 }
1112 }
1113 children.sort_by(|a, b| {
1114 self.nodes[self.graph[*a]]
1115 .name
1116 .cmp(self.nodes[self.graph[*b]].name)
1117 });
1118 node_children.insert(node_idx, children);
1119 }
1120
1121 fn build(
1123 idx: NodeIndex,
1124 node_children: &HashMap<NodeIndex, Vec<NodeIndex>>,
1125 nodes: &[VariantNode],
1126 graph: &DiGraph<usize, ReductionEdgeData>,
1127 ) -> NeighborTree {
1128 let children = node_children
1129 .get(&idx)
1130 .map(|cs| {
1131 cs.iter()
1132 .map(|&c| build(c, node_children, nodes, graph))
1133 .collect()
1134 })
1135 .unwrap_or_default();
1136 let node = &nodes[graph[idx]];
1137 NeighborTree {
1138 name: node.name.to_string(),
1139 variant: node.variant.clone(),
1140 children,
1141 }
1142 }
1143
1144 node_children
1145 .get(&start_idx)
1146 .map(|cs| {
1147 cs.iter()
1148 .map(|&c| build(c, &node_children, &self.nodes, &self.graph))
1149 .collect()
1150 })
1151 .unwrap_or_default()
1152 }
1153}
1154
1155impl Default for ReductionGraph {
1156 fn default() -> Self {
1157 Self::new()
1158 }
1159}
1160
1161impl ReductionGraph {
1162 pub(crate) fn to_json(&self) -> ReductionGraphJson {
1166 use crate::registry::ProblemSchemaEntry;
1167
1168 let schema_modules: HashMap<&str, &str> = inventory::iter::<ProblemSchemaEntry>
1170 .into_iter()
1171 .map(|entry| (entry.name, entry.module_path))
1172 .collect();
1173
1174 let mut json_nodes: Vec<(usize, NodeJson)> = self
1176 .nodes
1177 .iter()
1178 .enumerate()
1179 .map(|(i, node)| {
1180 let (category, doc_path) = if let Some(&mod_path) = schema_modules.get(node.name) {
1181 (
1182 Self::category_from_module_path(mod_path),
1183 Self::doc_path_from_module_path(mod_path, node.name),
1184 )
1185 } else {
1186 ("other".to_string(), String::new())
1187 };
1188 (
1189 i,
1190 NodeJson {
1191 name: node.name.to_string(),
1192 variant: node.variant.clone(),
1193 category,
1194 doc_path,
1195 complexity: node.complexity.to_string(),
1196 },
1197 )
1198 })
1199 .collect();
1200 json_nodes.sort_by(|a, b| (&a.1.name, &a.1.variant).cmp(&(&b.1.name, &b.1.variant)));
1201
1202 let mut old_to_new: HashMap<usize, usize> = HashMap::new();
1204 for (new_idx, (old_idx, _)) in json_nodes.iter().enumerate() {
1205 old_to_new.insert(*old_idx, new_idx);
1206 }
1207
1208 let nodes: Vec<NodeJson> = json_nodes.into_iter().map(|(_, n)| n).collect();
1209
1210 let mut edges: Vec<EdgeJson> = Vec::new();
1212 for edge_ref in self.graph.edge_references() {
1213 let src_node_id = self.graph[edge_ref.source()];
1214 let dst_node_id = self.graph[edge_ref.target()];
1215 let overhead = &edge_ref.weight().overhead;
1216 let capabilities = edge_ref.weight().capabilities;
1217
1218 let overhead_fields = overhead
1219 .output_size
1220 .iter()
1221 .map(|(field, poly)| OverheadFieldJson {
1222 field: field.to_string(),
1223 formula: poly.to_string(),
1224 })
1225 .collect();
1226
1227 let src_name = self.nodes[src_node_id].name;
1229 let dst_name = self.nodes[dst_node_id].name;
1230 let src_variant = &self.nodes[src_node_id].variant;
1231 let dst_variant = &self.nodes[dst_node_id].variant;
1232
1233 let doc_path = self.find_entry_doc_path(src_name, dst_name, src_variant, dst_variant);
1234
1235 edges.push(EdgeJson {
1236 source: old_to_new[&src_node_id],
1237 target: old_to_new[&dst_node_id],
1238 overhead: overhead_fields,
1239 doc_path,
1240 witness: capabilities.witness,
1241 aggregate: capabilities.aggregate,
1242 turing: capabilities.turing,
1243 });
1244 }
1245
1246 edges.sort_by(|a, b| {
1248 (
1249 &nodes[a.source].name,
1250 &nodes[a.source].variant,
1251 &nodes[a.target].name,
1252 &nodes[a.target].variant,
1253 )
1254 .cmp(&(
1255 &nodes[b.source].name,
1256 &nodes[b.source].variant,
1257 &nodes[b.target].name,
1258 &nodes[b.target].variant,
1259 ))
1260 });
1261
1262 ReductionGraphJson { nodes, edges }
1263 }
1264
1265 fn find_entry_doc_path(
1267 &self,
1268 src_name: &str,
1269 dst_name: &str,
1270 src_variant: &BTreeMap<String, String>,
1271 dst_variant: &BTreeMap<String, String>,
1272 ) -> String {
1273 for entry in inventory::iter::<ReductionEntry> {
1274 if entry.source_name == src_name && entry.target_name == dst_name {
1275 let entry_src = Self::variant_to_map(&entry.source_variant());
1276 let entry_dst = Self::variant_to_map(&entry.target_variant());
1277 if &entry_src == src_variant && &entry_dst == dst_variant {
1278 return Self::module_path_to_doc_path(entry.module_path);
1279 }
1280 }
1281 }
1282 String::new()
1283 }
1284
1285 pub fn to_json_string(&self) -> Result<String, serde_json::Error> {
1287 let json = self.to_json();
1288 serde_json::to_string_pretty(&json)
1289 }
1290
1291 pub fn to_json_file(&self, path: &std::path::Path) -> std::io::Result<()> {
1293 let json_string = self
1294 .to_json_string()
1295 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
1296 std::fs::write(path, json_string)
1297 }
1298
1299 fn module_path_to_doc_path(module_path: &str) -> String {
1303 let stripped = module_path
1304 .strip_prefix("problemreductions::")
1305 .unwrap_or(module_path);
1306 format!("{}/index.html", stripped.replace("::", "/"))
1307 }
1308
1309 fn category_from_module_path(module_path: &str) -> String {
1313 classify_problem_category(module_path).to_string()
1314 }
1315
1316 fn doc_path_from_module_path(module_path: &str, name: &str) -> String {
1321 let stripped = module_path
1322 .strip_prefix("problemreductions::")
1323 .unwrap_or(module_path);
1324 if let Some(parent) = stripped.rsplit_once("::").map(|(p, _)| p) {
1325 format!("{}/struct.{}.html", parent.replace("::", "/"), name)
1326 } else {
1327 format!("struct.{}.html", name)
1328 }
1329 }
1330
1331 pub fn find_best_entry(
1338 &self,
1339 source_name: &str,
1340 source_variant: &BTreeMap<String, String>,
1341 target_name: &str,
1342 target_variant: &BTreeMap<String, String>,
1343 ) -> Option<MatchedEntry> {
1344 for entry in inventory::iter::<ReductionEntry> {
1345 if entry.source_name != source_name || entry.target_name != target_name {
1346 continue;
1347 }
1348
1349 let entry_source = Self::variant_to_map(&entry.source_variant());
1350 let entry_target = Self::variant_to_map(&entry.target_variant());
1351
1352 if source_variant == &entry_source && target_variant == &entry_target {
1354 return Some(MatchedEntry {
1355 source_variant: entry_source,
1356 target_variant: entry_target,
1357 overhead: entry.overhead(),
1358 });
1359 }
1360 }
1361
1362 None
1363 }
1364}
1365
1366pub struct MatchedEntry {
1368 pub source_variant: BTreeMap<String, String>,
1370 pub target_variant: BTreeMap<String, String>,
1372 pub overhead: ReductionOverhead,
1374}
1375
1376pub struct ReductionChain {
1382 steps: Vec<Box<dyn DynReductionResult>>,
1383}
1384
1385impl ReductionChain {
1386 pub fn target_problem_any(&self) -> &dyn Any {
1388 self.steps
1389 .last()
1390 .expect("ReductionChain has no steps")
1391 .target_problem_any()
1392 }
1393
1394 pub fn target_problem<T: 'static>(&self) -> &T {
1398 self.target_problem_any()
1399 .downcast_ref::<T>()
1400 .expect("ReductionChain target type mismatch")
1401 }
1402
1403 pub fn extract_solution(&self, target_solution: &[usize]) -> Vec<usize> {
1405 self.steps
1406 .iter()
1407 .rev()
1408 .fold(target_solution.to_vec(), |sol, step| {
1409 step.extract_solution_dyn(&sol)
1410 })
1411 }
1412}
1413
1414pub struct AggregateReductionChain {
1417 steps: Vec<Box<dyn DynAggregateReductionResult>>,
1418}
1419
1420impl AggregateReductionChain {
1421 pub fn target_problem_any(&self) -> &dyn Any {
1423 self.steps
1424 .last()
1425 .expect("AggregateReductionChain has no steps")
1426 .target_problem_any()
1427 }
1428
1429 pub fn target_problem<T: 'static>(&self) -> &T {
1433 self.target_problem_any()
1434 .downcast_ref::<T>()
1435 .expect("AggregateReductionChain target type mismatch")
1436 }
1437
1438 pub fn extract_value_dyn(&self, target_value: serde_json::Value) -> serde_json::Value {
1440 self.steps
1441 .iter()
1442 .rev()
1443 .fold(target_value, |value, step| step.extract_value_dyn(value))
1444 }
1445}
1446
1447struct WitnessBackedIdentityAggregateStep {
1448 inner: Box<dyn DynReductionResult>,
1449}
1450
1451impl DynAggregateReductionResult for WitnessBackedIdentityAggregateStep {
1452 fn target_problem_any(&self) -> &dyn Any {
1453 self.inner.target_problem_any()
1454 }
1455
1456 fn extract_value_dyn(&self, target_value: serde_json::Value) -> serde_json::Value {
1457 target_value
1458 }
1459}
1460
1461impl ReductionGraph {
1462 fn execute_aggregate_edge(
1463 &self,
1464 edge_idx: EdgeIndex,
1465 input: &dyn Any,
1466 ) -> Option<Box<dyn DynAggregateReductionResult>> {
1467 let edge = &self.graph[edge_idx];
1468 if !Self::edge_supports_mode(edge, ReductionMode::Aggregate) {
1469 return None;
1470 }
1471
1472 if let Some(edge_fn) = edge.reduce_aggregate_fn {
1473 return Some(edge_fn(input));
1474 }
1475
1476 if edge.capabilities.witness && edge.capabilities.aggregate {
1477 let edge_fn = edge.reduce_fn?;
1478 return Some(Box::new(WitnessBackedIdentityAggregateStep {
1479 inner: edge_fn(input),
1480 }));
1481 }
1482
1483 None
1484 }
1485
1486 pub fn reduce_along_path(
1500 &self,
1501 path: &ReductionPath,
1502 source: &dyn Any,
1503 ) -> Option<ReductionChain> {
1504 if path.steps.len() < 2 {
1505 return None;
1506 }
1507 let mut edge_fns = Vec::new();
1509 for window in path.steps.windows(2) {
1510 let src = self.lookup_node(&window[0].name, &window[0].variant)?;
1511 let dst = self.lookup_node(&window[1].name, &window[1].variant)?;
1512 let edge_idx = self.graph.find_edge(src, dst)?;
1513 if !Self::edge_supports_mode(&self.graph[edge_idx], ReductionMode::Witness) {
1514 return None;
1515 }
1516 edge_fns.push(self.graph[edge_idx].reduce_fn?);
1517 }
1518 let mut steps: Vec<Box<dyn DynReductionResult>> = Vec::new();
1520 let step = (edge_fns[0])(source);
1521 steps.push(step);
1522 for edge_fn in &edge_fns[1..] {
1523 let step = {
1524 let prev_target = steps.last().unwrap().target_problem_any();
1525 edge_fn(prev_target)
1526 };
1527 steps.push(step);
1528 }
1529 Some(ReductionChain { steps })
1530 }
1531
1532 pub fn reduce_aggregate_along_path(
1534 &self,
1535 path: &ReductionPath,
1536 source: &dyn Any,
1537 ) -> Option<AggregateReductionChain> {
1538 if path.steps.len() < 2 {
1539 return None;
1540 }
1541
1542 let mut edge_indices = Vec::new();
1543 for window in path.steps.windows(2) {
1544 let src = self.lookup_node(&window[0].name, &window[0].variant)?;
1545 let dst = self.lookup_node(&window[1].name, &window[1].variant)?;
1546 let edge_idx = self.graph.find_edge(src, dst)?;
1547 edge_indices.push(edge_idx);
1548 }
1549
1550 let mut steps: Vec<Box<dyn DynAggregateReductionResult>> = Vec::new();
1551 let step = self.execute_aggregate_edge(edge_indices[0], source)?;
1552 steps.push(step);
1553 for &edge_idx in &edge_indices[1..] {
1554 let step = {
1555 let prev_target = steps.last().unwrap().target_problem_any();
1556 self.execute_aggregate_edge(edge_idx, prev_target)?
1557 };
1558 steps.push(step);
1559 }
1560 Some(AggregateReductionChain { steps })
1561 }
1562}
1563
1564#[cfg(test)]
1565#[path = "../unit_tests/rules/graph.rs"]
1566mod tests;
1567
1568#[cfg(test)]
1569#[path = "../unit_tests/rules/reduction_path_parity.rs"]
1570mod reduction_path_parity_tests;
1571
1572#[cfg(all(test, feature = "ilp-solver"))]
1573#[path = "../unit_tests/rules/maximumindependentset_ilp.rs"]
1574mod maximumindependentset_ilp_path_tests;
1575
1576#[cfg(all(test, feature = "ilp-solver"))]
1577#[path = "../unit_tests/rules/minimumvertexcover_ilp.rs"]
1578mod minimumvertexcover_ilp_path_tests;
1579
1580#[cfg(test)]
1581#[path = "../unit_tests/rules/maximumindependentset_qubo.rs"]
1582mod maximumindependentset_qubo_path_tests;
1583
1584#[cfg(test)]
1585#[path = "../unit_tests/rules/minimumvertexcover_qubo.rs"]
1586mod minimumvertexcover_qubo_path_tests;