1use crate::canonical::canonical_form;
11use crate::expr::Expr;
12use crate::rules::graph::{ReductionGraph, ReductionPath};
13use crate::rules::registry::ReductionOverhead;
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum ComparisonStatus {
20 Dominated,
22 NotDominated,
24 Unknown,
26}
27
28#[derive(Debug, Clone)]
30pub struct DominatedRule {
31 pub source_name: &'static str,
32 pub source_variant: BTreeMap<String, String>,
33 pub target_name: &'static str,
34 pub target_variant: BTreeMap<String, String>,
35 pub primitive_overhead: ReductionOverhead,
36 pub dominating_path: ReductionPath,
37 pub composed_overhead: ReductionOverhead,
38 pub comparable_fields: Vec<String>,
39}
40
41impl DominatedRule {
42 pub fn source_display(&self) -> String {
43 format_problem_variant(self.source_name, &self.source_variant)
44 }
45
46 pub fn target_display(&self) -> String {
47 format_problem_variant(self.target_name, &self.target_variant)
48 }
49}
50
51impl fmt::Display for DominatedRule {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 write!(f, "{} -> {}", self.source_display(), self.target_display())
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct UnknownComparison {
60 pub source_name: &'static str,
61 pub source_variant: BTreeMap<String, String>,
62 pub target_name: &'static str,
63 pub target_variant: BTreeMap<String, String>,
64 pub candidate_path: ReductionPath,
65 pub reason: String,
66}
67
68impl UnknownComparison {
69 pub fn source_display(&self) -> String {
70 format_problem_variant(self.source_name, &self.source_variant)
71 }
72
73 pub fn target_display(&self) -> String {
74 format_problem_variant(self.target_name, &self.target_variant)
75 }
76}
77
78impl fmt::Display for UnknownComparison {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(f, "{} -> {}", self.source_display(), self.target_display())
81 }
82}
83
84pub fn format_problem_variant(name: &str, variant: &BTreeMap<String, String>) -> String {
85 if variant.is_empty() {
86 return name.to_string();
87 }
88
89 let vars = variant
90 .iter()
91 .map(|(k, v)| format!("{k}: {v:?}"))
92 .collect::<Vec<_>>()
93 .join(", ");
94 format!("{name} {{{vars}}}")
95}
96
97#[derive(Debug, Clone)]
101struct Monomial {
102 coeff: f64,
103 vars: BTreeMap<&'static str, f64>,
105}
106
107impl Monomial {
108 fn constant(c: f64) -> Self {
109 Self {
110 coeff: c,
111 vars: BTreeMap::new(),
112 }
113 }
114
115 fn variable(name: &'static str) -> Self {
116 let mut vars = BTreeMap::new();
117 vars.insert(name, 1.0);
118 Self { coeff: 1.0, vars }
119 }
120
121 fn mul(&self, other: &Monomial) -> Monomial {
123 let coeff = self.coeff * other.coeff;
124 let mut vars = self.vars.clone();
125 for (&v, &e) in &other.vars {
126 *vars.entry(v).or_insert(0.0) += e;
127 }
128 Monomial { coeff, vars }
129 }
130}
131
132#[derive(Debug, Clone)]
134struct NormalizedPoly {
135 terms: Vec<Monomial>,
136}
137
138impl NormalizedPoly {
139 fn add(mut self, other: NormalizedPoly) -> NormalizedPoly {
140 self.terms.extend(other.terms);
141 self
142 }
143
144 fn mul(&self, other: &NormalizedPoly) -> NormalizedPoly {
145 let mut terms = Vec::new();
146 for a in &self.terms {
147 for b in &other.terms {
148 terms.push(a.mul(b));
149 }
150 }
151 NormalizedPoly { terms }
152 }
153
154 fn has_negative_coefficients(&self) -> bool {
156 self.terms.iter().any(|m| m.coeff < -1e-15)
157 }
158}
159
160fn normalize_polynomial(expr: &Expr) -> Result<NormalizedPoly, String> {
166 match expr {
167 Expr::Const(c) => Ok(NormalizedPoly {
168 terms: vec![Monomial::constant(*c)],
169 }),
170 Expr::Var(v) => Ok(NormalizedPoly {
171 terms: vec![Monomial::variable(v)],
172 }),
173 Expr::Add(a, b) => {
174 let pa = normalize_polynomial(a)?;
175 let pb = normalize_polynomial(b)?;
176 Ok(pa.add(pb))
177 }
178 Expr::Mul(a, b) => {
179 let pa = normalize_polynomial(a)?;
180 let pb = normalize_polynomial(b)?;
181 Ok(pa.mul(&pb))
182 }
183 Expr::Pow(base, exp) => {
184 if let Expr::Const(c) = exp.as_ref() {
185 if *c < 0.0 {
186 return Err(format!("negative exponent: {c}"));
187 }
188 let pb = normalize_polynomial(base)?;
189 if pb.terms.len() == 1 {
191 let m = &pb.terms[0];
192 let coeff = m.coeff.powf(*c);
193 let vars: BTreeMap<_, _> = m.vars.iter().map(|(&v, &e)| (v, e * c)).collect();
194 return Ok(NormalizedPoly {
195 terms: vec![Monomial { coeff, vars }],
196 });
197 }
198 let n = *c as usize;
200 if c.fract().abs() < 1e-10 {
201 if n == 0 {
202 return Ok(NormalizedPoly {
203 terms: vec![Monomial::constant(1.0)],
204 });
205 }
206 let mut result = pb.clone();
207 for _ in 1..n {
208 result = result.mul(&pb);
209 }
210 return Ok(result);
211 }
212 Err(format!(
213 "non-integer power of multi-term polynomial: ({base})^{c}"
214 ))
215 } else {
216 Err(format!("variable exponent: ({base})^({exp})"))
217 }
218 }
219 Expr::Exp(_) => Err("exp() not supported".into()),
220 Expr::Log(_) => Err("log() not supported".into()),
221 Expr::Sqrt(_) => Err("sqrt() not supported".into()),
222 Expr::Factorial(_) => Err("factorial() not supported".into()),
223 }
224}
225
226fn prepare_expr_for_comparison(expr: &Expr) -> Expr {
227 canonical_form(expr).unwrap_or_else(|_| expr.clone())
228}
229
230fn monomial_dominated_by(small: &Monomial, big: &Monomial) -> bool {
237 for (&var, &exp_small) in &small.vars {
238 let exp_big = big.vars.get(var).copied().unwrap_or(0.0);
239 if exp_small > exp_big + 1e-10 {
240 return false;
241 }
242 }
243 true
244}
245
246fn poly_leq(a: &NormalizedPoly, b: &NormalizedPoly) -> bool {
251 let b_positive: Vec<&Monomial> = b.terms.iter().filter(|m| m.coeff > 1e-15).collect();
252
253 for a_term in &a.terms {
254 if a_term.coeff <= 1e-15 {
255 continue; }
257 let dominated = b_positive
258 .iter()
259 .any(|b_term| monomial_dominated_by(a_term, b_term));
260 if !dominated {
261 return false;
262 }
263 }
264 true
265}
266
267pub fn compare_overhead(
276 primitive: &ReductionOverhead,
277 composite: &ReductionOverhead,
278) -> ComparisonStatus {
279 let comp_map: std::collections::HashMap<&str, &Expr> = composite
280 .output_size
281 .iter()
282 .map(|(name, expr)| (*name, expr))
283 .collect();
284
285 let mut any_common = false;
286
287 for (field, prim_expr) in &primitive.output_size {
288 let Some(comp_expr) = comp_map.get(field) else {
289 continue;
290 };
291 any_common = true;
292
293 let primitive_prepared = prepare_expr_for_comparison(prim_expr);
294 let composite_prepared = prepare_expr_for_comparison(comp_expr);
295
296 if primitive_prepared == composite_prepared {
297 continue;
298 }
299
300 let primitive_poly = match normalize_polynomial(&primitive_prepared) {
301 Ok(p) => p,
302 Err(_) => return ComparisonStatus::Unknown,
303 };
304 let composite_poly = match normalize_polynomial(&composite_prepared) {
305 Ok(p) => p,
306 Err(_) => return ComparisonStatus::Unknown,
307 };
308
309 if primitive_poly.has_negative_coefficients() || composite_poly.has_negative_coefficients()
311 {
312 return ComparisonStatus::Unknown;
313 }
314
315 if !poly_leq(&composite_poly, &primitive_poly) {
317 return ComparisonStatus::NotDominated;
318 }
319 }
320
321 if any_common {
322 ComparisonStatus::Dominated
323 } else {
324 ComparisonStatus::NotDominated
325 }
326}
327
328pub fn find_dominated_rules(
344 graph: &ReductionGraph,
345) -> (Vec<DominatedRule>, Vec<UnknownComparison>) {
346 const MAX_PATHS_PER_EDGE: usize = 1024;
347 const MAX_INTERMEDIATE_NODES: usize = 6;
348
349 let mut dominated = Vec::new();
350 let mut unknown = Vec::new();
351
352 for edge_info in all_edges(graph) {
353 let paths = graph.find_paths_up_to_mode_bounded(
354 edge_info.source_name,
355 &edge_info.source_variant,
356 edge_info.target_name,
357 &edge_info.target_variant,
358 crate::rules::graph::ReductionMode::Witness,
359 MAX_PATHS_PER_EDGE,
360 Some(MAX_INTERMEDIATE_NODES),
361 );
362
363 let mut best_dominating: Option<(ReductionPath, ReductionOverhead, Vec<String>)> = None;
364
365 for path in paths {
366 if path.len() <= 1 {
367 continue; }
369
370 let composed = graph.compose_path_overhead(&path);
371
372 match compare_overhead(&edge_info.overhead, &composed) {
373 ComparisonStatus::Dominated => {
374 let comparable_fields = common_fields(&edge_info.overhead, &composed);
375 let is_better = match &best_dominating {
376 None => true,
377 Some((best_path, _, _)) => path.len() < best_path.len(),
378 };
379 if is_better {
380 best_dominating = Some((path, composed, comparable_fields));
381 }
382 }
383 ComparisonStatus::Unknown => {
384 unknown.push(UnknownComparison {
385 source_name: edge_info.source_name,
386 source_variant: edge_info.source_variant.clone(),
387 target_name: edge_info.target_name,
388 target_variant: edge_info.target_variant.clone(),
389 candidate_path: path,
390 reason: "expression comparison returned Unknown".into(),
391 });
392 }
393 ComparisonStatus::NotDominated => {}
394 }
395 }
396
397 if let Some((path, composed, fields)) = best_dominating {
398 dominated.push(DominatedRule {
399 source_name: edge_info.source_name,
400 source_variant: edge_info.source_variant.clone(),
401 target_name: edge_info.target_name,
402 target_variant: edge_info.target_variant.clone(),
403 primitive_overhead: edge_info.overhead.clone(),
404 dominating_path: path,
405 composed_overhead: composed,
406 comparable_fields: fields,
407 });
408 }
409 }
410
411 dominated.sort_by(|a, b| {
413 (
414 format_problem_variant(a.source_name, &a.source_variant),
415 format_problem_variant(a.target_name, &a.target_variant),
416 a.dominating_path.len(),
417 )
418 .cmp(&(
419 format_problem_variant(b.source_name, &b.source_variant),
420 format_problem_variant(b.target_name, &b.target_variant),
421 b.dominating_path.len(),
422 ))
423 });
424 unknown.sort_by(|a, b| {
425 (
426 format_problem_variant(a.source_name, &a.source_variant),
427 format_problem_variant(a.target_name, &a.target_variant),
428 )
429 .cmp(&(
430 format_problem_variant(b.source_name, &b.source_variant),
431 format_problem_variant(b.target_name, &b.target_variant),
432 ))
433 });
434
435 (dominated, unknown)
436}
437
438fn common_fields(a: &ReductionOverhead, b: &ReductionOverhead) -> Vec<String> {
440 let b_fields: std::collections::HashSet<&str> = b.output_size.iter().map(|(n, _)| *n).collect();
441 a.output_size
442 .iter()
443 .filter(|&(f, _)| b_fields.contains(f))
444 .map(|(f, _)| f.to_string())
445 .collect()
446}
447
448fn all_edges(graph: &ReductionGraph) -> Vec<crate::rules::graph::ReductionEdgeInfo> {
450 let mut edges = Vec::new();
451 for name in graph.problem_types() {
452 edges.extend(graph.outgoing_reductions(name));
453 }
454 edges
455}
456
457#[derive(Debug, Clone)]
461pub struct ConnectivityReport {
462 pub total_types: usize,
464 pub total_reductions: usize,
466 pub isolated: Vec<IsolatedProblem>,
468 pub components: Vec<Vec<&'static str>>,
471}
472
473#[derive(Debug, Clone)]
475pub struct IsolatedProblem {
476 pub name: &'static str,
477 pub num_variants: usize,
478 pub variant_complexities: Vec<(BTreeMap<String, String>, Option<String>)>,
480}
481
482pub fn check_connectivity(graph: &ReductionGraph) -> ConnectivityReport {
484 let mut types = graph.problem_types();
485 types.sort();
486
487 let mut adj: BTreeMap<&str, BTreeSet<&str>> = BTreeMap::new();
489 for &name in &types {
490 adj.entry(name).or_default();
491 for edge in graph.outgoing_reductions(name) {
492 adj.entry(name).or_default().insert(edge.target_name);
493 adj.entry(edge.target_name).or_default().insert(name);
494 }
495 }
496
497 let mut visited: BTreeSet<&str> = BTreeSet::new();
499 let mut components: Vec<Vec<&str>> = Vec::new();
500
501 for &name in &types {
502 if visited.contains(name) {
503 continue;
504 }
505 let mut component = Vec::new();
506 let mut queue = std::collections::VecDeque::new();
507 queue.push_back(name);
508 visited.insert(name);
509
510 while let Some(current) = queue.pop_front() {
511 component.push(current);
512 if let Some(neighbors) = adj.get(current) {
513 for &neighbor in neighbors {
514 if visited.insert(neighbor) {
515 queue.push_back(neighbor);
516 }
517 }
518 }
519 }
520 component.sort();
521 components.push(component);
522 }
523
524 components.sort_by_key(|c| std::cmp::Reverse(c.len()));
525
526 let isolated: Vec<IsolatedProblem> = types
527 .iter()
528 .copied()
529 .filter(|name| adj.get(name).is_some_and(|n| n.is_empty()))
530 .map(|name| {
531 let variants = graph.variants_for(name);
532 let variant_complexities = variants
533 .iter()
534 .map(|v| {
535 let c = graph.variant_complexity(name, v).map(|e| e.to_string());
536 (v.clone(), c)
537 })
538 .collect();
539 IsolatedProblem {
540 name,
541 num_variants: variants.len(),
542 variant_complexities,
543 }
544 })
545 .collect();
546
547 ConnectivityReport {
548 total_types: types.len(),
549 total_reductions: graph.num_reductions(),
550 isolated,
551 components,
552 }
553}
554
555#[derive(Debug, Clone, PartialEq, Eq)]
557pub enum UnreachableReason {
558 InP,
560 Intermediate,
562 Orphan,
564 MissingProofChain,
566}
567
568#[derive(Debug, Clone)]
570pub struct UnreachableProblem {
571 pub name: &'static str,
572 pub reason: UnreachableReason,
573 pub outgoing_count: usize,
574 pub incoming_count: usize,
575}
576
577#[derive(Debug, Clone)]
579pub struct ReachabilityReport {
580 pub total_types: usize,
582 pub reachable: BTreeMap<&'static str, usize>,
584 pub unreachable: Vec<UnreachableProblem>,
586}
587
588impl ReachabilityReport {
589 pub fn missing_proof_chains(&self) -> Vec<&UnreachableProblem> {
591 self.unreachable
592 .iter()
593 .filter(|p| p.reason == UnreachableReason::MissingProofChain)
594 .collect()
595 }
596}
597
598pub fn check_reachability_from_3sat(graph: &ReductionGraph) -> ReachabilityReport {
602 const SOURCE: &str = "KSatisfiability";
603
604 let mut types = graph.problem_types();
605 types.sort();
606
607 let mut adj: BTreeMap<&str, BTreeSet<&str>> = BTreeMap::new();
609 for &name in &types {
610 adj.entry(name).or_default();
611 for edge in graph.outgoing_reductions(name) {
612 adj.entry(name).or_default().insert(edge.target_name);
613 }
614 }
615
616 let mut reachable: BTreeMap<&'static str, usize> = BTreeMap::new();
618 let mut queue: std::collections::VecDeque<(&str, usize)> = std::collections::VecDeque::new();
619 reachable.insert(SOURCE, 0);
620 queue.push_back((SOURCE, 0));
621
622 while let Some((current, hops)) = queue.pop_front() {
623 if let Some(neighbors) = adj.get(current) {
624 for &neighbor in neighbors {
625 if !reachable.contains_key(neighbor) {
626 reachable.insert(neighbor, hops + 1);
627 queue.push_back((neighbor, hops + 1));
628 }
629 }
630 }
631 }
632
633 let p_time_checks: &[(&str, Option<(&str, &str)>)] = &[
635 ("MaximumMatching", None),
636 ("KSatisfiability", Some(("k", "K2"))),
637 ("KColoring", Some(("graph", "SimpleGraph"))),
638 ];
639
640 let intermediate_names: &[&str] = &["Factoring"];
641
642 let mut unreachable_problems: Vec<UnreachableProblem> = Vec::new();
643
644 for &name in &types {
645 if reachable.contains_key(name) {
646 continue;
647 }
648
649 let out_count = graph.outgoing_reductions(name).len();
650 let in_count = graph.incoming_reductions(name).len();
651
652 if out_count == 0 && in_count == 0 {
654 unreachable_problems.push(UnreachableProblem {
655 name,
656 reason: UnreachableReason::Orphan,
657 outgoing_count: 0,
658 incoming_count: 0,
659 });
660 continue;
661 }
662
663 let is_p = p_time_checks.iter().any(|(pname, variant_check)| {
665 if *pname != name {
666 return false;
667 }
668 match variant_check {
669 None => true,
670 Some((key, val)) => {
671 let variants = graph.variants_for(name);
672 variants.len() == 1 && variants[0].get(*key).map(|s| s.as_str()) == Some(*val)
673 }
674 }
675 });
676 if is_p {
677 unreachable_problems.push(UnreachableProblem {
678 name,
679 reason: UnreachableReason::InP,
680 outgoing_count: out_count,
681 incoming_count: in_count,
682 });
683 continue;
684 }
685
686 if intermediate_names.contains(&name) {
688 unreachable_problems.push(UnreachableProblem {
689 name,
690 reason: UnreachableReason::Intermediate,
691 outgoing_count: out_count,
692 incoming_count: in_count,
693 });
694 continue;
695 }
696
697 unreachable_problems.push(UnreachableProblem {
699 name,
700 reason: UnreachableReason::MissingProofChain,
701 outgoing_count: out_count,
702 incoming_count: in_count,
703 });
704 }
705
706 ReachabilityReport {
707 total_types: types.len(),
708 reachable,
709 unreachable: unreachable_problems,
710 }
711}
712
713#[cfg(test)]
714#[path = "../unit_tests/rules/analysis.rs"]
715mod tests;