problemreductions/rules/
registry.rs1use crate::expr::Expr;
4use crate::rules::traits::DynReductionResult;
5use crate::types::ProblemSize;
6use std::any::Any;
7use std::collections::HashSet;
8
9#[derive(Clone, Debug, Default, serde::Serialize)]
11pub struct ReductionOverhead {
12 pub output_size: Vec<(&'static str, Expr)>,
15}
16
17impl ReductionOverhead {
18 pub fn new(output_size: Vec<(&'static str, Expr)>) -> Self {
19 Self { output_size }
20 }
21
22 pub fn identity(fields: &[&'static str]) -> Self {
25 Self {
26 output_size: fields.iter().map(|&f| (f, Expr::Var(f))).collect(),
27 }
28 }
29
30 pub fn evaluate_output_size(&self, input: &ProblemSize) -> ProblemSize {
36 let fields: Vec<_> = self
37 .output_size
38 .iter()
39 .map(|(name, expr)| (*name, expr.eval(input).round() as usize))
40 .collect();
41 ProblemSize::new(fields)
42 }
43
44 pub fn input_variable_names(&self) -> HashSet<&'static str> {
46 self.output_size
47 .iter()
48 .flat_map(|(_, expr)| expr.variables())
49 .collect()
50 }
51
52 pub fn compose(&self, next: &ReductionOverhead) -> ReductionOverhead {
57 use std::collections::HashMap;
58
59 let mapping: HashMap<&str, &Expr> = self
61 .output_size
62 .iter()
63 .map(|(name, expr)| (*name, expr))
64 .collect();
65
66 let composed = next
67 .output_size
68 .iter()
69 .map(|(name, expr)| (*name, expr.substitute(&mapping)))
70 .collect();
71
72 ReductionOverhead {
73 output_size: composed,
74 }
75 }
76
77 pub fn get(&self, name: &str) -> Option<&Expr> {
79 self.output_size
80 .iter()
81 .find(|(n, _)| *n == name)
82 .map(|(_, e)| e)
83 }
84}
85
86pub struct ReductionEntry {
89 pub source_name: &'static str,
91 pub target_name: &'static str,
93 pub source_variant_fn: fn() -> Vec<(&'static str, &'static str)>,
95 pub target_variant_fn: fn() -> Vec<(&'static str, &'static str)>,
97 pub overhead_fn: fn() -> ReductionOverhead,
99 pub module_path: &'static str,
101 pub reduce_fn: fn(&dyn Any) -> Box<dyn DynReductionResult>,
105 pub overhead_eval_fn: fn(&dyn Any) -> ProblemSize,
109}
110
111impl ReductionEntry {
112 pub fn overhead(&self) -> ReductionOverhead {
114 (self.overhead_fn)()
115 }
116
117 pub fn source_variant(&self) -> Vec<(&'static str, &'static str)> {
119 (self.source_variant_fn)()
120 }
121
122 pub fn target_variant(&self) -> Vec<(&'static str, &'static str)> {
124 (self.target_variant_fn)()
125 }
126
127 pub fn is_base_reduction(&self) -> bool {
129 let source = self.source_variant();
130 let target = self.target_variant();
131 let source_unweighted = source
132 .iter()
133 .find(|(k, _)| *k == "weight")
134 .map(|(_, v)| *v == "One")
135 .unwrap_or(true);
136 let target_unweighted = target
137 .iter()
138 .find(|(k, _)| *k == "weight")
139 .map(|(_, v)| *v == "One")
140 .unwrap_or(true);
141 source_unweighted && target_unweighted
142 }
143}
144
145impl std::fmt::Debug for ReductionEntry {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 f.debug_struct("ReductionEntry")
148 .field("source_name", &self.source_name)
149 .field("target_name", &self.target_name)
150 .field("source_variant", &self.source_variant())
151 .field("target_variant", &self.target_variant())
152 .field("overhead", &self.overhead())
153 .field("module_path", &self.module_path)
154 .finish()
155 }
156}
157
158inventory::collect!(ReductionEntry);
159
160#[cfg(test)]
161#[path = "../unit_tests/rules/registry.rs"]
162mod tests;