Skip to main content

problemreductions/rules/
registry.rs

1//! Automatic reduction registration via inventory.
2
3use crate::expr::Expr;
4use crate::rules::traits::{DynAggregateReductionResult, DynReductionResult};
5use crate::types::ProblemSize;
6use std::any::Any;
7use std::collections::HashSet;
8
9/// Overhead specification for a reduction.
10#[derive(Clone, Debug, Default, serde::Serialize)]
11pub struct ReductionOverhead {
12    /// Output size as expressions of input size variables.
13    /// Each entry is (output_field_name, expression).
14    pub output_size: Vec<(&'static str, Expr)>,
15}
16
17impl ReductionOverhead {
18    pub fn new(output_size: Vec<(&'static str, Expr)>) -> Self {
19        Self { output_size }
20    }
21
22    /// Identity overhead: each output field equals the same-named input field.
23    /// Used by variant cast reductions where problem size doesn't change.
24    pub fn identity(fields: &[&'static str]) -> Self {
25        Self {
26            output_size: fields.iter().map(|&f| (f, Expr::Var(f))).collect(),
27        }
28    }
29
30    /// Evaluate output size given input size.
31    ///
32    /// Uses `round()` for the f64 to usize conversion because expression values
33    /// are typically integers and any fractional results come from floating-point
34    /// arithmetic imprecision, not intentional fractions.
35    pub fn evaluate_output_size(&self, input: &ProblemSize) -> ProblemSize {
36        let fields: Vec<_> = self
37            .output_size
38            .iter()
39            .map(|(name, expr)| (*name, expr.eval(input).round() as usize))
40            .collect();
41        ProblemSize::new(fields)
42    }
43
44    /// Collect all input variable names referenced by the overhead expressions.
45    pub fn input_variable_names(&self) -> HashSet<&'static str> {
46        self.output_size
47            .iter()
48            .flat_map(|(_, expr)| expr.variables())
49            .collect()
50    }
51
52    /// Compose two overheads: substitute self's output into `next`'s input.
53    ///
54    /// Returns a new overhead whose expressions map from self's input variables
55    /// directly to `next`'s output variables.
56    pub fn compose(&self, next: &ReductionOverhead) -> ReductionOverhead {
57        use std::collections::HashMap;
58
59        // Build substitution map: output field name → output expression
60        let mapping: HashMap<&str, &Expr> = self
61            .output_size
62            .iter()
63            .map(|(name, expr)| (*name, expr))
64            .collect();
65
66        let composed = next
67            .output_size
68            .iter()
69            .map(|(name, expr)| (*name, expr.substitute(&mapping)))
70            .collect();
71
72        ReductionOverhead {
73            output_size: composed,
74        }
75    }
76
77    /// Get the expression for a named output field.
78    pub fn get(&self, name: &str) -> Option<&Expr> {
79        self.output_size
80            .iter()
81            .find(|(n, _)| *n == name)
82            .map(|(_, e)| e)
83    }
84}
85
86/// Witness/config reduction executor stored in the inventory.
87pub type ReduceFn = fn(&dyn Any) -> Box<dyn DynReductionResult>;
88
89/// Aggregate/value reduction executor stored in the inventory.
90pub type AggregateReduceFn = fn(&dyn Any) -> Box<dyn DynAggregateReductionResult>;
91
92/// Execution capabilities carried by a reduction edge.
93#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
94pub struct EdgeCapabilities {
95    pub witness: bool,
96    pub aggregate: bool,
97    /// Turing (multi-query) reduction: solving the source requires multiple
98    /// adaptive queries to the target (e.g., binary search over a decision bound).
99    #[serde(default)]
100    pub turing: bool,
101}
102
103impl EdgeCapabilities {
104    pub const fn none() -> Self {
105        Self {
106            witness: false,
107            aggregate: false,
108            turing: false,
109        }
110    }
111
112    pub const fn witness_only() -> Self {
113        Self {
114            witness: true,
115            aggregate: false,
116            turing: false,
117        }
118    }
119
120    pub const fn aggregate_only() -> Self {
121        Self {
122            witness: false,
123            aggregate: true,
124            turing: false,
125        }
126    }
127
128    pub const fn both() -> Self {
129        Self {
130            witness: true,
131            aggregate: true,
132            turing: false,
133        }
134    }
135
136    pub const fn turing() -> Self {
137        Self {
138            witness: false,
139            aggregate: false,
140            turing: true,
141        }
142    }
143}
144
145/// Defaults to `witness_only()` — the conservative choice for edges registered
146/// via `#[reduction]`, which are witness/config reductions.
147impl Default for EdgeCapabilities {
148    fn default() -> Self {
149        Self::witness_only()
150    }
151}
152
153/// A registered reduction entry for static inventory registration.
154/// Uses function pointers to lazily derive variant fields from `Problem::variant()`.
155pub struct ReductionEntry {
156    /// Base name of source problem (e.g., "MaximumIndependentSet").
157    pub source_name: &'static str,
158    /// Base name of target problem (e.g., "MinimumVertexCover").
159    pub target_name: &'static str,
160    /// Function to derive source variant attributes from `Problem::variant()`.
161    pub source_variant_fn: fn() -> Vec<(&'static str, &'static str)>,
162    /// Function to derive target variant attributes from `Problem::variant()`.
163    pub target_variant_fn: fn() -> Vec<(&'static str, &'static str)>,
164    /// Function to create overhead information (lazy evaluation for static context).
165    pub overhead_fn: fn() -> ReductionOverhead,
166    /// Module path where the reduction is defined (from `module_path!()`).
167    pub module_path: &'static str,
168    /// Type-erased reduction executor.
169    /// Takes a `&dyn Any` (must be `&SourceType`), calls `ReduceTo::reduce_to()`,
170    /// and returns the result as a boxed `DynReductionResult`.
171    pub reduce_fn: Option<ReduceFn>,
172    /// Type-erased aggregate reduction executor.
173    /// Takes a `&dyn Any` (must be `&SourceType`), calls
174    /// `ReduceToAggregate::reduce_to_aggregate()`, and returns the result as a
175    /// boxed `DynAggregateReductionResult`.
176    pub reduce_aggregate_fn: Option<AggregateReduceFn>,
177    /// Capability metadata for runtime path filtering.
178    pub capabilities: EdgeCapabilities,
179    /// Compiled overhead evaluation function.
180    /// Takes a `&dyn Any` (must be `&SourceType`), calls getter methods directly,
181    /// and returns the computed target problem size.
182    pub overhead_eval_fn: fn(&dyn Any) -> ProblemSize,
183    /// Extract source problem size from a type-erased instance.
184    /// Takes a `&dyn Any` (must be `&SourceType`), calls getter methods,
185    /// and returns the source problem's size fields as a `ProblemSize`.
186    pub source_size_fn: fn(&dyn Any) -> ProblemSize,
187}
188
189impl ReductionEntry {
190    /// Get the overhead by calling the function.
191    pub fn overhead(&self) -> ReductionOverhead {
192        (self.overhead_fn)()
193    }
194
195    /// Get the source variant by calling the function.
196    pub fn source_variant(&self) -> Vec<(&'static str, &'static str)> {
197        (self.source_variant_fn)()
198    }
199
200    /// Get the target variant by calling the function.
201    pub fn target_variant(&self) -> Vec<(&'static str, &'static str)> {
202        (self.target_variant_fn)()
203    }
204
205    /// Check if this reduction involves only the base (unweighted) variants.
206    pub fn is_base_reduction(&self) -> bool {
207        let source = self.source_variant();
208        let target = self.target_variant();
209        let source_unweighted = source
210            .iter()
211            .find(|(k, _)| *k == "weight")
212            .map(|(_, v)| *v == "One")
213            .unwrap_or(true);
214        let target_unweighted = target
215            .iter()
216            .find(|(k, _)| *k == "weight")
217            .map(|(_, v)| *v == "One")
218            .unwrap_or(true);
219        source_unweighted && target_unweighted
220    }
221}
222
223impl std::fmt::Debug for ReductionEntry {
224    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225        f.debug_struct("ReductionEntry")
226            .field("source_name", &self.source_name)
227            .field("target_name", &self.target_name)
228            .field("source_variant", &self.source_variant())
229            .field("target_variant", &self.target_variant())
230            .field("overhead", &self.overhead())
231            .field("module_path", &self.module_path)
232            .field("capabilities", &self.capabilities)
233            .finish()
234    }
235}
236
237inventory::collect!(ReductionEntry);
238
239/// Return all registered reduction entries.
240pub fn reduction_entries() -> Vec<&'static ReductionEntry> {
241    inventory::iter::<ReductionEntry>().collect()
242}
243
244#[cfg(test)]
245#[path = "../unit_tests/rules/registry.rs"]
246mod tests;