problemreductions/rules/
analysis.rs

1//! Analysis utilities for the reduction graph.
2//!
3//! Detects primitive reduction rules that are dominated by composite paths,
4//! using asymptotic normalization plus monomial-dominance comparison.
5//!
6//! This analysis is **sound but incomplete**: it reports `Dominated` only when
7//! the symbolic comparison is trustworthy, and `Unknown` when metadata is too
8//! weak to compare safely.
9
10use crate::canonical::canonical_form;
11use crate::expr::Expr;
12use crate::rules::graph::{ReductionGraph, ReductionPath};
13use crate::rules::registry::ReductionOverhead;
14use std::collections::BTreeMap;
15use std::fmt;
16
17/// Result of comparing one primitive rule against one composite path.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum ComparisonStatus {
20    /// Composite is equal or better on all common fields.
21    Dominated,
22    /// Composite is worse on at least one common field.
23    NotDominated,
24    /// Cannot decide: expression not normalizable or path not trustworthy.
25    Unknown,
26}
27
28/// A primitive reduction rule proven dominated by a composite path.
29#[derive(Debug, Clone)]
30pub struct DominatedRule {
31    pub source_name: &'static str,
32    pub source_variant: BTreeMap<String, String>,
33    pub target_name: &'static str,
34    pub target_variant: BTreeMap<String, String>,
35    pub primitive_overhead: ReductionOverhead,
36    pub dominating_path: ReductionPath,
37    pub composed_overhead: ReductionOverhead,
38    pub comparable_fields: Vec<String>,
39}
40
41impl DominatedRule {
42    pub fn source_display(&self) -> String {
43        format_problem_variant(self.source_name, &self.source_variant)
44    }
45
46    pub fn target_display(&self) -> String {
47        format_problem_variant(self.target_name, &self.target_variant)
48    }
49}
50
51impl fmt::Display for DominatedRule {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        write!(f, "{} -> {}", self.source_display(), self.target_display())
54    }
55}
56
57/// A candidate comparison that could not be decided soundly.
58#[derive(Debug, Clone)]
59pub struct UnknownComparison {
60    pub source_name: &'static str,
61    pub source_variant: BTreeMap<String, String>,
62    pub target_name: &'static str,
63    pub target_variant: BTreeMap<String, String>,
64    pub candidate_path: ReductionPath,
65    pub reason: String,
66}
67
68impl UnknownComparison {
69    pub fn source_display(&self) -> String {
70        format_problem_variant(self.source_name, &self.source_variant)
71    }
72
73    pub fn target_display(&self) -> String {
74        format_problem_variant(self.target_name, &self.target_variant)
75    }
76}
77
78impl fmt::Display for UnknownComparison {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        write!(f, "{} -> {}", self.source_display(), self.target_display())
81    }
82}
83
84pub fn format_problem_variant(name: &str, variant: &BTreeMap<String, String>) -> String {
85    if variant.is_empty() {
86        return name.to_string();
87    }
88
89    let vars = variant
90        .iter()
91        .map(|(k, v)| format!("{k}: {v:?}"))
92        .collect::<Vec<_>>()
93        .join(", ");
94    format!("{name} {{{vars}}}")
95}
96
97// ────────── Polynomial normalization ──────────
98
99/// A monomial: coefficient × ∏(variable ^ exponent).
100#[derive(Debug, Clone)]
101struct Monomial {
102    coeff: f64,
103    /// Variable name → exponent. Only non-zero exponents stored.
104    vars: BTreeMap<&'static str, f64>,
105}
106
107impl Monomial {
108    fn constant(c: f64) -> Self {
109        Self {
110            coeff: c,
111            vars: BTreeMap::new(),
112        }
113    }
114
115    fn variable(name: &'static str) -> Self {
116        let mut vars = BTreeMap::new();
117        vars.insert(name, 1.0);
118        Self { coeff: 1.0, vars }
119    }
120
121    /// Multiply two monomials.
122    fn mul(&self, other: &Monomial) -> Monomial {
123        let coeff = self.coeff * other.coeff;
124        let mut vars = self.vars.clone();
125        for (&v, &e) in &other.vars {
126            *vars.entry(v).or_insert(0.0) += e;
127        }
128        Monomial { coeff, vars }
129    }
130}
131
132/// A polynomial (sum of monomials) in normal form.
133#[derive(Debug, Clone)]
134struct NormalizedPoly {
135    terms: Vec<Monomial>,
136}
137
138impl NormalizedPoly {
139    fn add(mut self, other: NormalizedPoly) -> NormalizedPoly {
140        self.terms.extend(other.terms);
141        self
142    }
143
144    fn mul(&self, other: &NormalizedPoly) -> NormalizedPoly {
145        let mut terms = Vec::new();
146        for a in &self.terms {
147            for b in &other.terms {
148                terms.push(a.mul(b));
149            }
150        }
151        NormalizedPoly { terms }
152    }
153
154    /// True if any monomial has a negative coefficient.
155    fn has_negative_coefficients(&self) -> bool {
156        self.terms.iter().any(|m| m.coeff < -1e-15)
157    }
158}
159
160/// Normalize an expression into a sum of monomials.
161///
162/// Supports: constants, variables, addition, multiplication,
163/// and powers with non-negative constant exponents.
164/// Returns `Err` for exp, log, sqrt, division, and negative exponents.
165fn normalize_polynomial(expr: &Expr) -> Result<NormalizedPoly, String> {
166    match expr {
167        Expr::Const(c) => Ok(NormalizedPoly {
168            terms: vec![Monomial::constant(*c)],
169        }),
170        Expr::Var(v) => Ok(NormalizedPoly {
171            terms: vec![Monomial::variable(v)],
172        }),
173        Expr::Add(a, b) => {
174            let pa = normalize_polynomial(a)?;
175            let pb = normalize_polynomial(b)?;
176            Ok(pa.add(pb))
177        }
178        Expr::Mul(a, b) => {
179            let pa = normalize_polynomial(a)?;
180            let pb = normalize_polynomial(b)?;
181            Ok(pa.mul(&pb))
182        }
183        Expr::Pow(base, exp) => {
184            if let Expr::Const(c) = exp.as_ref() {
185                if *c < 0.0 {
186                    return Err(format!("negative exponent: {c}"));
187                }
188                let pb = normalize_polynomial(base)?;
189                // Single monomial: multiply exponents
190                if pb.terms.len() == 1 {
191                    let m = &pb.terms[0];
192                    let coeff = m.coeff.powf(*c);
193                    let vars: BTreeMap<_, _> = m.vars.iter().map(|(&v, &e)| (v, e * c)).collect();
194                    return Ok(NormalizedPoly {
195                        terms: vec![Monomial { coeff, vars }],
196                    });
197                }
198                // Multi-term polynomial raised to non-negative integer power
199                let n = *c as usize;
200                if c.fract().abs() < 1e-10 {
201                    if n == 0 {
202                        return Ok(NormalizedPoly {
203                            terms: vec![Monomial::constant(1.0)],
204                        });
205                    }
206                    let mut result = pb.clone();
207                    for _ in 1..n {
208                        result = result.mul(&pb);
209                    }
210                    return Ok(result);
211                }
212                Err(format!(
213                    "non-integer power of multi-term polynomial: ({base})^{c}"
214                ))
215            } else {
216                Err(format!("variable exponent: ({base})^({exp})"))
217            }
218        }
219        Expr::Exp(_) => Err("exp() not supported".into()),
220        Expr::Log(_) => Err("log() not supported".into()),
221        Expr::Sqrt(_) => Err("sqrt() not supported".into()),
222    }
223}
224
225fn prepare_expr_for_comparison(expr: &Expr) -> Expr {
226    canonical_form(expr).unwrap_or_else(|_| expr.clone())
227}
228
229// ────────── Monomial-dominance comparison ──────────
230
231/// Check if monomial `small` is asymptotically dominated by monomial `big`.
232///
233/// True iff for every variable in `small`, `big` has at least as large an exponent.
234/// This means `small` grows no faster than `big` as all variables → ∞.
235fn monomial_dominated_by(small: &Monomial, big: &Monomial) -> bool {
236    for (&var, &exp_small) in &small.vars {
237        let exp_big = big.vars.get(var).copied().unwrap_or(0.0);
238        if exp_small > exp_big + 1e-10 {
239            return false;
240        }
241    }
242    true
243}
244
245/// Check if polynomial `a` is asymptotically ≤ polynomial `b`.
246///
247/// True iff every positive-coefficient monomial in `a` is dominated by
248/// some positive-coefficient monomial in `b`.
249fn poly_leq(a: &NormalizedPoly, b: &NormalizedPoly) -> bool {
250    let b_positive: Vec<&Monomial> = b.terms.iter().filter(|m| m.coeff > 1e-15).collect();
251
252    for a_term in &a.terms {
253        if a_term.coeff <= 1e-15 {
254            continue; // zero or negative — can only make `a` smaller
255        }
256        let dominated = b_positive
257            .iter()
258            .any(|b_term| monomial_dominated_by(a_term, b_term));
259        if !dominated {
260            return false;
261        }
262    }
263    true
264}
265
266// ────────── Overhead comparison ──────────
267
268/// Compare two overheads across all common fields.
269///
270/// Returns `Dominated` if composite ≤ primitive on all common fields.
271/// Returns `NotDominated` if composite is worse on any common field.
272/// Returns `Unknown` if any common field's expressions cannot be normalized
273/// into a comparable polynomial form or contain negative coefficients.
274pub fn compare_overhead(
275    primitive: &ReductionOverhead,
276    composite: &ReductionOverhead,
277) -> ComparisonStatus {
278    let comp_map: std::collections::HashMap<&str, &Expr> = composite
279        .output_size
280        .iter()
281        .map(|(name, expr)| (*name, expr))
282        .collect();
283
284    let mut any_common = false;
285
286    for (field, prim_expr) in &primitive.output_size {
287        let Some(comp_expr) = comp_map.get(field) else {
288            continue;
289        };
290        any_common = true;
291
292        let primitive_prepared = prepare_expr_for_comparison(prim_expr);
293        let composite_prepared = prepare_expr_for_comparison(comp_expr);
294
295        if primitive_prepared == composite_prepared {
296            continue;
297        }
298
299        let primitive_poly = match normalize_polynomial(&primitive_prepared) {
300            Ok(p) => p,
301            Err(_) => return ComparisonStatus::Unknown,
302        };
303        let composite_poly = match normalize_polynomial(&composite_prepared) {
304            Ok(p) => p,
305            Err(_) => return ComparisonStatus::Unknown,
306        };
307
308        // Reject expressions with negative coefficients
309        if primitive_poly.has_negative_coefficients() || composite_poly.has_negative_coefficients()
310        {
311            return ComparisonStatus::Unknown;
312        }
313
314        // Check: composite ≤ primitive on this field
315        if !poly_leq(&composite_poly, &primitive_poly) {
316            return ComparisonStatus::NotDominated;
317        }
318    }
319
320    if any_common {
321        ComparisonStatus::Dominated
322    } else {
323        ComparisonStatus::NotDominated
324    }
325}
326
327// ────────── Main analysis ──────────
328
329/// Find all primitive reduction rules dominated by composite paths.
330///
331/// Returns a tuple of:
332/// - `Vec<DominatedRule>`: rules proven dominated by a composite path
333/// - `Vec<UnknownComparison>`: candidates that could not be decided
334///
335/// For each primitive rule (direct edge), enumerates all alternative paths,
336/// validates trustworthiness, composes overheads, and compares.
337/// Keeps only the best (shortest) dominating path per primitive rule.
338///
339/// Note: iterates the graph's coalesced edges rather than raw `inventory` entries.
340/// This is sound because `test_no_duplicate_primitive_rules_per_variant_pair` guards
341/// the invariant that at most one registration exists per (source_variant, target_variant) pair.
342pub fn find_dominated_rules(
343    graph: &ReductionGraph,
344) -> (Vec<DominatedRule>, Vec<UnknownComparison>) {
345    let mut dominated = Vec::new();
346    let mut unknown = Vec::new();
347
348    for edge_info in all_edges(graph) {
349        let paths = graph.find_all_paths(
350            edge_info.source_name,
351            &edge_info.source_variant,
352            edge_info.target_name,
353            &edge_info.target_variant,
354        );
355
356        let mut best_dominating: Option<(ReductionPath, ReductionOverhead, Vec<String>)> = None;
357
358        for path in paths {
359            if path.len() <= 1 {
360                continue; // skip the direct edge itself
361            }
362
363            let composed = graph.compose_path_overhead(&path);
364
365            match compare_overhead(&edge_info.overhead, &composed) {
366                ComparisonStatus::Dominated => {
367                    let comparable_fields = common_fields(&edge_info.overhead, &composed);
368                    let is_better = match &best_dominating {
369                        None => true,
370                        Some((best_path, _, _)) => path.len() < best_path.len(),
371                    };
372                    if is_better {
373                        best_dominating = Some((path, composed, comparable_fields));
374                    }
375                }
376                ComparisonStatus::Unknown => {
377                    unknown.push(UnknownComparison {
378                        source_name: edge_info.source_name,
379                        source_variant: edge_info.source_variant.clone(),
380                        target_name: edge_info.target_name,
381                        target_variant: edge_info.target_variant.clone(),
382                        candidate_path: path,
383                        reason: "expression comparison returned Unknown".into(),
384                    });
385                }
386                ComparisonStatus::NotDominated => {}
387            }
388        }
389
390        if let Some((path, composed, fields)) = best_dominating {
391            dominated.push(DominatedRule {
392                source_name: edge_info.source_name,
393                source_variant: edge_info.source_variant.clone(),
394                target_name: edge_info.target_name,
395                target_variant: edge_info.target_variant.clone(),
396                primitive_overhead: edge_info.overhead.clone(),
397                dominating_path: path,
398                composed_overhead: composed,
399                comparable_fields: fields,
400            });
401        }
402    }
403
404    // Deterministic output
405    dominated.sort_by(|a, b| {
406        (
407            format_problem_variant(a.source_name, &a.source_variant),
408            format_problem_variant(a.target_name, &a.target_variant),
409            a.dominating_path.len(),
410        )
411            .cmp(&(
412                format_problem_variant(b.source_name, &b.source_variant),
413                format_problem_variant(b.target_name, &b.target_variant),
414                b.dominating_path.len(),
415            ))
416    });
417    unknown.sort_by(|a, b| {
418        (
419            format_problem_variant(a.source_name, &a.source_variant),
420            format_problem_variant(a.target_name, &a.target_variant),
421        )
422            .cmp(&(
423                format_problem_variant(b.source_name, &b.source_variant),
424                format_problem_variant(b.target_name, &b.target_variant),
425            ))
426    });
427
428    (dominated, unknown)
429}
430
431/// Fields present in both overheads.
432fn common_fields(a: &ReductionOverhead, b: &ReductionOverhead) -> Vec<String> {
433    let b_fields: std::collections::HashSet<&str> = b.output_size.iter().map(|(n, _)| *n).collect();
434    a.output_size
435        .iter()
436        .filter(|&(f, _)| b_fields.contains(f))
437        .map(|(f, _)| f.to_string())
438        .collect()
439}
440
441/// Collect all edges from the reduction graph.
442fn all_edges(graph: &ReductionGraph) -> Vec<crate::rules::graph::ReductionEdgeInfo> {
443    let mut edges = Vec::new();
444    for name in graph.problem_types() {
445        edges.extend(graph.outgoing_reductions(name));
446    }
447    edges
448}
449
450#[cfg(test)]
451#[path = "../unit_tests/rules/analysis.rs"]
452mod tests;