1use crate::types::ProblemSize;
4use std::collections::{HashMap, HashSet};
5use std::fmt;
6
7#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
9pub enum Expr {
10 Const(f64),
12 Var(&'static str),
14 Add(Box<Expr>, Box<Expr>),
16 Mul(Box<Expr>, Box<Expr>),
18 Pow(Box<Expr>, Box<Expr>),
20 Exp(Box<Expr>),
22 Log(Box<Expr>),
24 Sqrt(Box<Expr>),
26}
27
28impl Expr {
29 pub fn pow(base: Expr, exp: Expr) -> Self {
31 Expr::Pow(Box::new(base), Box::new(exp))
32 }
33
34 pub fn scale(self, c: f64) -> Self {
36 Expr::Const(c) * self
37 }
38
39 pub fn eval(&self, vars: &ProblemSize) -> f64 {
41 match self {
42 Expr::Const(c) => *c,
43 Expr::Var(name) => vars.get(name).unwrap_or(0) as f64,
44 Expr::Add(a, b) => a.eval(vars) + b.eval(vars),
45 Expr::Mul(a, b) => a.eval(vars) * b.eval(vars),
46 Expr::Pow(base, exp) => base.eval(vars).powf(exp.eval(vars)),
47 Expr::Exp(a) => a.eval(vars).exp(),
48 Expr::Log(a) => a.eval(vars).ln(),
49 Expr::Sqrt(a) => a.eval(vars).sqrt(),
50 }
51 }
52
53 pub fn variables(&self) -> HashSet<&'static str> {
55 let mut vars = HashSet::new();
56 self.collect_variables(&mut vars);
57 vars
58 }
59
60 fn collect_variables(&self, vars: &mut HashSet<&'static str>) {
61 match self {
62 Expr::Const(_) => {}
63 Expr::Var(name) => {
64 vars.insert(name);
65 }
66 Expr::Add(a, b) | Expr::Mul(a, b) | Expr::Pow(a, b) => {
67 a.collect_variables(vars);
68 b.collect_variables(vars);
69 }
70 Expr::Exp(a) | Expr::Log(a) | Expr::Sqrt(a) => {
71 a.collect_variables(vars);
72 }
73 }
74 }
75
76 pub fn substitute(&self, mapping: &HashMap<&str, &Expr>) -> Expr {
78 match self {
79 Expr::Const(c) => Expr::Const(*c),
80 Expr::Var(name) => {
81 if let Some(replacement) = mapping.get(name) {
82 (*replacement).clone()
83 } else {
84 Expr::Var(name)
85 }
86 }
87 Expr::Add(a, b) => a.substitute(mapping) + b.substitute(mapping),
88 Expr::Mul(a, b) => a.substitute(mapping) * b.substitute(mapping),
89 Expr::Pow(a, b) => Expr::pow(a.substitute(mapping), b.substitute(mapping)),
90 Expr::Exp(a) => Expr::Exp(Box::new(a.substitute(mapping))),
91 Expr::Log(a) => Expr::Log(Box::new(a.substitute(mapping))),
92 Expr::Sqrt(a) => Expr::Sqrt(Box::new(a.substitute(mapping))),
93 }
94 }
95
96 pub fn parse(input: &str) -> Expr {
107 Self::try_parse(input)
108 .unwrap_or_else(|e| panic!("failed to parse expression \"{input}\": {e}"))
109 }
110
111 pub fn try_parse(input: &str) -> Result<Expr, String> {
113 parse_to_expr(input)
114 }
115
116 pub fn is_polynomial(&self) -> bool {
118 match self {
119 Expr::Const(_) | Expr::Var(_) => true,
120 Expr::Add(a, b) | Expr::Mul(a, b) => a.is_polynomial() && b.is_polynomial(),
121 Expr::Pow(base, exp) => {
122 base.is_polynomial()
123 && matches!(exp.as_ref(), Expr::Const(c) if *c >= 0.0 && (*c - c.round()).abs() < 1e-10)
124 }
125 Expr::Exp(_) | Expr::Log(_) | Expr::Sqrt(_) => false,
126 }
127 }
128
129 pub fn is_valid_complexity_notation(&self) -> bool {
140 self.is_valid_complexity_notation_inner()
141 }
142
143 fn is_valid_complexity_notation_inner(&self) -> bool {
144 match self {
145 Expr::Const(c) => (*c - 1.0).abs() < 1e-10,
146 Expr::Var(_) => true,
147 Expr::Add(a, b) => {
148 a.constant_value().is_none()
149 && b.constant_value().is_none()
150 && a.is_valid_complexity_notation_inner()
151 && b.is_valid_complexity_notation_inner()
152 }
153 Expr::Mul(a, b) => {
154 a.constant_value().is_none()
155 && b.constant_value().is_none()
156 && a.is_valid_complexity_notation_inner()
157 && b.is_valid_complexity_notation_inner()
158 }
159 Expr::Pow(base, exp) => {
160 let base_is_constant = base.constant_value().is_some();
161 let exp_is_constant = exp.constant_value().is_some();
162
163 let base_ok = if base_is_constant {
164 base.is_valid_exponential_base()
165 } else {
166 base.is_valid_complexity_notation_inner()
167 };
168
169 let exp_ok = if exp_is_constant {
170 true
171 } else {
172 exp.is_valid_complexity_notation_inner()
173 };
174
175 base_ok && exp_ok
176 }
177 Expr::Exp(a) | Expr::Log(a) | Expr::Sqrt(a) => a.is_valid_complexity_notation_inner(),
178 }
179 }
180
181 fn is_valid_exponential_base(&self) -> bool {
182 self.constant_value().is_some_and(|c| c > 0.0)
183 }
184
185 pub(crate) fn constant_value(&self) -> Option<f64> {
186 match self {
187 Expr::Const(c) => Some(*c),
188 Expr::Var(_) => None,
189 Expr::Add(a, b) => Some(a.constant_value()? + b.constant_value()?),
190 Expr::Mul(a, b) => Some(a.constant_value()? * b.constant_value()?),
191 Expr::Pow(base, exp) => Some(base.constant_value()?.powf(exp.constant_value()?)),
192 Expr::Exp(a) => Some(a.constant_value()?.exp()),
193 Expr::Log(a) => Some(a.constant_value()?.ln()),
194 Expr::Sqrt(a) => Some(a.constant_value()?.sqrt()),
195 }
196 }
197}
198
199impl fmt::Display for Expr {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 match self {
202 Expr::Const(c) => {
203 let ci = c.round() as i64;
204 if (*c - ci as f64).abs() < 1e-10 {
205 write!(f, "{ci}")
206 } else {
207 write!(f, "{c}")
208 }
209 }
210 Expr::Var(name) => write!(f, "{name}"),
211 Expr::Add(a, b) => write!(f, "{a} + {b}"),
212 Expr::Mul(a, b) => {
213 let left = if matches!(a.as_ref(), Expr::Add(_, _)) {
214 format!("({a})")
215 } else {
216 format!("{a}")
217 };
218 let right = if matches!(b.as_ref(), Expr::Add(_, _)) {
219 format!("({b})")
220 } else {
221 format!("{b}")
222 };
223 write!(f, "{left} * {right}")
224 }
225 Expr::Pow(base, exp) => {
226 if let Expr::Const(e) = exp.as_ref() {
228 if (*e - 0.5).abs() < 1e-15 {
229 return write!(f, "sqrt({base})");
230 }
231 }
232 let base_str = if matches!(base.as_ref(), Expr::Add(_, _) | Expr::Mul(_, _)) {
233 format!("({base})")
234 } else {
235 format!("{base}")
236 };
237 let exp_str = if matches!(exp.as_ref(), Expr::Add(_, _) | Expr::Mul(_, _)) {
238 format!("({exp})")
239 } else {
240 format!("{exp}")
241 };
242 write!(f, "{base_str}^{exp_str}")
243 }
244 Expr::Exp(a) => write!(f, "exp({a})"),
245 Expr::Log(a) => write!(f, "log({a})"),
246 Expr::Sqrt(a) => write!(f, "sqrt({a})"),
247 }
248 }
249}
250
251impl std::ops::Add for Expr {
252 type Output = Self;
253
254 fn add(self, other: Self) -> Self {
255 Expr::Add(Box::new(self), Box::new(other))
256 }
257}
258
259impl std::ops::Mul for Expr {
260 type Output = Self;
261
262 fn mul(self, other: Self) -> Self {
263 Expr::Mul(Box::new(self), Box::new(other))
264 }
265}
266
267impl std::ops::Sub for Expr {
268 type Output = Self;
269
270 fn sub(self, other: Self) -> Self {
271 self + Expr::Const(-1.0) * other
272 }
273}
274
275impl std::ops::Div for Expr {
276 type Output = Self;
277
278 fn div(self, other: Self) -> Self {
279 self * Expr::pow(other, Expr::Const(-1.0))
280 }
281}
282
283impl std::ops::Neg for Expr {
284 type Output = Self;
285
286 fn neg(self) -> Self {
287 Expr::Const(-1.0) * self
288 }
289}
290
291#[derive(Clone, Debug, PartialEq, Eq)]
293pub enum AsymptoticAnalysisError {
294 Unsupported(String),
295}
296
297impl fmt::Display for AsymptoticAnalysisError {
298 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
299 match self {
300 Self::Unsupported(expr) => write!(f, "unsupported asymptotic expression: {expr}"),
301 }
302 }
303}
304
305impl std::error::Error for AsymptoticAnalysisError {}
306
307#[derive(Clone, Debug, PartialEq, Eq)]
309pub enum CanonicalizationError {
310 Unsupported(String),
312}
313
314impl fmt::Display for CanonicalizationError {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 match self {
317 Self::Unsupported(expr) => {
318 write!(f, "unsupported expression for canonicalization: {expr}")
319 }
320 }
321 }
322}
323
324impl std::error::Error for CanonicalizationError {}
325
326pub fn asymptotic_normal_form(expr: &Expr) -> Result<Expr, AsymptoticAnalysisError> {
330 crate::big_o::big_o_normal_form(expr)
331}
332
333fn parse_to_expr(input: &str) -> Result<Expr, String> {
340 let tokens = tokenize_expr(input)?;
341 let mut parser = ExprParser::new(tokens);
342 let expr = parser.parse_additive()?;
343 if parser.pos != parser.tokens.len() {
344 return Err(format!("trailing tokens at position {}", parser.pos));
345 }
346 Ok(expr)
347}
348
349#[derive(Debug, Clone, PartialEq)]
350enum ExprToken {
351 Number(f64),
352 Ident(String),
353 Plus,
354 Minus,
355 Star,
356 Slash,
357 Caret,
358 LParen,
359 RParen,
360}
361
362fn tokenize_expr(input: &str) -> Result<Vec<ExprToken>, String> {
363 let mut tokens = Vec::new();
364 let mut chars = input.chars().peekable();
365 while let Some(&ch) = chars.peek() {
366 match ch {
367 ' ' | '\t' | '\n' => {
368 chars.next();
369 }
370 '+' => {
371 chars.next();
372 tokens.push(ExprToken::Plus);
373 }
374 '-' => {
375 chars.next();
376 tokens.push(ExprToken::Minus);
377 }
378 '*' => {
379 chars.next();
380 tokens.push(ExprToken::Star);
381 }
382 '/' => {
383 chars.next();
384 tokens.push(ExprToken::Slash);
385 }
386 '^' => {
387 chars.next();
388 tokens.push(ExprToken::Caret);
389 }
390 '(' => {
391 chars.next();
392 tokens.push(ExprToken::LParen);
393 }
394 ')' => {
395 chars.next();
396 tokens.push(ExprToken::RParen);
397 }
398 c if c.is_ascii_digit() || c == '.' => {
399 let mut num = String::new();
400 while let Some(&c) = chars.peek() {
401 if c.is_ascii_digit() || c == '.' {
402 num.push(c);
403 chars.next();
404 } else {
405 break;
406 }
407 }
408 tokens.push(ExprToken::Number(
409 num.parse().map_err(|_| format!("invalid number: {num}"))?,
410 ));
411 }
412 c if c.is_ascii_alphabetic() || c == '_' => {
413 let mut ident = String::new();
414 while let Some(&c) = chars.peek() {
415 if c.is_ascii_alphanumeric() || c == '_' {
416 ident.push(c);
417 chars.next();
418 } else {
419 break;
420 }
421 }
422 tokens.push(ExprToken::Ident(ident));
423 }
424 _ => return Err(format!("unexpected character: '{ch}'")),
425 }
426 }
427 Ok(tokens)
428}
429
430struct ExprParser {
431 tokens: Vec<ExprToken>,
432 pos: usize,
433}
434
435impl ExprParser {
436 fn new(tokens: Vec<ExprToken>) -> Self {
437 Self { tokens, pos: 0 }
438 }
439
440 fn peek(&self) -> Option<&ExprToken> {
441 self.tokens.get(self.pos)
442 }
443
444 fn advance(&mut self) -> Option<ExprToken> {
445 let tok = self.tokens.get(self.pos).cloned();
446 self.pos += 1;
447 tok
448 }
449
450 fn expect(&mut self, expected: &ExprToken) -> Result<(), String> {
451 match self.advance() {
452 Some(ref tok) if tok == expected => Ok(()),
453 Some(tok) => Err(format!("expected {expected:?}, got {tok:?}")),
454 None => Err(format!("expected {expected:?}, got end of input")),
455 }
456 }
457
458 fn parse_additive(&mut self) -> Result<Expr, String> {
459 let mut left = self.parse_multiplicative()?;
460 while matches!(self.peek(), Some(ExprToken::Plus) | Some(ExprToken::Minus)) {
461 let op = self.advance().unwrap();
462 let right = self.parse_multiplicative()?;
463 left = match op {
464 ExprToken::Plus => left + right,
465 ExprToken::Minus => left - right,
466 _ => unreachable!(),
467 };
468 }
469 Ok(left)
470 }
471
472 fn parse_multiplicative(&mut self) -> Result<Expr, String> {
473 let mut left = self.parse_unary()?;
474 while matches!(self.peek(), Some(ExprToken::Star) | Some(ExprToken::Slash)) {
475 let op = self.advance().unwrap();
476 let right = self.parse_unary()?;
477 left = match op {
478 ExprToken::Star => left * right,
479 ExprToken::Slash => left / right,
480 _ => unreachable!(),
481 };
482 }
483 Ok(left)
484 }
485
486 fn parse_power(&mut self) -> Result<Expr, String> {
487 let base = self.parse_primary()?;
488 if matches!(self.peek(), Some(ExprToken::Caret)) {
489 self.advance();
490 let exp = self.parse_unary()?; Ok(Expr::pow(base, exp))
492 } else {
493 Ok(base)
494 }
495 }
496
497 fn parse_unary(&mut self) -> Result<Expr, String> {
498 if matches!(self.peek(), Some(ExprToken::Minus)) {
499 self.advance();
500 let expr = self.parse_unary()?;
501 Ok(-expr)
502 } else {
503 self.parse_power()
504 }
505 }
506
507 fn parse_primary(&mut self) -> Result<Expr, String> {
508 match self.advance() {
509 Some(ExprToken::Number(n)) => Ok(Expr::Const(n)),
510 Some(ExprToken::Ident(name)) => {
511 if matches!(self.peek(), Some(ExprToken::LParen)) {
512 self.advance();
513 let arg = self.parse_additive()?;
514 self.expect(&ExprToken::RParen)?;
515 match name.as_str() {
516 "exp" => Ok(Expr::Exp(Box::new(arg))),
517 "log" => Ok(Expr::Log(Box::new(arg))),
518 "sqrt" => Ok(Expr::Sqrt(Box::new(arg))),
519 _ => Err(format!("unknown function: {name}")),
520 }
521 } else {
522 let leaked: &'static str = Box::leak(name.into_boxed_str());
524 Ok(Expr::Var(leaked))
525 }
526 }
527 Some(ExprToken::LParen) => {
528 let expr = self.parse_additive()?;
529 self.expect(&ExprToken::RParen)?;
530 Ok(expr)
531 }
532 Some(tok) => Err(format!("unexpected token: {tok:?}")),
533 None => Err("unexpected end of input".to_string()),
534 }
535 }
536}
537
538#[cfg(test)]
539#[path = "unit_tests/expr.rs"]
540mod tests;