1use std::collections::BTreeMap;
7
8use crate::expr::{CanonicalizationError, Expr};
9
10const MAX_CANONICAL_TERMS: usize = 50_000;
25
26#[derive(Clone, Debug, PartialEq)]
30struct OpaqueFactor {
31 key: String,
33 expr: Expr,
35}
36
37impl Eq for OpaqueFactor {}
38
39impl PartialOrd for OpaqueFactor {
40 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
41 Some(self.cmp(other))
42 }
43}
44
45impl Ord for OpaqueFactor {
46 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
47 self.key.cmp(&other.key)
48 }
49}
50
51fn normalized_f64_bits(value: f64) -> u64 {
52 if value == 0.0 {
53 0.0f64.to_bits()
54 } else {
55 value.to_bits()
56 }
57}
58
59#[derive(Clone, Debug)]
61struct CanonicalTerm {
62 coeff: f64,
64 vars: BTreeMap<&'static str, f64>,
66 opaque: Vec<OpaqueFactor>,
68}
69
70fn try_merge_opaque(existing: &[OpaqueFactor], new: &OpaqueFactor) -> Option<Vec<OpaqueFactor>> {
73 for (i, existing_factor) in existing.iter().enumerate() {
74 if let (Expr::Exp(a), Expr::Exp(b)) = (&existing_factor.expr, &new.expr) {
76 let merged_arg = (**a).clone() + (**b).clone();
77 let merged_expr =
78 Expr::Exp(Box::new(canonical_form(&merged_arg).unwrap_or(merged_arg)));
79 let mut result = existing.to_vec();
80 result[i] = OpaqueFactor {
81 key: merged_expr.to_string(),
82 expr: merged_expr,
83 };
84 return Some(result);
85 }
86
87 if let (Expr::Pow(base1, exp1), Expr::Pow(base2, exp2)) = (&existing_factor.expr, &new.expr)
89 {
90 if let (Some(c1), Some(c2)) = (base1.constant_value(), base2.constant_value()) {
91 if c1 > 0.0 && c2 > 0.0 && (c1 - c2).abs() < 1e-15 {
92 let merged_exp = (**exp1).clone() + (**exp2).clone();
93 let canon_exp = canonical_form(&merged_exp).unwrap_or(merged_exp);
94 let merged_expr = Expr::Pow(base1.clone(), Box::new(canon_exp));
95 let mut result = existing.to_vec();
96 result[i] = OpaqueFactor {
97 key: merged_expr.to_string(),
98 expr: merged_expr,
99 };
100 return Some(result);
101 }
102 }
103 }
104 }
105 None
106}
107
108#[derive(Clone, Debug)]
110pub(crate) struct CanonicalSum {
111 terms: Vec<CanonicalTerm>,
112}
113
114impl CanonicalTerm {
115 fn constant(c: f64) -> Self {
116 Self {
117 coeff: c,
118 vars: BTreeMap::new(),
119 opaque: Vec::new(),
120 }
121 }
122
123 fn variable(name: &'static str) -> Self {
124 let mut vars = BTreeMap::new();
125 vars.insert(name, 1.0);
126 Self {
127 coeff: 1.0,
128 vars,
129 opaque: Vec::new(),
130 }
131 }
132
133 fn opaque_factor(expr: Expr) -> Self {
134 let key = expr.to_string();
135 Self {
136 coeff: 1.0,
137 vars: BTreeMap::new(),
138 opaque: vec![OpaqueFactor { key, expr }],
139 }
140 }
141
142 fn mul(&self, other: &CanonicalTerm) -> CanonicalTerm {
146 let coeff = self.coeff * other.coeff;
147 let mut vars = self.vars.clone();
148 for (&v, &e) in &other.vars {
149 *vars.entry(v).or_insert(0.0) += e;
150 }
151 vars.retain(|_, e| e.abs() > 1e-15);
153
154 let mut opaque = self.opaque.clone();
156 for other_factor in &other.opaque {
157 if let Some(merged) = try_merge_opaque(&opaque, other_factor) {
158 opaque = merged;
159 } else {
160 opaque.push(other_factor.clone());
161 }
162 }
163 opaque.sort();
164 CanonicalTerm {
165 coeff,
166 vars,
167 opaque,
168 }
169 }
170
171 fn sort_key(&self) -> (Vec<(&'static str, u64)>, Vec<String>) {
173 let vars: Vec<_> = self
174 .vars
175 .iter()
176 .map(|(&k, &v)| (k, normalized_f64_bits(v)))
177 .collect();
178 let opaque: Vec<_> = self.opaque.iter().map(|o| o.key.clone()).collect();
179 (vars, opaque)
180 }
181}
182
183impl CanonicalSum {
184 fn from_term(term: CanonicalTerm) -> Self {
185 Self { terms: vec![term] }
186 }
187
188 fn add(mut self, other: CanonicalSum) -> Self {
189 self.terms.extend(other.terms);
190 self
191 }
192
193 fn mul(&self, other: &CanonicalSum) -> CanonicalSum {
194 let mut terms = Vec::new();
195 for a in &self.terms {
196 for b in &other.terms {
197 terms.push(a.mul(b));
198 }
199 }
200 CanonicalSum { terms }
201 }
202
203 fn try_mul(&self, other: &CanonicalSum) -> Result<CanonicalSum, CanonicalizationError> {
207 let product = self.terms.len().saturating_mul(other.terms.len());
208 if product > MAX_CANONICAL_TERMS {
209 return Err(CanonicalizationError::Unsupported(format!(
210 "expression too large to canonicalize ({product} terms exceeds cap of {MAX_CANONICAL_TERMS})"
211 )));
212 }
213 Ok(self.mul(other))
214 }
215
216 fn simplify(self) -> Self {
219 type SortKey = (Vec<(&'static str, u64)>, Vec<String>);
220 let mut groups: BTreeMap<SortKey, CanonicalTerm> = BTreeMap::new();
221
222 for term in self.terms {
223 let key = term.sort_key();
224 groups
225 .entry(key)
226 .and_modify(|existing| existing.coeff += term.coeff)
227 .or_insert(term);
228 }
229
230 let mut terms: Vec<_> = groups
231 .into_values()
232 .filter(|t| t.coeff.abs() > 1e-15)
233 .collect();
234
235 terms.sort_by(|a, b| a.sort_key().cmp(&b.sort_key()));
236
237 CanonicalSum { terms }
238 }
239}
240
241pub fn canonical_form(expr: &Expr) -> Result<Expr, CanonicalizationError> {
253 let sum = expr_to_canonical(expr)?;
254 let simplified = sum.simplify();
255 Ok(canonical_sum_to_expr(&simplified))
256}
257
258fn expr_to_canonical(expr: &Expr) -> Result<CanonicalSum, CanonicalizationError> {
259 match expr {
260 Expr::Const(c) => Ok(CanonicalSum::from_term(CanonicalTerm::constant(*c))),
261 Expr::Var(name) => Ok(CanonicalSum::from_term(CanonicalTerm::variable(name))),
262 Expr::Add(a, b) => {
263 let ca = expr_to_canonical(a)?;
264 let cb = expr_to_canonical(b)?;
265 Ok(ca.add(cb))
266 }
267 Expr::Mul(a, b) => {
268 let ca = expr_to_canonical(a)?;
269 let cb = expr_to_canonical(b)?;
270 ca.try_mul(&cb)
271 }
272 Expr::Pow(base, exp) => canonicalize_pow(base, exp),
273 Expr::Exp(arg) => {
274 let inner = canonical_form(arg)?;
276 Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
277 Expr::Exp(Box::new(inner)),
278 )))
279 }
280 Expr::Log(arg) => {
281 let inner = canonical_form(arg)?;
282 Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
283 Expr::Log(Box::new(inner)),
284 )))
285 }
286 Expr::Sqrt(arg) => {
287 canonicalize_pow(arg, &Expr::Const(0.5))
289 }
290 Expr::Factorial(arg) => {
291 let inner = canonical_form(arg)?;
292 Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
293 Expr::Factorial(Box::new(inner)),
294 )))
295 }
296 }
297}
298
299fn canonicalize_pow(base: &Expr, exp: &Expr) -> Result<CanonicalSum, CanonicalizationError> {
300 match (base, exp) {
301 (_, _) if base.constant_value().is_some() && exp.constant_value().is_some() => {
303 let b = base.constant_value().unwrap();
304 let e = exp.constant_value().unwrap();
305 Ok(CanonicalSum::from_term(CanonicalTerm::constant(b.powf(e))))
306 }
307 (Expr::Var(name), _) if exp.constant_value().is_some() => {
309 let e = exp.constant_value().unwrap();
310 if e.abs() < 1e-15 {
311 return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
312 }
313 let mut vars = BTreeMap::new();
314 vars.insert(*name, e);
315 Ok(CanonicalSum::from_term(CanonicalTerm {
316 coeff: 1.0,
317 vars,
318 opaque: Vec::new(),
319 }))
320 }
321 (_, _) if exp.constant_value().is_some() => {
323 let e = exp.constant_value().unwrap();
324 if e >= 0.0 && (e - e.round()).abs() < 1e-10 {
325 let n = e.round() as usize;
326 let base_sum = expr_to_canonical(base)?;
327 if n == 0 {
328 return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
329 }
330 let mut result = base_sum.clone();
331 for _ in 1..n {
332 result = result.try_mul(&base_sum)?;
333 }
334 Ok(result)
335 } else {
336 let canon_base = canonical_form(base)?;
338 Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
339 Expr::Pow(Box::new(canon_base), Box::new(Expr::Const(e))),
340 )))
341 }
342 }
343 (_, _) if base.constant_value().is_some() => {
345 let c = base.constant_value().unwrap();
346 if (c - 1.0).abs() < 1e-15 {
347 return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
348 }
349 if c <= 0.0 {
350 return Err(CanonicalizationError::Unsupported(format!(
351 "{}^{}",
352 base, exp
353 )));
354 }
355 let canon_exp = canonical_form(exp)?;
356 Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
357 Expr::Pow(Box::new(base.clone()), Box::new(canon_exp)),
358 )))
359 }
360 _ => Err(CanonicalizationError::Unsupported(format!(
362 "{}^{}",
363 base, exp
364 ))),
365 }
366}
367
368fn canonical_sum_to_expr(sum: &CanonicalSum) -> Expr {
369 if sum.terms.is_empty() {
370 return Expr::Const(0.0);
371 }
372
373 let term_exprs: Vec<Expr> = sum.terms.iter().map(canonical_term_to_expr).collect();
374
375 let mut result = term_exprs[0].clone();
376 for term in &term_exprs[1..] {
377 result = result + term.clone();
378 }
379 result
380}
381
382fn canonical_term_to_expr(term: &CanonicalTerm) -> Expr {
383 let mut factors: Vec<Expr> = Vec::new();
384
385 let (coeff_factor, sign) = if term.coeff < 0.0 {
387 (term.coeff.abs(), true)
388 } else {
389 (term.coeff, false)
390 };
391
392 let has_other_factors = !term.vars.is_empty() || !term.opaque.is_empty();
393
394 if (coeff_factor - 1.0).abs() > 1e-15 || !has_other_factors {
395 factors.push(Expr::Const(coeff_factor));
396 }
397
398 for (&var, &exp) in &term.vars {
400 if (exp - 1.0).abs() < 1e-15 {
401 factors.push(Expr::Var(var));
402 } else {
403 factors.push(Expr::pow(Expr::Var(var), Expr::Const(exp)));
404 }
405 }
406
407 for opaque in &term.opaque {
409 factors.push(opaque.expr.clone());
410 }
411
412 let mut result = if factors.is_empty() {
413 Expr::Const(1.0)
414 } else {
415 let mut r = factors[0].clone();
416 for f in &factors[1..] {
417 r = r * f.clone();
418 }
419 r
420 };
421
422 if sign {
423 result = -result;
424 }
425
426 result
427}
428
429#[cfg(test)]
430#[path = "unit_tests/canonical.rs"]
431mod tests;