problemreductions/rules/
registry.rs1use crate::expr::Expr;
4use crate::rules::traits::{DynAggregateReductionResult, DynReductionResult};
5use crate::types::ProblemSize;
6use std::any::Any;
7use std::collections::HashSet;
8
9#[derive(Clone, Debug, Default, serde::Serialize)]
11pub struct ReductionOverhead {
12 pub output_size: Vec<(&'static str, Expr)>,
15}
16
17impl ReductionOverhead {
18 pub fn new(output_size: Vec<(&'static str, Expr)>) -> Self {
19 Self { output_size }
20 }
21
22 pub fn identity(fields: &[&'static str]) -> Self {
25 Self {
26 output_size: fields.iter().map(|&f| (f, Expr::Var(f))).collect(),
27 }
28 }
29
30 pub fn evaluate_output_size(&self, input: &ProblemSize) -> ProblemSize {
36 let fields: Vec<_> = self
37 .output_size
38 .iter()
39 .map(|(name, expr)| (*name, expr.eval(input).round() as usize))
40 .collect();
41 ProblemSize::new(fields)
42 }
43
44 pub fn input_variable_names(&self) -> HashSet<&'static str> {
46 self.output_size
47 .iter()
48 .flat_map(|(_, expr)| expr.variables())
49 .collect()
50 }
51
52 pub fn compose(&self, next: &ReductionOverhead) -> ReductionOverhead {
57 use std::collections::HashMap;
58
59 let mapping: HashMap<&str, &Expr> = self
61 .output_size
62 .iter()
63 .map(|(name, expr)| (*name, expr))
64 .collect();
65
66 let composed = next
67 .output_size
68 .iter()
69 .map(|(name, expr)| (*name, expr.substitute(&mapping)))
70 .collect();
71
72 ReductionOverhead {
73 output_size: composed,
74 }
75 }
76
77 pub fn get(&self, name: &str) -> Option<&Expr> {
79 self.output_size
80 .iter()
81 .find(|(n, _)| *n == name)
82 .map(|(_, e)| e)
83 }
84}
85
86pub type ReduceFn = fn(&dyn Any) -> Box<dyn DynReductionResult>;
88
89pub type AggregateReduceFn = fn(&dyn Any) -> Box<dyn DynAggregateReductionResult>;
91
92#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
94pub struct EdgeCapabilities {
95 pub witness: bool,
96 pub aggregate: bool,
97 #[serde(default)]
100 pub turing: bool,
101}
102
103impl EdgeCapabilities {
104 pub const fn none() -> Self {
105 Self {
106 witness: false,
107 aggregate: false,
108 turing: false,
109 }
110 }
111
112 pub const fn witness_only() -> Self {
113 Self {
114 witness: true,
115 aggregate: false,
116 turing: false,
117 }
118 }
119
120 pub const fn aggregate_only() -> Self {
121 Self {
122 witness: false,
123 aggregate: true,
124 turing: false,
125 }
126 }
127
128 pub const fn both() -> Self {
129 Self {
130 witness: true,
131 aggregate: true,
132 turing: false,
133 }
134 }
135
136 pub const fn turing() -> Self {
137 Self {
138 witness: false,
139 aggregate: false,
140 turing: true,
141 }
142 }
143}
144
145impl Default for EdgeCapabilities {
148 fn default() -> Self {
149 Self::witness_only()
150 }
151}
152
153pub struct ReductionEntry {
156 pub source_name: &'static str,
158 pub target_name: &'static str,
160 pub source_variant_fn: fn() -> Vec<(&'static str, &'static str)>,
162 pub target_variant_fn: fn() -> Vec<(&'static str, &'static str)>,
164 pub overhead_fn: fn() -> ReductionOverhead,
166 pub module_path: &'static str,
168 pub reduce_fn: Option<ReduceFn>,
172 pub reduce_aggregate_fn: Option<AggregateReduceFn>,
177 pub capabilities: EdgeCapabilities,
179 pub overhead_eval_fn: fn(&dyn Any) -> ProblemSize,
183 pub source_size_fn: fn(&dyn Any) -> ProblemSize,
187}
188
189impl ReductionEntry {
190 pub fn overhead(&self) -> ReductionOverhead {
192 (self.overhead_fn)()
193 }
194
195 pub fn source_variant(&self) -> Vec<(&'static str, &'static str)> {
197 (self.source_variant_fn)()
198 }
199
200 pub fn target_variant(&self) -> Vec<(&'static str, &'static str)> {
202 (self.target_variant_fn)()
203 }
204
205 pub fn is_base_reduction(&self) -> bool {
207 let source = self.source_variant();
208 let target = self.target_variant();
209 let source_unweighted = source
210 .iter()
211 .find(|(k, _)| *k == "weight")
212 .map(|(_, v)| *v == "One")
213 .unwrap_or(true);
214 let target_unweighted = target
215 .iter()
216 .find(|(k, _)| *k == "weight")
217 .map(|(_, v)| *v == "One")
218 .unwrap_or(true);
219 source_unweighted && target_unweighted
220 }
221}
222
223impl std::fmt::Debug for ReductionEntry {
224 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225 f.debug_struct("ReductionEntry")
226 .field("source_name", &self.source_name)
227 .field("target_name", &self.target_name)
228 .field("source_variant", &self.source_variant())
229 .field("target_variant", &self.target_variant())
230 .field("overhead", &self.overhead())
231 .field("module_path", &self.module_path)
232 .field("capabilities", &self.capabilities)
233 .finish()
234 }
235}
236
237inventory::collect!(ReductionEntry);
238
239pub fn reduction_entries() -> Vec<&'static ReductionEntry> {
241 inventory::iter::<ReductionEntry>().collect()
242}
243
244#[cfg(test)]
245#[path = "../unit_tests/rules/registry.rs"]
246mod tests;