Skip to main content

problemreductions/
canonical.rs

1//! Exact symbolic canonicalization for `Expr`.
2//!
3//! Normalizes expressions into a canonical sum-of-terms form with signed
4//! coefficients and deterministic ordering, without losing algebraic precision.
5
6use std::collections::BTreeMap;
7
8use crate::expr::{CanonicalizationError, Expr};
9
10/// Hard cap on the number of additive terms produced while expanding an
11/// expression into canonical sum-of-monomials form.
12///
13/// Expanding a nested `(sum)^2 * (sum)^2` structure is exponential in nesting
14/// depth: composed-path overheads that traverse quadratic-overhead reductions
15/// (e.g. `QuadraticAssignment`) blow up to multi-GB of monomials and OOM/hang.
16/// When the intermediate term count would exceed this cap we abandon expansion
17/// and report the expression as `Unsupported`; callers (e.g. `big_o_of`) fall
18/// back to printing the compact, un-expanded expression. See issue #1069.
19///
20/// Legitimate overhead expressions stay far below this bound (the worst
21/// non-pathological case is a few hundred terms), so this never affects normal
22/// output — it only stops pathological blowups. This is a stopgap guard; the
23/// symbolic system is slated for a larger rework.
24const MAX_CANONICAL_TERMS: usize = 50_000;
25
26/// An opaque non-polynomial factor (exp, log, fractional-power base).
27///
28/// Stored by its canonical string representation for deterministic ordering.
29#[derive(Clone, Debug, PartialEq)]
30struct OpaqueFactor {
31    /// The canonical string form (used for equality and ordering).
32    key: String,
33    /// The original `Expr` for reconstruction.
34    expr: Expr,
35}
36
37impl Eq for OpaqueFactor {}
38
39impl PartialOrd for OpaqueFactor {
40    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
41        Some(self.cmp(other))
42    }
43}
44
45impl Ord for OpaqueFactor {
46    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
47        self.key.cmp(&other.key)
48    }
49}
50
51fn normalized_f64_bits(value: f64) -> u64 {
52    if value == 0.0 {
53        0.0f64.to_bits()
54    } else {
55        value.to_bits()
56    }
57}
58
59/// A single additive term: coefficient × product of canonical factors.
60#[derive(Clone, Debug)]
61struct CanonicalTerm {
62    /// Signed numeric coefficient.
63    coeff: f64,
64    /// Polynomial variable exponents (variable_name → exponent).
65    vars: BTreeMap<&'static str, f64>,
66    /// Non-polynomial opaque factors, sorted by key.
67    opaque: Vec<OpaqueFactor>,
68}
69
70/// Try to merge a new opaque factor into an existing list using transcendental identities.
71/// Returns `Some(updated_list)` if a merge happened, `None` if no identity applies.
72fn try_merge_opaque(existing: &[OpaqueFactor], new: &OpaqueFactor) -> Option<Vec<OpaqueFactor>> {
73    for (i, existing_factor) in existing.iter().enumerate() {
74        // exp(a) * exp(b) -> exp(a + b)
75        if let (Expr::Exp(a), Expr::Exp(b)) = (&existing_factor.expr, &new.expr) {
76            let merged_arg = (**a).clone() + (**b).clone();
77            let merged_expr =
78                Expr::Exp(Box::new(canonical_form(&merged_arg).unwrap_or(merged_arg)));
79            let mut result = existing.to_vec();
80            result[i] = OpaqueFactor {
81                key: merged_expr.to_string(),
82                expr: merged_expr,
83            };
84            return Some(result);
85        }
86
87        // c^a * c^b -> c^(a+b) for matching positive constant base c
88        if let (Expr::Pow(base1, exp1), Expr::Pow(base2, exp2)) = (&existing_factor.expr, &new.expr)
89        {
90            if let (Some(c1), Some(c2)) = (base1.constant_value(), base2.constant_value()) {
91                if c1 > 0.0 && c2 > 0.0 && (c1 - c2).abs() < 1e-15 {
92                    let merged_exp = (**exp1).clone() + (**exp2).clone();
93                    let canon_exp = canonical_form(&merged_exp).unwrap_or(merged_exp);
94                    let merged_expr = Expr::Pow(base1.clone(), Box::new(canon_exp));
95                    let mut result = existing.to_vec();
96                    result[i] = OpaqueFactor {
97                        key: merged_expr.to_string(),
98                        expr: merged_expr,
99                    };
100                    return Some(result);
101                }
102            }
103        }
104    }
105    None
106}
107
108/// A canonical sum of terms: the exact normal form of an expression.
109#[derive(Clone, Debug)]
110pub(crate) struct CanonicalSum {
111    terms: Vec<CanonicalTerm>,
112}
113
114impl CanonicalTerm {
115    fn constant(c: f64) -> Self {
116        Self {
117            coeff: c,
118            vars: BTreeMap::new(),
119            opaque: Vec::new(),
120        }
121    }
122
123    fn variable(name: &'static str) -> Self {
124        let mut vars = BTreeMap::new();
125        vars.insert(name, 1.0);
126        Self {
127            coeff: 1.0,
128            vars,
129            opaque: Vec::new(),
130        }
131    }
132
133    fn opaque_factor(expr: Expr) -> Self {
134        let key = expr.to_string();
135        Self {
136            coeff: 1.0,
137            vars: BTreeMap::new(),
138            opaque: vec![OpaqueFactor { key, expr }],
139        }
140    }
141
142    /// Multiply two terms, applying transcendental identities:
143    /// - `exp(a) * exp(b) -> exp(a + b)`
144    /// - `c^a * c^b -> c^(a + b)` for matching constant base `c`
145    fn mul(&self, other: &CanonicalTerm) -> CanonicalTerm {
146        let coeff = self.coeff * other.coeff;
147        let mut vars = self.vars.clone();
148        for (&v, &e) in &other.vars {
149            *vars.entry(v).or_insert(0.0) += e;
150        }
151        // Remove zero-exponent variables
152        vars.retain(|_, e| e.abs() > 1e-15);
153
154        // Merge opaque factors with transcendental identities
155        let mut opaque = self.opaque.clone();
156        for other_factor in &other.opaque {
157            if let Some(merged) = try_merge_opaque(&opaque, other_factor) {
158                opaque = merged;
159            } else {
160                opaque.push(other_factor.clone());
161            }
162        }
163        opaque.sort();
164        CanonicalTerm {
165            coeff,
166            vars,
167            opaque,
168        }
169    }
170
171    /// Deterministic sort key for ordering terms in a sum.
172    fn sort_key(&self) -> (Vec<(&'static str, u64)>, Vec<String>) {
173        let vars: Vec<_> = self
174            .vars
175            .iter()
176            .map(|(&k, &v)| (k, normalized_f64_bits(v)))
177            .collect();
178        let opaque: Vec<_> = self.opaque.iter().map(|o| o.key.clone()).collect();
179        (vars, opaque)
180    }
181}
182
183impl CanonicalSum {
184    fn from_term(term: CanonicalTerm) -> Self {
185        Self { terms: vec![term] }
186    }
187
188    fn add(mut self, other: CanonicalSum) -> Self {
189        self.terms.extend(other.terms);
190        self
191    }
192
193    fn mul(&self, other: &CanonicalSum) -> CanonicalSum {
194        let mut terms = Vec::new();
195        for a in &self.terms {
196            for b in &other.terms {
197                terms.push(a.mul(b));
198            }
199        }
200        CanonicalSum { terms }
201    }
202
203    /// Multiply with a guard against pathological expansion (see
204    /// [`MAX_CANONICAL_TERMS`]). The Cartesian product size is checked *before*
205    /// it is materialized, so this never allocates the blown-up vector.
206    fn try_mul(&self, other: &CanonicalSum) -> Result<CanonicalSum, CanonicalizationError> {
207        let product = self.terms.len().saturating_mul(other.terms.len());
208        if product > MAX_CANONICAL_TERMS {
209            return Err(CanonicalizationError::Unsupported(format!(
210                "expression too large to canonicalize ({product} terms exceeds cap of {MAX_CANONICAL_TERMS})"
211            )));
212        }
213        Ok(self.mul(other))
214    }
215
216    /// Merge terms with the same signature and drop zero-coefficient terms.
217    /// Sort the result deterministically.
218    fn simplify(self) -> Self {
219        type SortKey = (Vec<(&'static str, u64)>, Vec<String>);
220        let mut groups: BTreeMap<SortKey, CanonicalTerm> = BTreeMap::new();
221
222        for term in self.terms {
223            let key = term.sort_key();
224            groups
225                .entry(key)
226                .and_modify(|existing| existing.coeff += term.coeff)
227                .or_insert(term);
228        }
229
230        let mut terms: Vec<_> = groups
231            .into_values()
232            .filter(|t| t.coeff.abs() > 1e-15)
233            .collect();
234
235        terms.sort_by(|a, b| a.sort_key().cmp(&b.sort_key()));
236
237        CanonicalSum { terms }
238    }
239}
240
241/// Normalize an expression into its exact canonical sum-of-terms form.
242///
243/// This performs exact symbolic simplification:
244/// - Flattens nested Add/Mul
245/// - Merges duplicate additive terms by summing coefficients
246/// - Merges repeated multiplicative factors into powers
247/// - Preserves signed coefficients (supports subtraction)
248/// - Preserves transcendental identities: exp(a)*exp(b)=exp(a+b), etc.
249/// - Produces deterministic ordering
250///
251/// Does NOT drop terms or constant factors — use `big_o_normal_form()` for that.
252pub fn canonical_form(expr: &Expr) -> Result<Expr, CanonicalizationError> {
253    let sum = expr_to_canonical(expr)?;
254    let simplified = sum.simplify();
255    Ok(canonical_sum_to_expr(&simplified))
256}
257
258fn expr_to_canonical(expr: &Expr) -> Result<CanonicalSum, CanonicalizationError> {
259    match expr {
260        Expr::Const(c) => Ok(CanonicalSum::from_term(CanonicalTerm::constant(*c))),
261        Expr::Var(name) => Ok(CanonicalSum::from_term(CanonicalTerm::variable(name))),
262        Expr::Add(a, b) => {
263            let ca = expr_to_canonical(a)?;
264            let cb = expr_to_canonical(b)?;
265            Ok(ca.add(cb))
266        }
267        Expr::Mul(a, b) => {
268            let ca = expr_to_canonical(a)?;
269            let cb = expr_to_canonical(b)?;
270            ca.try_mul(&cb)
271        }
272        Expr::Pow(base, exp) => canonicalize_pow(base, exp),
273        Expr::Exp(arg) => {
274            // Treat exp(canonicalized_arg) as an opaque factor
275            let inner = canonical_form(arg)?;
276            Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
277                Expr::Exp(Box::new(inner)),
278            )))
279        }
280        Expr::Log(arg) => {
281            let inner = canonical_form(arg)?;
282            Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
283                Expr::Log(Box::new(inner)),
284            )))
285        }
286        Expr::Sqrt(arg) => {
287            // sqrt(x) = x^0.5 — canonicalize as power
288            canonicalize_pow(arg, &Expr::Const(0.5))
289        }
290        Expr::Factorial(arg) => {
291            let inner = canonical_form(arg)?;
292            Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
293                Expr::Factorial(Box::new(inner)),
294            )))
295        }
296    }
297}
298
299fn canonicalize_pow(base: &Expr, exp: &Expr) -> Result<CanonicalSum, CanonicalizationError> {
300    match (base, exp) {
301        // Constant base, constant exp → numeric constant
302        (_, _) if base.constant_value().is_some() && exp.constant_value().is_some() => {
303            let b = base.constant_value().unwrap();
304            let e = exp.constant_value().unwrap();
305            Ok(CanonicalSum::from_term(CanonicalTerm::constant(b.powf(e))))
306        }
307        // Variable ^ constant exponent → vars map (supports fractional/negative exponents)
308        (Expr::Var(name), _) if exp.constant_value().is_some() => {
309            let e = exp.constant_value().unwrap();
310            if e.abs() < 1e-15 {
311                return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
312            }
313            let mut vars = BTreeMap::new();
314            vars.insert(*name, e);
315            Ok(CanonicalSum::from_term(CanonicalTerm {
316                coeff: 1.0,
317                vars,
318                opaque: Vec::new(),
319            }))
320        }
321        // Polynomial base ^ constant integer exponent → expand
322        (_, _) if exp.constant_value().is_some() => {
323            let e = exp.constant_value().unwrap();
324            if e >= 0.0 && (e - e.round()).abs() < 1e-10 {
325                let n = e.round() as usize;
326                let base_sum = expr_to_canonical(base)?;
327                if n == 0 {
328                    return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
329                }
330                let mut result = base_sum.clone();
331                for _ in 1..n {
332                    result = result.try_mul(&base_sum)?;
333                }
334                Ok(result)
335            } else {
336                // Fractional exponent with non-variable base → opaque
337                let canon_base = canonical_form(base)?;
338                Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
339                    Expr::Pow(Box::new(canon_base), Box::new(Expr::Const(e))),
340                )))
341            }
342        }
343        // Constant base ^ variable exponent → opaque (exponential growth)
344        (_, _) if base.constant_value().is_some() => {
345            let c = base.constant_value().unwrap();
346            if (c - 1.0).abs() < 1e-15 {
347                return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
348            }
349            if c <= 0.0 {
350                return Err(CanonicalizationError::Unsupported(format!(
351                    "{}^{}",
352                    base, exp
353                )));
354            }
355            let canon_exp = canonical_form(exp)?;
356            Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
357                Expr::Pow(Box::new(base.clone()), Box::new(canon_exp)),
358            )))
359        }
360        // Variable base ^ variable exponent → unsupported
361        _ => Err(CanonicalizationError::Unsupported(format!(
362            "{}^{}",
363            base, exp
364        ))),
365    }
366}
367
368fn canonical_sum_to_expr(sum: &CanonicalSum) -> Expr {
369    if sum.terms.is_empty() {
370        return Expr::Const(0.0);
371    }
372
373    let term_exprs: Vec<Expr> = sum.terms.iter().map(canonical_term_to_expr).collect();
374
375    let mut result = term_exprs[0].clone();
376    for term in &term_exprs[1..] {
377        result = result + term.clone();
378    }
379    result
380}
381
382fn canonical_term_to_expr(term: &CanonicalTerm) -> Expr {
383    let mut factors: Vec<Expr> = Vec::new();
384
385    // Add coefficient if not 1.0 (or -1.0, handled specially)
386    let (coeff_factor, sign) = if term.coeff < 0.0 {
387        (term.coeff.abs(), true)
388    } else {
389        (term.coeff, false)
390    };
391
392    let has_other_factors = !term.vars.is_empty() || !term.opaque.is_empty();
393
394    if (coeff_factor - 1.0).abs() > 1e-15 || !has_other_factors {
395        factors.push(Expr::Const(coeff_factor));
396    }
397
398    // Add variable powers
399    for (&var, &exp) in &term.vars {
400        if (exp - 1.0).abs() < 1e-15 {
401            factors.push(Expr::Var(var));
402        } else {
403            factors.push(Expr::pow(Expr::Var(var), Expr::Const(exp)));
404        }
405    }
406
407    // Add opaque factors
408    for opaque in &term.opaque {
409        factors.push(opaque.expr.clone());
410    }
411
412    let mut result = if factors.is_empty() {
413        Expr::Const(1.0)
414    } else {
415        let mut r = factors[0].clone();
416        for f in &factors[1..] {
417            r = r * f.clone();
418        }
419        r
420    };
421
422    if sign {
423        result = -result;
424    }
425
426    result
427}
428
429#[cfg(test)]
430#[path = "unit_tests/canonical.rs"]
431mod tests;