1use std::collections::BTreeMap;
7
8use crate::expr::{CanonicalizationError, Expr};
9
10#[derive(Clone, Debug, PartialEq)]
14struct OpaqueFactor {
15 key: String,
17 expr: Expr,
19}
20
21impl Eq for OpaqueFactor {}
22
23impl PartialOrd for OpaqueFactor {
24 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
25 Some(self.cmp(other))
26 }
27}
28
29impl Ord for OpaqueFactor {
30 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
31 self.key.cmp(&other.key)
32 }
33}
34
35fn normalized_f64_bits(value: f64) -> u64 {
36 if value == 0.0 {
37 0.0f64.to_bits()
38 } else {
39 value.to_bits()
40 }
41}
42
43#[derive(Clone, Debug)]
45struct CanonicalTerm {
46 coeff: f64,
48 vars: BTreeMap<&'static str, f64>,
50 opaque: Vec<OpaqueFactor>,
52}
53
54fn try_merge_opaque(existing: &[OpaqueFactor], new: &OpaqueFactor) -> Option<Vec<OpaqueFactor>> {
57 for (i, existing_factor) in existing.iter().enumerate() {
58 if let (Expr::Exp(a), Expr::Exp(b)) = (&existing_factor.expr, &new.expr) {
60 let merged_arg = (**a).clone() + (**b).clone();
61 let merged_expr =
62 Expr::Exp(Box::new(canonical_form(&merged_arg).unwrap_or(merged_arg)));
63 let mut result = existing.to_vec();
64 result[i] = OpaqueFactor {
65 key: merged_expr.to_string(),
66 expr: merged_expr,
67 };
68 return Some(result);
69 }
70
71 if let (Expr::Pow(base1, exp1), Expr::Pow(base2, exp2)) = (&existing_factor.expr, &new.expr)
73 {
74 if let (Some(c1), Some(c2)) = (base1.constant_value(), base2.constant_value()) {
75 if c1 > 0.0 && c2 > 0.0 && (c1 - c2).abs() < 1e-15 {
76 let merged_exp = (**exp1).clone() + (**exp2).clone();
77 let canon_exp = canonical_form(&merged_exp).unwrap_or(merged_exp);
78 let merged_expr = Expr::Pow(base1.clone(), Box::new(canon_exp));
79 let mut result = existing.to_vec();
80 result[i] = OpaqueFactor {
81 key: merged_expr.to_string(),
82 expr: merged_expr,
83 };
84 return Some(result);
85 }
86 }
87 }
88 }
89 None
90}
91
92#[derive(Clone, Debug)]
94pub(crate) struct CanonicalSum {
95 terms: Vec<CanonicalTerm>,
96}
97
98impl CanonicalTerm {
99 fn constant(c: f64) -> Self {
100 Self {
101 coeff: c,
102 vars: BTreeMap::new(),
103 opaque: Vec::new(),
104 }
105 }
106
107 fn variable(name: &'static str) -> Self {
108 let mut vars = BTreeMap::new();
109 vars.insert(name, 1.0);
110 Self {
111 coeff: 1.0,
112 vars,
113 opaque: Vec::new(),
114 }
115 }
116
117 fn opaque_factor(expr: Expr) -> Self {
118 let key = expr.to_string();
119 Self {
120 coeff: 1.0,
121 vars: BTreeMap::new(),
122 opaque: vec![OpaqueFactor { key, expr }],
123 }
124 }
125
126 fn mul(&self, other: &CanonicalTerm) -> CanonicalTerm {
130 let coeff = self.coeff * other.coeff;
131 let mut vars = self.vars.clone();
132 for (&v, &e) in &other.vars {
133 *vars.entry(v).or_insert(0.0) += e;
134 }
135 vars.retain(|_, e| e.abs() > 1e-15);
137
138 let mut opaque = self.opaque.clone();
140 for other_factor in &other.opaque {
141 if let Some(merged) = try_merge_opaque(&opaque, other_factor) {
142 opaque = merged;
143 } else {
144 opaque.push(other_factor.clone());
145 }
146 }
147 opaque.sort();
148 CanonicalTerm {
149 coeff,
150 vars,
151 opaque,
152 }
153 }
154
155 fn sort_key(&self) -> (Vec<(&'static str, u64)>, Vec<String>) {
157 let vars: Vec<_> = self
158 .vars
159 .iter()
160 .map(|(&k, &v)| (k, normalized_f64_bits(v)))
161 .collect();
162 let opaque: Vec<_> = self.opaque.iter().map(|o| o.key.clone()).collect();
163 (vars, opaque)
164 }
165}
166
167impl CanonicalSum {
168 fn from_term(term: CanonicalTerm) -> Self {
169 Self { terms: vec![term] }
170 }
171
172 fn add(mut self, other: CanonicalSum) -> Self {
173 self.terms.extend(other.terms);
174 self
175 }
176
177 fn mul(&self, other: &CanonicalSum) -> CanonicalSum {
178 let mut terms = Vec::new();
179 for a in &self.terms {
180 for b in &other.terms {
181 terms.push(a.mul(b));
182 }
183 }
184 CanonicalSum { terms }
185 }
186
187 fn simplify(self) -> Self {
190 type SortKey = (Vec<(&'static str, u64)>, Vec<String>);
191 let mut groups: BTreeMap<SortKey, CanonicalTerm> = BTreeMap::new();
192
193 for term in self.terms {
194 let key = term.sort_key();
195 groups
196 .entry(key)
197 .and_modify(|existing| existing.coeff += term.coeff)
198 .or_insert(term);
199 }
200
201 let mut terms: Vec<_> = groups
202 .into_values()
203 .filter(|t| t.coeff.abs() > 1e-15)
204 .collect();
205
206 terms.sort_by(|a, b| a.sort_key().cmp(&b.sort_key()));
207
208 CanonicalSum { terms }
209 }
210}
211
212pub fn canonical_form(expr: &Expr) -> Result<Expr, CanonicalizationError> {
224 let sum = expr_to_canonical(expr)?;
225 let simplified = sum.simplify();
226 Ok(canonical_sum_to_expr(&simplified))
227}
228
229fn expr_to_canonical(expr: &Expr) -> Result<CanonicalSum, CanonicalizationError> {
230 match expr {
231 Expr::Const(c) => Ok(CanonicalSum::from_term(CanonicalTerm::constant(*c))),
232 Expr::Var(name) => Ok(CanonicalSum::from_term(CanonicalTerm::variable(name))),
233 Expr::Add(a, b) => {
234 let ca = expr_to_canonical(a)?;
235 let cb = expr_to_canonical(b)?;
236 Ok(ca.add(cb))
237 }
238 Expr::Mul(a, b) => {
239 let ca = expr_to_canonical(a)?;
240 let cb = expr_to_canonical(b)?;
241 Ok(ca.mul(&cb))
242 }
243 Expr::Pow(base, exp) => canonicalize_pow(base, exp),
244 Expr::Exp(arg) => {
245 let inner = canonical_form(arg)?;
247 Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
248 Expr::Exp(Box::new(inner)),
249 )))
250 }
251 Expr::Log(arg) => {
252 let inner = canonical_form(arg)?;
253 Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
254 Expr::Log(Box::new(inner)),
255 )))
256 }
257 Expr::Sqrt(arg) => {
258 canonicalize_pow(arg, &Expr::Const(0.5))
260 }
261 Expr::Factorial(arg) => {
262 let inner = canonical_form(arg)?;
263 Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
264 Expr::Factorial(Box::new(inner)),
265 )))
266 }
267 }
268}
269
270fn canonicalize_pow(base: &Expr, exp: &Expr) -> Result<CanonicalSum, CanonicalizationError> {
271 match (base, exp) {
272 (_, _) if base.constant_value().is_some() && exp.constant_value().is_some() => {
274 let b = base.constant_value().unwrap();
275 let e = exp.constant_value().unwrap();
276 Ok(CanonicalSum::from_term(CanonicalTerm::constant(b.powf(e))))
277 }
278 (Expr::Var(name), _) if exp.constant_value().is_some() => {
280 let e = exp.constant_value().unwrap();
281 if e.abs() < 1e-15 {
282 return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
283 }
284 let mut vars = BTreeMap::new();
285 vars.insert(*name, e);
286 Ok(CanonicalSum::from_term(CanonicalTerm {
287 coeff: 1.0,
288 vars,
289 opaque: Vec::new(),
290 }))
291 }
292 (_, _) if exp.constant_value().is_some() => {
294 let e = exp.constant_value().unwrap();
295 if e >= 0.0 && (e - e.round()).abs() < 1e-10 {
296 let n = e.round() as usize;
297 let base_sum = expr_to_canonical(base)?;
298 if n == 0 {
299 return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
300 }
301 let mut result = base_sum.clone();
302 for _ in 1..n {
303 result = result.mul(&base_sum);
304 }
305 Ok(result)
306 } else {
307 let canon_base = canonical_form(base)?;
309 Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
310 Expr::Pow(Box::new(canon_base), Box::new(Expr::Const(e))),
311 )))
312 }
313 }
314 (_, _) if base.constant_value().is_some() => {
316 let c = base.constant_value().unwrap();
317 if (c - 1.0).abs() < 1e-15 {
318 return Ok(CanonicalSum::from_term(CanonicalTerm::constant(1.0)));
319 }
320 if c <= 0.0 {
321 return Err(CanonicalizationError::Unsupported(format!(
322 "{}^{}",
323 base, exp
324 )));
325 }
326 let canon_exp = canonical_form(exp)?;
327 Ok(CanonicalSum::from_term(CanonicalTerm::opaque_factor(
328 Expr::Pow(Box::new(base.clone()), Box::new(canon_exp)),
329 )))
330 }
331 _ => Err(CanonicalizationError::Unsupported(format!(
333 "{}^{}",
334 base, exp
335 ))),
336 }
337}
338
339fn canonical_sum_to_expr(sum: &CanonicalSum) -> Expr {
340 if sum.terms.is_empty() {
341 return Expr::Const(0.0);
342 }
343
344 let term_exprs: Vec<Expr> = sum.terms.iter().map(canonical_term_to_expr).collect();
345
346 let mut result = term_exprs[0].clone();
347 for term in &term_exprs[1..] {
348 result = result + term.clone();
349 }
350 result
351}
352
353fn canonical_term_to_expr(term: &CanonicalTerm) -> Expr {
354 let mut factors: Vec<Expr> = Vec::new();
355
356 let (coeff_factor, sign) = if term.coeff < 0.0 {
358 (term.coeff.abs(), true)
359 } else {
360 (term.coeff, false)
361 };
362
363 let has_other_factors = !term.vars.is_empty() || !term.opaque.is_empty();
364
365 if (coeff_factor - 1.0).abs() > 1e-15 || !has_other_factors {
366 factors.push(Expr::Const(coeff_factor));
367 }
368
369 for (&var, &exp) in &term.vars {
371 if (exp - 1.0).abs() < 1e-15 {
372 factors.push(Expr::Var(var));
373 } else {
374 factors.push(Expr::pow(Expr::Var(var), Expr::Const(exp)));
375 }
376 }
377
378 for opaque in &term.opaque {
380 factors.push(opaque.expr.clone());
381 }
382
383 let mut result = if factors.is_empty() {
384 Expr::Const(1.0)
385 } else {
386 let mut r = factors[0].clone();
387 for f in &factors[1..] {
388 r = r * f.clone();
389 }
390 r
391 };
392
393 if sign {
394 result = -result;
395 }
396
397 result
398}
399
400#[cfg(test)]
401#[path = "unit_tests/canonical.rs"]
402mod tests;