Skip to main content

problemreductions/
canonical.rs

1//! Exact symbolic canonicalization for `Expr`.
2//!
3//! Normalizes expressions into a canonical sum-of-terms form with signed
4//! coefficients and deterministic ordering, without losing algebraic precision.
5
6use std::collections::BTreeMap;
7
8use crate::expr::{CanonicalizationError, Expr};
9
10/// An opaque non-polynomial factor (exp, log, fractional-power base).
11///
12/// Stored by its canonical string representation for deterministic ordering.
13#[derive(Clone, Debug, PartialEq)]
14struct OpaqueFactor {
15    /// The canonical string form (used for equality and ordering).
16    key: String,
17    /// The original `Expr` for reconstruction.
18    expr: Expr,
19}
20
21impl Eq for OpaqueFactor {}
22
23impl PartialOrd for OpaqueFactor {
24    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
25        Some(self.cmp(other))
26    }
27}
28
29impl Ord for OpaqueFactor {
30    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
31        self.key.cmp(&other.key)
32    }
33}
34
35fn normalized_f64_bits(value: f64) -> u64 {
36    if value == 0.0 {
37        0.0f64.to_bits()
38    } else {
39        value.to_bits()
40    }
41}
42
43/// A single additive term: coefficient × product of canonical factors.
44#[derive(Clone, Debug)]
45struct CanonicalTerm {
46    /// Signed numeric coefficient.
47    coeff: f64,
48    /// Polynomial variable exponents (variable_name → exponent).
49    vars: BTreeMap<&'static str, f64>,
50    /// Non-polynomial opaque factors, sorted by key.
51    opaque: Vec<OpaqueFactor>,
52}
53
54/// Try to merge a new opaque factor into an existing list using transcendental identities.
55/// Returns `Some(updated_list)` if a merge happened, `None` if no identity applies.
56fn try_merge_opaque(existing: &[OpaqueFactor], new: &OpaqueFactor) -> Option<Vec<OpaqueFactor>> {
57    for (i, existing_factor) in existing.iter().enumerate() {
58        // exp(a) * exp(b) -> exp(a + b)
59        if let (Expr::Exp(a), Expr::Exp(b)) = (&existing_factor.expr, &new.expr) {
60            let merged_arg = (**a).clone() + (**b).clone();
61            let merged_expr =
62                Expr::Exp(Box::new(canonical_form(&merged_arg).unwrap_or(merged_arg)));
63            let mut result = existing.to_vec();
64            result[i] = OpaqueFactor {
65                key: merged_expr.to_string(),
66                expr: merged_expr,
67            };
68            return Some(result);
69        }
70
71        // c^a * c^b -> c^(a+b) for matching positive constant base c
72        if let (Expr::Pow(base1, exp1), Expr::Pow(base2, exp2)) = (&existing_factor.expr, &new.expr)
73        {
74            if let (Some(c1), Some(c2)) = (base1.constant_value(), base2.constant_value()) {
75                if c1 > 0.0 && c2 > 0.0 && (c1 - c2).abs() < 1e-15 {
76                    let merged_exp = (**exp1).clone() + (**exp2).clone();
77                    let canon_exp = canonical_form(&merged_exp).unwrap_or(merged_exp);
78                    let merged_expr = Expr::Pow(base1.clone(), Box::new(canon_exp));
79                    let mut result = existing.to_vec();
80                    result[i] = OpaqueFactor {
81                        key: merged_expr.to_string(),
82                        expr: merged_expr,
83                    };
84                    return Some(result);
85                }
86            }
87        }
88    }
89    None
90}
91
92/// A canonical sum of terms: the exact normal form of an expression.
93#[derive(Clone, Debug)]
94pub(crate) struct CanonicalSum {
95    terms: Vec<CanonicalTerm>,
96}
97
98impl CanonicalTerm {
99    fn constant(c: f64) -> Self {
100        Self {
101            coeff: c,
102            vars: BTreeMap::new(),
103            opaque: Vec::new(),
104        }
105    }
106
107    fn variable(name: &'static str) -> Self {
108        let mut vars = BTreeMap::new();
109        vars.insert(name, 1.0);
110        Self {
111            coeff: 1.0,
112            vars,
113            opaque: Vec::new(),
114        }
115    }
116
117    fn opaque_factor(expr: Expr) -> Self {
118        let key = expr.to_string();
119        Self {
120            coeff: 1.0,
121            vars: BTreeMap::new(),
122            opaque: vec![OpaqueFactor { key, expr }],
123        }
124    }
125
126    /// Multiply two terms, applying transcendental identities:
127    /// - `exp(a) * exp(b) -> exp(a + b)`
128    /// - `c^a * c^b -> c^(a + b)` for matching constant base `c`
129    fn mul(&self, other: &CanonicalTerm) -> CanonicalTerm {
130        let coeff = self.coeff * other.coeff;
131        let mut vars = self.vars.clone();
132        for (&v, &e) in &other.vars {
133            *vars.entry(v).or_insert(0.0) += e;
134        }
135        // Remove zero-exponent variables
136        vars.retain(|_, e| e.abs() > 1e-15);
137
138        // Merge opaque factors with transcendental identities
139        let mut opaque = self.opaque.clone();
140        for other_factor in &other.opaque {
141            if let Some(merged) = try_merge_opaque(&opaque, other_factor) {
142                opaque = merged;
143            } else {
144                opaque.push(other_factor.clone());
145            }
146        }
147        opaque.sort();
148        CanonicalTerm {
149            coeff,
150            vars,
151            opaque,
152        }
153    }
154
155    /// Deterministic sort key for ordering terms in a sum.
156    fn sort_key(&self) -> (Vec<(&'static str, u64)>, Vec<String>) {
157        let vars: Vec<_> = self
158            .vars
159            .iter()
160            .map(|(&k, &v)| (k, normalized_f64_bits(v)))
161            .collect();
162        let opaque: Vec<_> = self.opaque.iter().map(|o| o.key.clone()).collect();
163        (vars, opaque)
164    }
165}
166
167impl CanonicalSum {
168    fn from_term(term: CanonicalTerm) -> Self {
169        Self { terms: vec![term] }
170    }
171
172    fn add(mut self, other: CanonicalSum) -> Self {
173        self.terms.extend(other.terms);
174        self
175    }
176
177    fn mul(&self, other: &CanonicalSum) -> CanonicalSum {
178        let mut terms = Vec::new();
179        for a in &self.terms {
180            for b in &other.terms {
181                terms.push(a.mul(b));
182            }
183        }
184        CanonicalSum { terms }
185    }
186
187    /// Merge terms with the same signature and drop zero-coefficient terms.
188    /// Sort the result deterministically.
189    fn simplify(self) -> Self {
190        type SortKey = (Vec<(&'static str, u64)>, Vec<String>);
191        let mut groups: BTreeMap<SortKey, CanonicalTerm> = BTreeMap::new();
192
193        for term in self.terms {
194            let key = term.sort_key();
195            groups
196                .entry(key)
197                .and_modify(|existing| existing.coeff += term.coeff)
198                .or_insert(term);
199        }
200
201        let mut terms: Vec<_> = groups
202            .into_values()
203            .filter(|t| t.coeff.abs() > 1e-15)
204            .collect();
205
206        terms.sort_by(|a, b| a.sort_key().cmp(&b.sort_key()));
207
208        CanonicalSum { terms }
209    }
210}
211
212/// Normalize an expression into its exact canonical sum-of-terms form.
213///
214/// This performs exact symbolic simplification:
215/// - Flattens nested Add/Mul
216/// - Merges duplicate additive terms by summing coefficients
217/// - Merges repeated multiplicative factors into powers
218/// - Preserves signed coefficients (supports subtraction)
219/// - Preserves transcendental identities: exp(a)*exp(b)=exp(a+b), etc.
220/// - Produces deterministic ordering
221///
222/// Does NOT drop terms or constant factors — use `big_o_normal_form()` for that.
223pub fn canonical_form(expr: &Expr) -> Result<Expr, CanonicalizationError> {
224    let sum = expr_to_canonical(expr)?;
225    let simplified = sum.simplify();
226    Ok(canonical_sum_to_expr(&simplified))
227}
228
229fn expr_to_canonical(expr: &Expr) -> Result<CanonicalSum, CanonicalizationError> {
230    match expr {
231        Expr::Const(c) => Ok(CanonicalSum::from_term(CanonicalTerm::constant(*c))),
232        Expr::Var(name) => Ok(CanonicalSum::from_term(CanonicalTerm::variable(name))),
233        Expr::Add(a, b) => {
234            let ca = expr_to_canonical(a)?;
235            let cb = expr_to_canonical(b)?;
236            Ok(ca.add(cb))
237        }
238        Expr::Mul(a, b) => {
239            let ca = expr_to_canonical(a)?;
240            let cb = expr_to_canonical(b)?;
241            Ok(ca.mul(&cb))
242        }
243        Expr::Pow(base, exp) => canonicalize_pow(base, exp),
244        Expr::Exp(arg) => {
245            // Treat exp(canonicalized_arg) as an opaque factor
246            let inner = canonical_form(arg)?;
247            Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
248                Expr::Exp(Box::new(inner)),
249            )))
250        }
251        Expr::Log(arg) => {
252            let inner = canonical_form(arg)?;
253            Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
254                Expr::Log(Box::new(inner)),
255            )))
256        }
257        Expr::Sqrt(arg) => {
258            // sqrt(x) = x^0.5 — canonicalize as power
259            canonicalize_pow(arg, &Expr::Const(0.5))
260        }
261        Expr::Factorial(arg) => {
262            let inner = canonical_form(arg)?;
263            Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
264                Expr::Factorial(Box::new(inner)),
265            )))
266        }
267    }
268}
269
270fn canonicalize_pow(base: &Expr, exp: &Expr) -> Result<CanonicalSum, CanonicalizationError> {
271    match (base, exp) {
272        // Constant base, constant exp → numeric constant
273        (_, _) if base.constant_value().is_some() && exp.constant_value().is_some() => {
274            let b = base.constant_value().unwrap();
275            let e = exp.constant_value().unwrap();
276            Ok(CanonicalSum::from_term(CanonicalTerm::constant(b.powf(e))))
277        }
278        // Variable ^ constant exponent → vars map (supports fractional/negative exponents)
279        (Expr::Var(name), _) if exp.constant_value().is_some() => {
280            let e = exp.constant_value().unwrap();
281            if e.abs() < 1e-15 {
282                return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
283            }
284            let mut vars = BTreeMap::new();
285            vars.insert(*name, e);
286            Ok(CanonicalSum::from_term(CanonicalTerm {
287                coeff: 1.0,
288                vars,
289                opaque: Vec::new(),
290            }))
291        }
292        // Polynomial base ^ constant integer exponent → expand
293        (_, _) if exp.constant_value().is_some() => {
294            let e = exp.constant_value().unwrap();
295            if e >= 0.0 && (e - e.round()).abs() < 1e-10 {
296                let n = e.round() as usize;
297                let base_sum = expr_to_canonical(base)?;
298                if n == 0 {
299                    return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
300                }
301                let mut result = base_sum.clone();
302                for _ in 1..n {
303                    result = result.mul(&base_sum);
304                }
305                Ok(result)
306            } else {
307                // Fractional exponent with non-variable base → opaque
308                let canon_base = canonical_form(base)?;
309                Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
310                    Expr::Pow(Box::new(canon_base), Box::new(Expr::Const(e))),
311                )))
312            }
313        }
314        // Constant base ^ variable exponent → opaque (exponential growth)
315        (_, _) if base.constant_value().is_some() => {
316            let c = base.constant_value().unwrap();
317            if (c - 1.0).abs() < 1e-15 {
318                return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
319            }
320            if c <= 0.0 {
321                return Err(CanonicalizationError::Unsupported(format!(
322                    "{}^{}",
323                    base, exp
324                )));
325            }
326            let canon_exp = canonical_form(exp)?;
327            Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
328                Expr::Pow(Box::new(base.clone()), Box::new(canon_exp)),
329            )))
330        }
331        // Variable base ^ variable exponent → unsupported
332        _ => Err(CanonicalizationError::Unsupported(format!(
333            "{}^{}",
334            base, exp
335        ))),
336    }
337}
338
339fn canonical_sum_to_expr(sum: &CanonicalSum) -> Expr {
340    if sum.terms.is_empty() {
341        return Expr::Const(0.0);
342    }
343
344    let term_exprs: Vec<Expr> = sum.terms.iter().map(canonical_term_to_expr).collect();
345
346    let mut result = term_exprs[0].clone();
347    for term in &term_exprs[1..] {
348        result = result + term.clone();
349    }
350    result
351}
352
353fn canonical_term_to_expr(term: &CanonicalTerm) -> Expr {
354    let mut factors: Vec<Expr> = Vec::new();
355
356    // Add coefficient if not 1.0 (or -1.0, handled specially)
357    let (coeff_factor, sign) = if term.coeff < 0.0 {
358        (term.coeff.abs(), true)
359    } else {
360        (term.coeff, false)
361    };
362
363    let has_other_factors = !term.vars.is_empty() || !term.opaque.is_empty();
364
365    if (coeff_factor - 1.0).abs() > 1e-15 || !has_other_factors {
366        factors.push(Expr::Const(coeff_factor));
367    }
368
369    // Add variable powers
370    for (&var, &exp) in &term.vars {
371        if (exp - 1.0).abs() < 1e-15 {
372            factors.push(Expr::Var(var));
373        } else {
374            factors.push(Expr::pow(Expr::Var(var), Expr::Const(exp)));
375        }
376    }
377
378    // Add opaque factors
379    for opaque in &term.opaque {
380        factors.push(opaque.expr.clone());
381    }
382
383    let mut result = if factors.is_empty() {
384        Expr::Const(1.0)
385    } else {
386        let mut r = factors[0].clone();
387        for f in &factors[1..] {
388            r = r * f.clone();
389        }
390        r
391    };
392
393    if sign {
394        result = -result;
395    }
396
397    result
398}
399
400#[cfg(test)]
401#[path = "unit_tests/canonical.rs"]
402mod tests;