Skip to main content

problemreductions/
big_o.rs

1//! Big-O asymptotic projection for canonical expressions.
2//!
3//! Takes the output of `canonical_form()` and projects it into an
4//! asymptotic growth class by dropping dominated terms and constant factors.
5
6use crate::canonical::canonical_form;
7use crate::expr::{AsymptoticAnalysisError, CanonicalizationError, Expr};
8
9#[derive(Clone, Debug)]
10struct ProjectedTerm {
11    expr: Expr,
12    negative: bool,
13}
14
15/// Compute the Big-O normal form of an expression.
16///
17/// This is a two-phase pipeline:
18/// 1. `canonical_form()` — exact symbolic simplification
19/// 2. Asymptotic projection — drop dominated terms and constant factors
20///
21/// Returns an expression representing the asymptotic growth class.
22pub fn big_o_normal_form(expr: &Expr) -> Result<Expr, AsymptoticAnalysisError> {
23    let canonical = canonical_form(expr).map_err(|e| match e {
24        CanonicalizationError::Unsupported(s) => AsymptoticAnalysisError::Unsupported(s),
25    })?;
26
27    project_big_o(&canonical)
28}
29
30/// Project a canonicalized expression into its Big-O growth class.
31fn project_big_o(expr: &Expr) -> Result<Expr, AsymptoticAnalysisError> {
32    // Decompose into additive terms
33    let mut terms = Vec::new();
34    collect_additive_terms(expr, &mut terms);
35
36    // Project each term: drop constant multiplicative factors
37    let mut projected: Vec<ProjectedTerm> = Vec::new();
38    for term in &terms {
39        if let Some(projected_term) = project_term(term)? {
40            projected.push(projected_term);
41        }
42        // Pure constants are dropped (asymptotically irrelevant)
43    }
44
45    // Remove dominated terms
46    let survivors = remove_dominated_terms(projected);
47
48    if survivors.is_empty() {
49        // All terms were constants → O(1)
50        return Ok(Expr::Const(1.0));
51    }
52
53    if let Some(negative) = survivors.iter().find(|term| term.negative) {
54        return Err(AsymptoticAnalysisError::Unsupported(format!(
55            "-1 * {}",
56            negative.expr
57        )));
58    }
59
60    // Deduplicate
61    let mut seen = std::collections::BTreeSet::new();
62    let mut deduped = Vec::new();
63    for term in survivors {
64        let key = term.expr.to_string();
65        if seen.insert(key) {
66            deduped.push(term);
67        }
68    }
69
70    // Rebuild sum
71    let mut result = deduped[0].expr.clone();
72    for term in &deduped[1..] {
73        result = result + term.expr.clone();
74    }
75
76    Ok(result)
77}
78
79fn collect_additive_terms(expr: &Expr, out: &mut Vec<Expr>) {
80    match expr {
81        Expr::Add(a, b) => {
82            collect_additive_terms(a, out);
83            collect_additive_terms(b, out);
84        }
85        other => out.push(other.clone()),
86    }
87}
88
89/// Project a single multiplicative term: strip constant factors.
90/// Returns None if the term is a pure constant.
91fn project_term(term: &Expr) -> Result<Option<ProjectedTerm>, AsymptoticAnalysisError> {
92    if term.constant_value().is_some() {
93        return Ok(None); // Pure constant → dropped
94    }
95
96    // Collect multiplicative factors
97    let mut factors = Vec::new();
98    collect_multiplicative_factors(term, &mut factors);
99
100    let mut coeff = 1.0;
101    let mut symbolic = Vec::new();
102    for factor in &factors {
103        if let Some(c) = factor.constant_value() {
104            coeff *= c;
105            continue;
106        }
107        if contains_negative_exponent(factor) {
108            return Err(AsymptoticAnalysisError::Unsupported(term.to_string()));
109        }
110        symbolic.push(factor.clone());
111    }
112
113    if symbolic.is_empty() {
114        return Ok(None);
115    }
116
117    let mut result = symbolic[0].clone();
118    for f in &symbolic[1..] {
119        result = result * f.clone();
120    }
121
122    Ok(Some(ProjectedTerm {
123        expr: result,
124        negative: coeff < 0.0,
125    }))
126}
127
128fn collect_multiplicative_factors(expr: &Expr, out: &mut Vec<Expr>) {
129    match expr {
130        Expr::Mul(a, b) => {
131            collect_multiplicative_factors(a, out);
132            collect_multiplicative_factors(b, out);
133        }
134        other => out.push(other.clone()),
135    }
136}
137
138/// Remove terms dominated by other terms using monomial comparison.
139///
140/// A term `t` is dominated if there exists another term `s` such that
141/// `t` grows no faster than `s` asymptotically.
142fn remove_dominated_terms(terms: Vec<ProjectedTerm>) -> Vec<ProjectedTerm> {
143    if terms.len() <= 1 {
144        return terms;
145    }
146
147    let mut survivors = Vec::new();
148    for (i, term) in terms.iter().enumerate() {
149        let is_dominated = terms
150            .iter()
151            .enumerate()
152            .any(|(j, other)| i != j && term_dominated_by(&term.expr, &other.expr));
153        if !is_dominated {
154            survivors.push(term.clone());
155        }
156    }
157    survivors
158}
159
160/// Check if `small` is asymptotically dominated by `big`.
161///
162/// Supports three comparison strategies:
163/// 1. Polynomial monomial exponent comparison (exact)
164/// 2. Exponential vs subexponential / base comparison (structural)
165/// 3. Numerical evaluation at two scales (for subexponential cross-class)
166fn term_dominated_by(small: &Expr, big: &Expr) -> bool {
167    // Case 1: Both pure polynomial monomials — use exponent comparison
168    let small_exps = extract_var_exponents(small);
169    let big_exps = extract_var_exponents(big);
170    if let (Some(ref se), Some(ref be)) = (small_exps, big_exps) {
171        return polynomial_dominated(se, be);
172    }
173
174    // Cross-class comparison: small's variables must be a subset of big's
175    let small_vars = small.variables();
176    let big_vars = big.variables();
177    if small_vars.is_empty() || big_vars.is_empty() || !small_vars.is_subset(&big_vars) {
178        return false;
179    }
180
181    // Case 2: Exponential comparison
182    let small_has_exp = has_exponential_growth(small);
183    let big_has_exp = has_exponential_growth(big);
184    match (small_has_exp, big_has_exp) {
185        (false, true) => return true,  // exponential dominates subexponential
186        (true, false) => return false, // subexponential can't dominate exponential
187        (true, true) => {
188            // Compare effective exponential bases
189            if let (Some(sb), Some(bb)) = (effective_exp_base(small), effective_exp_base(big)) {
190                if bb > sb * (1.0 + 1e-10) {
191                    return true;
192                }
193            }
194            return false;
195        }
196        (false, false) => {} // both subexponential, fall through
197    }
198
199    // Case 3: Both subexponential, same variables — numerical comparison
200    // Handles: poly vs poly*log, log vs log(log), poly vs log, etc.
201    if small_vars == big_vars {
202        return numerical_dominance_check(small, big, &small_vars);
203    }
204
205    false
206}
207
208/// Check polynomial dominance: small ≤ big component-wise with at least one strict inequality.
209fn polynomial_dominated(
210    se: &std::collections::BTreeMap<&'static str, f64>,
211    be: &std::collections::BTreeMap<&'static str, f64>,
212) -> bool {
213    let mut all_leq = true;
214    let mut any_strictly_less = false;
215
216    for (var, small_exp) in se {
217        let big_exp = be.get(var).copied().unwrap_or(0.0);
218        if *small_exp > big_exp + 1e-15 {
219            all_leq = false;
220            break;
221        }
222        if *small_exp < big_exp - 1e-15 {
223            any_strictly_less = true;
224        }
225    }
226
227    if all_leq {
228        for (var, big_exp) in be {
229            if !se.contains_key(var) && *big_exp > 1e-15 {
230                any_strictly_less = true;
231            }
232        }
233    }
234
235    all_leq && any_strictly_less
236}
237
238/// Extract variable → exponent mapping from a monomial expression.
239/// Returns None for non-polynomial terms (exp, log, etc.).
240fn extract_var_exponents(expr: &Expr) -> Option<std::collections::BTreeMap<&'static str, f64>> {
241    use std::collections::BTreeMap;
242    let mut exps = BTreeMap::new();
243    extract_var_exponents_inner(expr, &mut exps)?;
244    Some(exps)
245}
246
247fn extract_var_exponents_inner(
248    expr: &Expr,
249    exps: &mut std::collections::BTreeMap<&'static str, f64>,
250) -> Option<()> {
251    match expr {
252        Expr::Var(name) => {
253            *exps.entry(name).or_insert(0.0) += 1.0;
254            Some(())
255        }
256        Expr::Pow(base, exp) => {
257            if let (Expr::Var(name), Some(e)) = (base.as_ref(), exp.constant_value()) {
258                if e < 0.0 {
259                    return None;
260                }
261                *exps.entry(name).or_insert(0.0) += e;
262                Some(())
263            } else {
264                None // Non-simple power
265            }
266        }
267        Expr::Mul(a, b) => {
268            extract_var_exponents_inner(a, exps)?;
269            extract_var_exponents_inner(b, exps)
270        }
271        Expr::Const(_) => Some(()), // Constants don't affect exponents
272        _ => None,                  // exp, log, sqrt → not a polynomial monomial
273    }
274}
275
276fn contains_negative_exponent(expr: &Expr) -> bool {
277    match expr {
278        Expr::Pow(_, exp) => exp.constant_value().is_some_and(|e| e < 0.0),
279        Expr::Mul(a, b) | Expr::Add(a, b) => {
280            contains_negative_exponent(a) || contains_negative_exponent(b)
281        }
282        Expr::Exp(arg) | Expr::Log(arg) | Expr::Sqrt(arg) | Expr::Factorial(arg) => {
283            contains_negative_exponent(arg)
284        }
285        Expr::Const(_) | Expr::Var(_) => false,
286    }
287}
288
289/// Check if an expression has exponential growth.
290///
291/// Returns true if the expression contains `exp(var_expr)` or `c^(var_expr)` where c > 1.
292fn has_exponential_growth(expr: &Expr) -> bool {
293    match expr {
294        Expr::Exp(arg) => !arg.variables().is_empty(),
295        Expr::Pow(base, exp) => {
296            base.constant_value().is_some_and(|c| c > 1.0) && !exp.variables().is_empty()
297        }
298        Expr::Mul(a, b) => has_exponential_growth(a) || has_exponential_growth(b),
299        _ => false,
300    }
301}
302
303/// Compute the effective exponential base for growth rate comparison.
304///
305/// For `c^(f(n))`, approximates the effective base as `c^(f(1))`.
306/// This works correctly for linear exponents (the common case in complexity expressions).
307fn effective_exp_base(expr: &Expr) -> Option<f64> {
308    match expr {
309        Expr::Exp(arg) => {
310            let vars = arg.variables();
311            if vars.is_empty() {
312                None
313            } else {
314                let size = unit_problem_size(&vars);
315                let rate = arg.eval(&size);
316                Some(std::f64::consts::E.powf(rate))
317            }
318        }
319        Expr::Pow(base, exp) => {
320            if let Some(c) = base.constant_value() {
321                let vars = exp.variables();
322                if c > 1.0 && !vars.is_empty() {
323                    let size = unit_problem_size(&vars);
324                    let exp_at_1 = exp.eval(&size);
325                    Some(c.powf(exp_at_1))
326                } else {
327                    None
328                }
329            } else {
330                None
331            }
332        }
333        Expr::Mul(a, b) => match (effective_exp_base(a), effective_exp_base(b)) {
334            (Some(ba), Some(bb)) => Some(ba * bb),
335            (Some(b), None) | (None, Some(b)) => Some(b),
336            (None, None) => None,
337        },
338        _ => None,
339    }
340}
341
342/// Create a `ProblemSize` with all variables set to the given value.
343fn make_problem_size(
344    vars: &std::collections::HashSet<&'static str>,
345    val: usize,
346) -> crate::types::ProblemSize {
347    crate::types::ProblemSize::new(vars.iter().map(|&v| (v, val)).collect())
348}
349
350/// Create a `ProblemSize` with all variables set to 1.
351fn unit_problem_size(vars: &std::collections::HashSet<&'static str>) -> crate::types::ProblemSize {
352    make_problem_size(vars, 1)
353}
354
355/// Check dominance numerically by evaluating at two scales.
356///
357/// Returns true if `big/small` ratio is > 1 and increasing between the two
358/// evaluation points, indicating `big` grows asymptotically faster.
359fn numerical_dominance_check(
360    small: &Expr,
361    big: &Expr,
362    vars: &std::collections::HashSet<&'static str>,
363) -> bool {
364    let size1 = make_problem_size(vars, 100);
365    let size2 = make_problem_size(vars, 10_000);
366
367    let s1 = small.eval(&size1);
368    let b1 = big.eval(&size1);
369    let s2 = small.eval(&size2);
370    let b2 = big.eval(&size2);
371
372    // Both must be finite and positive at both points
373    if !s1.is_finite() || !b1.is_finite() || !s2.is_finite() || !b2.is_finite() {
374        return false;
375    }
376    if s1 <= 1e-300 || b1 <= 1e-300 || s2 <= 1e-300 || b2 <= 1e-300 {
377        return false;
378    }
379
380    let ratio1 = b1 / s1;
381    let ratio2 = b2 / s2;
382
383    // Dominance: ratio is > 1 at both points and strictly increasing
384    ratio1 > 1.0 + 1e-10 && ratio2 > ratio1 * (1.0 + 1e-6)
385}
386
387#[cfg(test)]
388#[path = "unit_tests/big_o.rs"]
389mod tests;