1use crate::canonical::canonical_form;
7use crate::expr::{AsymptoticAnalysisError, CanonicalizationError, Expr};
8
9#[derive(Clone, Debug)]
10struct ProjectedTerm {
11 expr: Expr,
12 negative: bool,
13}
14
15pub fn big_o_normal_form(expr: &Expr) -> Result<Expr, AsymptoticAnalysisError> {
23 let canonical = canonical_form(expr).map_err(|e| match e {
24 CanonicalizationError::Unsupported(s) => AsymptoticAnalysisError::Unsupported(s),
25 })?;
26
27 project_big_o(&canonical)
28}
29
30fn project_big_o(expr: &Expr) -> Result<Expr, AsymptoticAnalysisError> {
32 let mut terms = Vec::new();
34 collect_additive_terms(expr, &mut terms);
35
36 let mut projected: Vec<ProjectedTerm> = Vec::new();
38 for term in &terms {
39 if let Some(projected_term) = project_term(term)? {
40 projected.push(projected_term);
41 }
42 }
44
45 let survivors = remove_dominated_terms(projected);
47
48 if survivors.is_empty() {
49 return Ok(Expr::Const(1.0));
51 }
52
53 if let Some(negative) = survivors.iter().find(|term| term.negative) {
54 return Err(AsymptoticAnalysisError::Unsupported(format!(
55 "-1 * {}",
56 negative.expr
57 )));
58 }
59
60 let mut seen = std::collections::BTreeSet::new();
62 let mut deduped = Vec::new();
63 for term in survivors {
64 let key = term.expr.to_string();
65 if seen.insert(key) {
66 deduped.push(term);
67 }
68 }
69
70 let mut result = deduped[0].expr.clone();
72 for term in &deduped[1..] {
73 result = result + term.expr.clone();
74 }
75
76 Ok(result)
77}
78
79fn collect_additive_terms(expr: &Expr, out: &mut Vec<Expr>) {
80 match expr {
81 Expr::Add(a, b) => {
82 collect_additive_terms(a, out);
83 collect_additive_terms(b, out);
84 }
85 other => out.push(other.clone()),
86 }
87}
88
89fn project_term(term: &Expr) -> Result<Option<ProjectedTerm>, AsymptoticAnalysisError> {
92 if term.constant_value().is_some() {
93 return Ok(None); }
95
96 let mut factors = Vec::new();
98 collect_multiplicative_factors(term, &mut factors);
99
100 let mut coeff = 1.0;
101 let mut symbolic = Vec::new();
102 for factor in &factors {
103 if let Some(c) = factor.constant_value() {
104 coeff *= c;
105 continue;
106 }
107 if contains_negative_exponent(factor) {
108 return Err(AsymptoticAnalysisError::Unsupported(term.to_string()));
109 }
110 symbolic.push(factor.clone());
111 }
112
113 if symbolic.is_empty() {
114 return Ok(None);
115 }
116
117 let mut result = symbolic[0].clone();
118 for f in &symbolic[1..] {
119 result = result * f.clone();
120 }
121
122 Ok(Some(ProjectedTerm {
123 expr: result,
124 negative: coeff < 0.0,
125 }))
126}
127
128fn collect_multiplicative_factors(expr: &Expr, out: &mut Vec<Expr>) {
129 match expr {
130 Expr::Mul(a, b) => {
131 collect_multiplicative_factors(a, out);
132 collect_multiplicative_factors(b, out);
133 }
134 other => out.push(other.clone()),
135 }
136}
137
138fn remove_dominated_terms(terms: Vec<ProjectedTerm>) -> Vec<ProjectedTerm> {
143 if terms.len() <= 1 {
144 return terms;
145 }
146
147 let mut survivors = Vec::new();
148 for (i, term) in terms.iter().enumerate() {
149 let is_dominated = terms
150 .iter()
151 .enumerate()
152 .any(|(j, other)| i != j && term_dominated_by(&term.expr, &other.expr));
153 if !is_dominated {
154 survivors.push(term.clone());
155 }
156 }
157 survivors
158}
159
160fn term_dominated_by(small: &Expr, big: &Expr) -> bool {
167 let small_exps = extract_var_exponents(small);
169 let big_exps = extract_var_exponents(big);
170 if let (Some(ref se), Some(ref be)) = (small_exps, big_exps) {
171 return polynomial_dominated(se, be);
172 }
173
174 let small_vars = small.variables();
176 let big_vars = big.variables();
177 if small_vars.is_empty() || big_vars.is_empty() || !small_vars.is_subset(&big_vars) {
178 return false;
179 }
180
181 let small_has_exp = has_exponential_growth(small);
183 let big_has_exp = has_exponential_growth(big);
184 match (small_has_exp, big_has_exp) {
185 (false, true) => return true, (true, false) => return false, (true, true) => {
188 if let (Some(sb), Some(bb)) = (effective_exp_base(small), effective_exp_base(big)) {
190 if bb > sb * (1.0 + 1e-10) {
191 return true;
192 }
193 }
194 return false;
195 }
196 (false, false) => {} }
198
199 if small_vars == big_vars {
202 return numerical_dominance_check(small, big, &small_vars);
203 }
204
205 false
206}
207
208fn polynomial_dominated(
210 se: &std::collections::BTreeMap<&'static str, f64>,
211 be: &std::collections::BTreeMap<&'static str, f64>,
212) -> bool {
213 let mut all_leq = true;
214 let mut any_strictly_less = false;
215
216 for (var, small_exp) in se {
217 let big_exp = be.get(var).copied().unwrap_or(0.0);
218 if *small_exp > big_exp + 1e-15 {
219 all_leq = false;
220 break;
221 }
222 if *small_exp < big_exp - 1e-15 {
223 any_strictly_less = true;
224 }
225 }
226
227 if all_leq {
228 for (var, big_exp) in be {
229 if !se.contains_key(var) && *big_exp > 1e-15 {
230 any_strictly_less = true;
231 }
232 }
233 }
234
235 all_leq && any_strictly_less
236}
237
238fn extract_var_exponents(expr: &Expr) -> Option<std::collections::BTreeMap<&'static str, f64>> {
241 use std::collections::BTreeMap;
242 let mut exps = BTreeMap::new();
243 extract_var_exponents_inner(expr, &mut exps)?;
244 Some(exps)
245}
246
247fn extract_var_exponents_inner(
248 expr: &Expr,
249 exps: &mut std::collections::BTreeMap<&'static str, f64>,
250) -> Option<()> {
251 match expr {
252 Expr::Var(name) => {
253 *exps.entry(name).or_insert(0.0) += 1.0;
254 Some(())
255 }
256 Expr::Pow(base, exp) => {
257 if let (Expr::Var(name), Some(e)) = (base.as_ref(), exp.constant_value()) {
258 if e < 0.0 {
259 return None;
260 }
261 *exps.entry(name).or_insert(0.0) += e;
262 Some(())
263 } else {
264 None }
266 }
267 Expr::Mul(a, b) => {
268 extract_var_exponents_inner(a, exps)?;
269 extract_var_exponents_inner(b, exps)
270 }
271 Expr::Const(_) => Some(()), _ => None, }
274}
275
276fn contains_negative_exponent(expr: &Expr) -> bool {
277 match expr {
278 Expr::Pow(_, exp) => exp.constant_value().is_some_and(|e| e < 0.0),
279 Expr::Mul(a, b) | Expr::Add(a, b) => {
280 contains_negative_exponent(a) || contains_negative_exponent(b)
281 }
282 Expr::Exp(arg) | Expr::Log(arg) | Expr::Sqrt(arg) | Expr::Factorial(arg) => {
283 contains_negative_exponent(arg)
284 }
285 Expr::Const(_) | Expr::Var(_) => false,
286 }
287}
288
289fn has_exponential_growth(expr: &Expr) -> bool {
293 match expr {
294 Expr::Exp(arg) => !arg.variables().is_empty(),
295 Expr::Pow(base, exp) => {
296 base.constant_value().is_some_and(|c| c > 1.0) && !exp.variables().is_empty()
297 }
298 Expr::Mul(a, b) => has_exponential_growth(a) || has_exponential_growth(b),
299 _ => false,
300 }
301}
302
303fn effective_exp_base(expr: &Expr) -> Option<f64> {
308 match expr {
309 Expr::Exp(arg) => {
310 let vars = arg.variables();
311 if vars.is_empty() {
312 None
313 } else {
314 let size = unit_problem_size(&vars);
315 let rate = arg.eval(&size);
316 Some(std::f64::consts::E.powf(rate))
317 }
318 }
319 Expr::Pow(base, exp) => {
320 if let Some(c) = base.constant_value() {
321 let vars = exp.variables();
322 if c > 1.0 && !vars.is_empty() {
323 let size = unit_problem_size(&vars);
324 let exp_at_1 = exp.eval(&size);
325 Some(c.powf(exp_at_1))
326 } else {
327 None
328 }
329 } else {
330 None
331 }
332 }
333 Expr::Mul(a, b) => match (effective_exp_base(a), effective_exp_base(b)) {
334 (Some(ba), Some(bb)) => Some(ba * bb),
335 (Some(b), None) | (None, Some(b)) => Some(b),
336 (None, None) => None,
337 },
338 _ => None,
339 }
340}
341
342fn make_problem_size(
344 vars: &std::collections::HashSet<&'static str>,
345 val: usize,
346) -> crate::types::ProblemSize {
347 crate::types::ProblemSize::new(vars.iter().map(|&v| (v, val)).collect())
348}
349
350fn unit_problem_size(vars: &std::collections::HashSet<&'static str>) -> crate::types::ProblemSize {
352 make_problem_size(vars, 1)
353}
354
355fn numerical_dominance_check(
360 small: &Expr,
361 big: &Expr,
362 vars: &std::collections::HashSet<&'static str>,
363) -> bool {
364 let size1 = make_problem_size(vars, 100);
365 let size2 = make_problem_size(vars, 10_000);
366
367 let s1 = small.eval(&size1);
368 let b1 = big.eval(&size1);
369 let s2 = small.eval(&size2);
370 let b2 = big.eval(&size2);
371
372 if !s1.is_finite() || !b1.is_finite() || !s2.is_finite() || !b2.is_finite() {
374 return false;
375 }
376 if s1 <= 1e-300 || b1 <= 1e-300 || s2 <= 1e-300 || b2 <= 1e-300 {
377 return false;
378 }
379
380 let ratio1 = b1 / s1;
381 let ratio2 = b2 / s2;
382
383 ratio1 > 1.0 + 1e-10 && ratio2 > ratio1 * (1.0 + 1e-6)
385}
386
387#[cfg(test)]
388#[path = "unit_tests/big_o.rs"]
389mod tests;