1use crate::canonical::canonical_form;
11use crate::expr::Expr;
12use crate::rules::graph::{ReductionGraph, ReductionPath};
13use crate::rules::registry::ReductionOverhead;
14use std::collections::BTreeMap;
15use std::fmt;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum ComparisonStatus {
20 Dominated,
22 NotDominated,
24 Unknown,
26}
27
28#[derive(Debug, Clone)]
30pub struct DominatedRule {
31 pub source_name: &'static str,
32 pub source_variant: BTreeMap<String, String>,
33 pub target_name: &'static str,
34 pub target_variant: BTreeMap<String, String>,
35 pub primitive_overhead: ReductionOverhead,
36 pub dominating_path: ReductionPath,
37 pub composed_overhead: ReductionOverhead,
38 pub comparable_fields: Vec<String>,
39}
40
41impl DominatedRule {
42 pub fn source_display(&self) -> String {
43 format_problem_variant(self.source_name, &self.source_variant)
44 }
45
46 pub fn target_display(&self) -> String {
47 format_problem_variant(self.target_name, &self.target_variant)
48 }
49}
50
51impl fmt::Display for DominatedRule {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 write!(f, "{} -> {}", self.source_display(), self.target_display())
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct UnknownComparison {
60 pub source_name: &'static str,
61 pub source_variant: BTreeMap<String, String>,
62 pub target_name: &'static str,
63 pub target_variant: BTreeMap<String, String>,
64 pub candidate_path: ReductionPath,
65 pub reason: String,
66}
67
68impl UnknownComparison {
69 pub fn source_display(&self) -> String {
70 format_problem_variant(self.source_name, &self.source_variant)
71 }
72
73 pub fn target_display(&self) -> String {
74 format_problem_variant(self.target_name, &self.target_variant)
75 }
76}
77
78impl fmt::Display for UnknownComparison {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(f, "{} -> {}", self.source_display(), self.target_display())
81 }
82}
83
84pub fn format_problem_variant(name: &str, variant: &BTreeMap<String, String>) -> String {
85 if variant.is_empty() {
86 return name.to_string();
87 }
88
89 let vars = variant
90 .iter()
91 .map(|(k, v)| format!("{k}: {v:?}"))
92 .collect::<Vec<_>>()
93 .join(", ");
94 format!("{name} {{{vars}}}")
95}
96
97#[derive(Debug, Clone)]
101struct Monomial {
102 coeff: f64,
103 vars: BTreeMap<&'static str, f64>,
105}
106
107impl Monomial {
108 fn constant(c: f64) -> Self {
109 Self {
110 coeff: c,
111 vars: BTreeMap::new(),
112 }
113 }
114
115 fn variable(name: &'static str) -> Self {
116 let mut vars = BTreeMap::new();
117 vars.insert(name, 1.0);
118 Self { coeff: 1.0, vars }
119 }
120
121 fn mul(&self, other: &Monomial) -> Monomial {
123 let coeff = self.coeff * other.coeff;
124 let mut vars = self.vars.clone();
125 for (&v, &e) in &other.vars {
126 *vars.entry(v).or_insert(0.0) += e;
127 }
128 Monomial { coeff, vars }
129 }
130}
131
132#[derive(Debug, Clone)]
134struct NormalizedPoly {
135 terms: Vec<Monomial>,
136}
137
138impl NormalizedPoly {
139 fn add(mut self, other: NormalizedPoly) -> NormalizedPoly {
140 self.terms.extend(other.terms);
141 self
142 }
143
144 fn mul(&self, other: &NormalizedPoly) -> NormalizedPoly {
145 let mut terms = Vec::new();
146 for a in &self.terms {
147 for b in &other.terms {
148 terms.push(a.mul(b));
149 }
150 }
151 NormalizedPoly { terms }
152 }
153
154 fn has_negative_coefficients(&self) -> bool {
156 self.terms.iter().any(|m| m.coeff < -1e-15)
157 }
158}
159
160fn normalize_polynomial(expr: &Expr) -> Result<NormalizedPoly, String> {
166 match expr {
167 Expr::Const(c) => Ok(NormalizedPoly {
168 terms: vec![Monomial::constant(*c)],
169 }),
170 Expr::Var(v) => Ok(NormalizedPoly {
171 terms: vec![Monomial::variable(v)],
172 }),
173 Expr::Add(a, b) => {
174 let pa = normalize_polynomial(a)?;
175 let pb = normalize_polynomial(b)?;
176 Ok(pa.add(pb))
177 }
178 Expr::Mul(a, b) => {
179 let pa = normalize_polynomial(a)?;
180 let pb = normalize_polynomial(b)?;
181 Ok(pa.mul(&pb))
182 }
183 Expr::Pow(base, exp) => {
184 if let Expr::Const(c) = exp.as_ref() {
185 if *c < 0.0 {
186 return Err(format!("negative exponent: {c}"));
187 }
188 let pb = normalize_polynomial(base)?;
189 if pb.terms.len() == 1 {
191 let m = &pb.terms[0];
192 let coeff = m.coeff.powf(*c);
193 let vars: BTreeMap<_, _> = m.vars.iter().map(|(&v, &e)| (v, e * c)).collect();
194 return Ok(NormalizedPoly {
195 terms: vec![Monomial { coeff, vars }],
196 });
197 }
198 let n = *c as usize;
200 if c.fract().abs() < 1e-10 {
201 if n == 0 {
202 return Ok(NormalizedPoly {
203 terms: vec![Monomial::constant(1.0)],
204 });
205 }
206 let mut result = pb.clone();
207 for _ in 1..n {
208 result = result.mul(&pb);
209 }
210 return Ok(result);
211 }
212 Err(format!(
213 "non-integer power of multi-term polynomial: ({base})^{c}"
214 ))
215 } else {
216 Err(format!("variable exponent: ({base})^({exp})"))
217 }
218 }
219 Expr::Exp(_) => Err("exp() not supported".into()),
220 Expr::Log(_) => Err("log() not supported".into()),
221 Expr::Sqrt(_) => Err("sqrt() not supported".into()),
222 }
223}
224
225fn prepare_expr_for_comparison(expr: &Expr) -> Expr {
226 canonical_form(expr).unwrap_or_else(|_| expr.clone())
227}
228
229fn monomial_dominated_by(small: &Monomial, big: &Monomial) -> bool {
236 for (&var, &exp_small) in &small.vars {
237 let exp_big = big.vars.get(var).copied().unwrap_or(0.0);
238 if exp_small > exp_big + 1e-10 {
239 return false;
240 }
241 }
242 true
243}
244
245fn poly_leq(a: &NormalizedPoly, b: &NormalizedPoly) -> bool {
250 let b_positive: Vec<&Monomial> = b.terms.iter().filter(|m| m.coeff > 1e-15).collect();
251
252 for a_term in &a.terms {
253 if a_term.coeff <= 1e-15 {
254 continue; }
256 let dominated = b_positive
257 .iter()
258 .any(|b_term| monomial_dominated_by(a_term, b_term));
259 if !dominated {
260 return false;
261 }
262 }
263 true
264}
265
266pub fn compare_overhead(
275 primitive: &ReductionOverhead,
276 composite: &ReductionOverhead,
277) -> ComparisonStatus {
278 let comp_map: std::collections::HashMap<&str, &Expr> = composite
279 .output_size
280 .iter()
281 .map(|(name, expr)| (*name, expr))
282 .collect();
283
284 let mut any_common = false;
285
286 for (field, prim_expr) in &primitive.output_size {
287 let Some(comp_expr) = comp_map.get(field) else {
288 continue;
289 };
290 any_common = true;
291
292 let primitive_prepared = prepare_expr_for_comparison(prim_expr);
293 let composite_prepared = prepare_expr_for_comparison(comp_expr);
294
295 if primitive_prepared == composite_prepared {
296 continue;
297 }
298
299 let primitive_poly = match normalize_polynomial(&primitive_prepared) {
300 Ok(p) => p,
301 Err(_) => return ComparisonStatus::Unknown,
302 };
303 let composite_poly = match normalize_polynomial(&composite_prepared) {
304 Ok(p) => p,
305 Err(_) => return ComparisonStatus::Unknown,
306 };
307
308 if primitive_poly.has_negative_coefficients() || composite_poly.has_negative_coefficients()
310 {
311 return ComparisonStatus::Unknown;
312 }
313
314 if !poly_leq(&composite_poly, &primitive_poly) {
316 return ComparisonStatus::NotDominated;
317 }
318 }
319
320 if any_common {
321 ComparisonStatus::Dominated
322 } else {
323 ComparisonStatus::NotDominated
324 }
325}
326
327pub fn find_dominated_rules(
343 graph: &ReductionGraph,
344) -> (Vec<DominatedRule>, Vec<UnknownComparison>) {
345 let mut dominated = Vec::new();
346 let mut unknown = Vec::new();
347
348 for edge_info in all_edges(graph) {
349 let paths = graph.find_all_paths(
350 edge_info.source_name,
351 &edge_info.source_variant,
352 edge_info.target_name,
353 &edge_info.target_variant,
354 );
355
356 let mut best_dominating: Option<(ReductionPath, ReductionOverhead, Vec<String>)> = None;
357
358 for path in paths {
359 if path.len() <= 1 {
360 continue; }
362
363 let composed = graph.compose_path_overhead(&path);
364
365 match compare_overhead(&edge_info.overhead, &composed) {
366 ComparisonStatus::Dominated => {
367 let comparable_fields = common_fields(&edge_info.overhead, &composed);
368 let is_better = match &best_dominating {
369 None => true,
370 Some((best_path, _, _)) => path.len() < best_path.len(),
371 };
372 if is_better {
373 best_dominating = Some((path, composed, comparable_fields));
374 }
375 }
376 ComparisonStatus::Unknown => {
377 unknown.push(UnknownComparison {
378 source_name: edge_info.source_name,
379 source_variant: edge_info.source_variant.clone(),
380 target_name: edge_info.target_name,
381 target_variant: edge_info.target_variant.clone(),
382 candidate_path: path,
383 reason: "expression comparison returned Unknown".into(),
384 });
385 }
386 ComparisonStatus::NotDominated => {}
387 }
388 }
389
390 if let Some((path, composed, fields)) = best_dominating {
391 dominated.push(DominatedRule {
392 source_name: edge_info.source_name,
393 source_variant: edge_info.source_variant.clone(),
394 target_name: edge_info.target_name,
395 target_variant: edge_info.target_variant.clone(),
396 primitive_overhead: edge_info.overhead.clone(),
397 dominating_path: path,
398 composed_overhead: composed,
399 comparable_fields: fields,
400 });
401 }
402 }
403
404 dominated.sort_by(|a, b| {
406 (
407 format_problem_variant(a.source_name, &a.source_variant),
408 format_problem_variant(a.target_name, &a.target_variant),
409 a.dominating_path.len(),
410 )
411 .cmp(&(
412 format_problem_variant(b.source_name, &b.source_variant),
413 format_problem_variant(b.target_name, &b.target_variant),
414 b.dominating_path.len(),
415 ))
416 });
417 unknown.sort_by(|a, b| {
418 (
419 format_problem_variant(a.source_name, &a.source_variant),
420 format_problem_variant(a.target_name, &a.target_variant),
421 )
422 .cmp(&(
423 format_problem_variant(b.source_name, &b.source_variant),
424 format_problem_variant(b.target_name, &b.target_variant),
425 ))
426 });
427
428 (dominated, unknown)
429}
430
431fn common_fields(a: &ReductionOverhead, b: &ReductionOverhead) -> Vec<String> {
433 let b_fields: std::collections::HashSet<&str> = b.output_size.iter().map(|(n, _)| *n).collect();
434 a.output_size
435 .iter()
436 .filter(|&(f, _)| b_fields.contains(f))
437 .map(|(f, _)| f.to_string())
438 .collect()
439}
440
441fn all_edges(graph: &ReductionGraph) -> Vec<crate::rules::graph::ReductionEdgeInfo> {
443 let mut edges = Vec::new();
444 for name in graph.problem_types() {
445 edges.extend(graph.outgoing_reductions(name));
446 }
447 edges
448}
449
450#[cfg(test)]
451#[path = "../unit_tests/rules/analysis.rs"]
452mod tests;