problemreductions/
expr.rs

1//! General symbolic expression AST for reduction overhead.
2
3use crate::types::ProblemSize;
4use std::collections::{HashMap, HashSet};
5use std::fmt;
6
7/// A symbolic math expression over problem size variables.
8#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
9pub enum Expr {
10    /// Numeric constant.
11    Const(f64),
12    /// Named variable (e.g., "num_vertices").
13    Var(&'static str),
14    /// Addition: a + b.
15    Add(Box<Expr>, Box<Expr>),
16    /// Multiplication: a * b.
17    Mul(Box<Expr>, Box<Expr>),
18    /// Exponentiation: base ^ exponent.
19    Pow(Box<Expr>, Box<Expr>),
20    /// Exponential function: exp(a).
21    Exp(Box<Expr>),
22    /// Natural logarithm: log(a).
23    Log(Box<Expr>),
24    /// Square root: sqrt(a).
25    Sqrt(Box<Expr>),
26}
27
28impl Expr {
29    /// Convenience constructor for exponentiation.
30    pub fn pow(base: Expr, exp: Expr) -> Self {
31        Expr::Pow(Box::new(base), Box::new(exp))
32    }
33
34    /// Multiply expression by a scalar constant.
35    pub fn scale(self, c: f64) -> Self {
36        Expr::Const(c) * self
37    }
38
39    /// Evaluate the expression given concrete variable values.
40    pub fn eval(&self, vars: &ProblemSize) -> f64 {
41        match self {
42            Expr::Const(c) => *c,
43            Expr::Var(name) => vars.get(name).unwrap_or(0) as f64,
44            Expr::Add(a, b) => a.eval(vars) + b.eval(vars),
45            Expr::Mul(a, b) => a.eval(vars) * b.eval(vars),
46            Expr::Pow(base, exp) => base.eval(vars).powf(exp.eval(vars)),
47            Expr::Exp(a) => a.eval(vars).exp(),
48            Expr::Log(a) => a.eval(vars).ln(),
49            Expr::Sqrt(a) => a.eval(vars).sqrt(),
50        }
51    }
52
53    /// Collect all variable names referenced in this expression.
54    pub fn variables(&self) -> HashSet<&'static str> {
55        let mut vars = HashSet::new();
56        self.collect_variables(&mut vars);
57        vars
58    }
59
60    fn collect_variables(&self, vars: &mut HashSet<&'static str>) {
61        match self {
62            Expr::Const(_) => {}
63            Expr::Var(name) => {
64                vars.insert(name);
65            }
66            Expr::Add(a, b) | Expr::Mul(a, b) | Expr::Pow(a, b) => {
67                a.collect_variables(vars);
68                b.collect_variables(vars);
69            }
70            Expr::Exp(a) | Expr::Log(a) | Expr::Sqrt(a) => {
71                a.collect_variables(vars);
72            }
73        }
74    }
75
76    /// Substitute variables with other expressions.
77    pub fn substitute(&self, mapping: &HashMap<&str, &Expr>) -> Expr {
78        match self {
79            Expr::Const(c) => Expr::Const(*c),
80            Expr::Var(name) => {
81                if let Some(replacement) = mapping.get(name) {
82                    (*replacement).clone()
83                } else {
84                    Expr::Var(name)
85                }
86            }
87            Expr::Add(a, b) => a.substitute(mapping) + b.substitute(mapping),
88            Expr::Mul(a, b) => a.substitute(mapping) * b.substitute(mapping),
89            Expr::Pow(a, b) => Expr::pow(a.substitute(mapping), b.substitute(mapping)),
90            Expr::Exp(a) => Expr::Exp(Box::new(a.substitute(mapping))),
91            Expr::Log(a) => Expr::Log(Box::new(a.substitute(mapping))),
92            Expr::Sqrt(a) => Expr::Sqrt(Box::new(a.substitute(mapping))),
93        }
94    }
95
96    /// Parse an expression string into an `Expr` at runtime.
97    ///
98    /// **Memory note:** Variable names are leaked to `&'static str` via `Box::leak`
99    /// since `Expr::Var` requires static lifetimes. Each unique variable name leaks
100    /// a small allocation that is never freed. This is acceptable for testing and
101    /// one-time cross-check evaluation, but should not be used in hot loops with
102    /// dynamic input.
103    ///
104    /// # Panics
105    /// Panics if the expression string has invalid syntax.
106    pub fn parse(input: &str) -> Expr {
107        Self::try_parse(input)
108            .unwrap_or_else(|e| panic!("failed to parse expression \"{input}\": {e}"))
109    }
110
111    /// Parse an expression string into an `Expr`, returning a normal error on failure.
112    pub fn try_parse(input: &str) -> Result<Expr, String> {
113        parse_to_expr(input)
114    }
115
116    /// Check if this expression is a polynomial (no exp/log/sqrt, integer exponents only).
117    pub fn is_polynomial(&self) -> bool {
118        match self {
119            Expr::Const(_) | Expr::Var(_) => true,
120            Expr::Add(a, b) | Expr::Mul(a, b) => a.is_polynomial() && b.is_polynomial(),
121            Expr::Pow(base, exp) => {
122                base.is_polynomial()
123                    && matches!(exp.as_ref(), Expr::Const(c) if *c >= 0.0 && (*c - c.round()).abs() < 1e-10)
124            }
125            Expr::Exp(_) | Expr::Log(_) | Expr::Sqrt(_) => false,
126        }
127    }
128
129    /// Check whether this expression is suitable for asymptotic complexity notation.
130    ///
131    /// This is intentionally conservative for symbolic size formulas:
132    /// - rejects explicit multiplicative constant factors like `3 * n`
133    /// - rejects additive constant terms like `n + 1`
134    /// - allows constants used as exponents (e.g. `n^(1/3)`)
135    /// - allows constants used as exponential bases (e.g. `2^n`)
136    ///
137    /// The goal is to accept expressions that already look like reduced
138    /// asymptotic notation, rather than exact-count formulas.
139    pub fn is_valid_complexity_notation(&self) -> bool {
140        self.is_valid_complexity_notation_inner()
141    }
142
143    fn is_valid_complexity_notation_inner(&self) -> bool {
144        match self {
145            Expr::Const(c) => (*c - 1.0).abs() < 1e-10,
146            Expr::Var(_) => true,
147            Expr::Add(a, b) => {
148                a.constant_value().is_none()
149                    && b.constant_value().is_none()
150                    && a.is_valid_complexity_notation_inner()
151                    && b.is_valid_complexity_notation_inner()
152            }
153            Expr::Mul(a, b) => {
154                a.constant_value().is_none()
155                    && b.constant_value().is_none()
156                    && a.is_valid_complexity_notation_inner()
157                    && b.is_valid_complexity_notation_inner()
158            }
159            Expr::Pow(base, exp) => {
160                let base_is_constant = base.constant_value().is_some();
161                let exp_is_constant = exp.constant_value().is_some();
162
163                let base_ok = if base_is_constant {
164                    base.is_valid_exponential_base()
165                } else {
166                    base.is_valid_complexity_notation_inner()
167                };
168
169                let exp_ok = if exp_is_constant {
170                    true
171                } else {
172                    exp.is_valid_complexity_notation_inner()
173                };
174
175                base_ok && exp_ok
176            }
177            Expr::Exp(a) | Expr::Log(a) | Expr::Sqrt(a) => a.is_valid_complexity_notation_inner(),
178        }
179    }
180
181    fn is_valid_exponential_base(&self) -> bool {
182        self.constant_value().is_some_and(|c| c > 0.0)
183    }
184
185    pub(crate) fn constant_value(&self) -> Option<f64> {
186        match self {
187            Expr::Const(c) => Some(*c),
188            Expr::Var(_) => None,
189            Expr::Add(a, b) => Some(a.constant_value()? + b.constant_value()?),
190            Expr::Mul(a, b) => Some(a.constant_value()? * b.constant_value()?),
191            Expr::Pow(base, exp) => Some(base.constant_value()?.powf(exp.constant_value()?)),
192            Expr::Exp(a) => Some(a.constant_value()?.exp()),
193            Expr::Log(a) => Some(a.constant_value()?.ln()),
194            Expr::Sqrt(a) => Some(a.constant_value()?.sqrt()),
195        }
196    }
197}
198
199impl fmt::Display for Expr {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        match self {
202            Expr::Const(c) => {
203                let ci = c.round() as i64;
204                if (*c - ci as f64).abs() < 1e-10 {
205                    write!(f, "{ci}")
206                } else {
207                    write!(f, "{c}")
208                }
209            }
210            Expr::Var(name) => write!(f, "{name}"),
211            Expr::Add(a, b) => write!(f, "{a} + {b}"),
212            Expr::Mul(a, b) => {
213                let left = if matches!(a.as_ref(), Expr::Add(_, _)) {
214                    format!("({a})")
215                } else {
216                    format!("{a}")
217                };
218                let right = if matches!(b.as_ref(), Expr::Add(_, _)) {
219                    format!("({b})")
220                } else {
221                    format!("{b}")
222                };
223                write!(f, "{left} * {right}")
224            }
225            Expr::Pow(base, exp) => {
226                // Special case: x^0.5 → sqrt(x)
227                if let Expr::Const(e) = exp.as_ref() {
228                    if (*e - 0.5).abs() < 1e-15 {
229                        return write!(f, "sqrt({base})");
230                    }
231                }
232                let base_str = if matches!(base.as_ref(), Expr::Add(_, _) | Expr::Mul(_, _)) {
233                    format!("({base})")
234                } else {
235                    format!("{base}")
236                };
237                let exp_str = if matches!(exp.as_ref(), Expr::Add(_, _) | Expr::Mul(_, _)) {
238                    format!("({exp})")
239                } else {
240                    format!("{exp}")
241                };
242                write!(f, "{base_str}^{exp_str}")
243            }
244            Expr::Exp(a) => write!(f, "exp({a})"),
245            Expr::Log(a) => write!(f, "log({a})"),
246            Expr::Sqrt(a) => write!(f, "sqrt({a})"),
247        }
248    }
249}
250
251impl std::ops::Add for Expr {
252    type Output = Self;
253
254    fn add(self, other: Self) -> Self {
255        Expr::Add(Box::new(self), Box::new(other))
256    }
257}
258
259impl std::ops::Mul for Expr {
260    type Output = Self;
261
262    fn mul(self, other: Self) -> Self {
263        Expr::Mul(Box::new(self), Box::new(other))
264    }
265}
266
267impl std::ops::Sub for Expr {
268    type Output = Self;
269
270    fn sub(self, other: Self) -> Self {
271        self + Expr::Const(-1.0) * other
272    }
273}
274
275impl std::ops::Div for Expr {
276    type Output = Self;
277
278    fn div(self, other: Self) -> Self {
279        self * Expr::pow(other, Expr::Const(-1.0))
280    }
281}
282
283impl std::ops::Neg for Expr {
284    type Output = Self;
285
286    fn neg(self) -> Self {
287        Expr::Const(-1.0) * self
288    }
289}
290
291/// Error returned when analyzing asymptotic behavior.
292#[derive(Clone, Debug, PartialEq, Eq)]
293pub enum AsymptoticAnalysisError {
294    Unsupported(String),
295}
296
297impl fmt::Display for AsymptoticAnalysisError {
298    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
299        match self {
300            Self::Unsupported(expr) => write!(f, "unsupported asymptotic expression: {expr}"),
301        }
302    }
303}
304
305impl std::error::Error for AsymptoticAnalysisError {}
306
307/// Error returned when exact canonicalization fails.
308#[derive(Clone, Debug, PartialEq, Eq)]
309pub enum CanonicalizationError {
310    /// Expression cannot be canonicalized (e.g., variable in both base and exponent).
311    Unsupported(String),
312}
313
314impl fmt::Display for CanonicalizationError {
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        match self {
317            Self::Unsupported(expr) => {
318                write!(f, "unsupported expression for canonicalization: {expr}")
319            }
320        }
321    }
322}
323
324impl std::error::Error for CanonicalizationError {}
325
326/// Return a normalized `Expr` representing the asymptotic behavior of `expr`.
327///
328/// This is now a compatibility wrapper for `big_o_normal_form()`.
329pub fn asymptotic_normal_form(expr: &Expr) -> Result<Expr, AsymptoticAnalysisError> {
330    crate::big_o::big_o_normal_form(expr)
331}
332
333// --- Runtime expression parser ---
334
335/// Parse an expression string into an `Expr`.
336///
337/// Uses the same grammar as the proc macro parser. Variable names are leaked
338/// to `&'static str` for compatibility with `Expr::Var`.
339fn parse_to_expr(input: &str) -> Result<Expr, String> {
340    let tokens = tokenize_expr(input)?;
341    let mut parser = ExprParser::new(tokens);
342    let expr = parser.parse_additive()?;
343    if parser.pos != parser.tokens.len() {
344        return Err(format!("trailing tokens at position {}", parser.pos));
345    }
346    Ok(expr)
347}
348
349#[derive(Debug, Clone, PartialEq)]
350enum ExprToken {
351    Number(f64),
352    Ident(String),
353    Plus,
354    Minus,
355    Star,
356    Slash,
357    Caret,
358    LParen,
359    RParen,
360}
361
362fn tokenize_expr(input: &str) -> Result<Vec<ExprToken>, String> {
363    let mut tokens = Vec::new();
364    let mut chars = input.chars().peekable();
365    while let Some(&ch) = chars.peek() {
366        match ch {
367            ' ' | '\t' | '\n' => {
368                chars.next();
369            }
370            '+' => {
371                chars.next();
372                tokens.push(ExprToken::Plus);
373            }
374            '-' => {
375                chars.next();
376                tokens.push(ExprToken::Minus);
377            }
378            '*' => {
379                chars.next();
380                tokens.push(ExprToken::Star);
381            }
382            '/' => {
383                chars.next();
384                tokens.push(ExprToken::Slash);
385            }
386            '^' => {
387                chars.next();
388                tokens.push(ExprToken::Caret);
389            }
390            '(' => {
391                chars.next();
392                tokens.push(ExprToken::LParen);
393            }
394            ')' => {
395                chars.next();
396                tokens.push(ExprToken::RParen);
397            }
398            c if c.is_ascii_digit() || c == '.' => {
399                let mut num = String::new();
400                while let Some(&c) = chars.peek() {
401                    if c.is_ascii_digit() || c == '.' {
402                        num.push(c);
403                        chars.next();
404                    } else {
405                        break;
406                    }
407                }
408                tokens.push(ExprToken::Number(
409                    num.parse().map_err(|_| format!("invalid number: {num}"))?,
410                ));
411            }
412            c if c.is_ascii_alphabetic() || c == '_' => {
413                let mut ident = String::new();
414                while let Some(&c) = chars.peek() {
415                    if c.is_ascii_alphanumeric() || c == '_' {
416                        ident.push(c);
417                        chars.next();
418                    } else {
419                        break;
420                    }
421                }
422                tokens.push(ExprToken::Ident(ident));
423            }
424            _ => return Err(format!("unexpected character: '{ch}'")),
425        }
426    }
427    Ok(tokens)
428}
429
430struct ExprParser {
431    tokens: Vec<ExprToken>,
432    pos: usize,
433}
434
435impl ExprParser {
436    fn new(tokens: Vec<ExprToken>) -> Self {
437        Self { tokens, pos: 0 }
438    }
439
440    fn peek(&self) -> Option<&ExprToken> {
441        self.tokens.get(self.pos)
442    }
443
444    fn advance(&mut self) -> Option<ExprToken> {
445        let tok = self.tokens.get(self.pos).cloned();
446        self.pos += 1;
447        tok
448    }
449
450    fn expect(&mut self, expected: &ExprToken) -> Result<(), String> {
451        match self.advance() {
452            Some(ref tok) if tok == expected => Ok(()),
453            Some(tok) => Err(format!("expected {expected:?}, got {tok:?}")),
454            None => Err(format!("expected {expected:?}, got end of input")),
455        }
456    }
457
458    fn parse_additive(&mut self) -> Result<Expr, String> {
459        let mut left = self.parse_multiplicative()?;
460        while matches!(self.peek(), Some(ExprToken::Plus) | Some(ExprToken::Minus)) {
461            let op = self.advance().unwrap();
462            let right = self.parse_multiplicative()?;
463            left = match op {
464                ExprToken::Plus => left + right,
465                ExprToken::Minus => left - right,
466                _ => unreachable!(),
467            };
468        }
469        Ok(left)
470    }
471
472    fn parse_multiplicative(&mut self) -> Result<Expr, String> {
473        let mut left = self.parse_unary()?;
474        while matches!(self.peek(), Some(ExprToken::Star) | Some(ExprToken::Slash)) {
475            let op = self.advance().unwrap();
476            let right = self.parse_unary()?;
477            left = match op {
478                ExprToken::Star => left * right,
479                ExprToken::Slash => left / right,
480                _ => unreachable!(),
481            };
482        }
483        Ok(left)
484    }
485
486    fn parse_power(&mut self) -> Result<Expr, String> {
487        let base = self.parse_primary()?;
488        if matches!(self.peek(), Some(ExprToken::Caret)) {
489            self.advance();
490            let exp = self.parse_unary()?; // right-associative, allows unary minus in exponent
491            Ok(Expr::pow(base, exp))
492        } else {
493            Ok(base)
494        }
495    }
496
497    fn parse_unary(&mut self) -> Result<Expr, String> {
498        if matches!(self.peek(), Some(ExprToken::Minus)) {
499            self.advance();
500            let expr = self.parse_unary()?;
501            Ok(-expr)
502        } else {
503            self.parse_power()
504        }
505    }
506
507    fn parse_primary(&mut self) -> Result<Expr, String> {
508        match self.advance() {
509            Some(ExprToken::Number(n)) => Ok(Expr::Const(n)),
510            Some(ExprToken::Ident(name)) => {
511                if matches!(self.peek(), Some(ExprToken::LParen)) {
512                    self.advance();
513                    let arg = self.parse_additive()?;
514                    self.expect(&ExprToken::RParen)?;
515                    match name.as_str() {
516                        "exp" => Ok(Expr::Exp(Box::new(arg))),
517                        "log" => Ok(Expr::Log(Box::new(arg))),
518                        "sqrt" => Ok(Expr::Sqrt(Box::new(arg))),
519                        _ => Err(format!("unknown function: {name}")),
520                    }
521                } else {
522                    // Leak the string to get &'static str for Expr::Var
523                    let leaked: &'static str = Box::leak(name.into_boxed_str());
524                    Ok(Expr::Var(leaked))
525                }
526            }
527            Some(ExprToken::LParen) => {
528                let expr = self.parse_additive()?;
529                self.expect(&ExprToken::RParen)?;
530                Ok(expr)
531            }
532            Some(tok) => Err(format!("unexpected token: {tok:?}")),
533            None => Err("unexpected end of input".to_string()),
534        }
535    }
536}
537
538#[cfg(test)]
539#[path = "unit_tests/expr.rs"]
540mod tests;