problemreductions/rules/
registry.rs

1//! Automatic reduction registration via inventory.
2
3use crate::expr::Expr;
4use crate::rules::traits::DynReductionResult;
5use crate::types::ProblemSize;
6use std::any::Any;
7use std::collections::HashSet;
8
9/// Overhead specification for a reduction.
10#[derive(Clone, Debug, Default, serde::Serialize)]
11pub struct ReductionOverhead {
12    /// Output size as expressions of input size variables.
13    /// Each entry is (output_field_name, expression).
14    pub output_size: Vec<(&'static str, Expr)>,
15}
16
17impl ReductionOverhead {
18    pub fn new(output_size: Vec<(&'static str, Expr)>) -> Self {
19        Self { output_size }
20    }
21
22    /// Identity overhead: each output field equals the same-named input field.
23    /// Used by variant cast reductions where problem size doesn't change.
24    pub fn identity(fields: &[&'static str]) -> Self {
25        Self {
26            output_size: fields.iter().map(|&f| (f, Expr::Var(f))).collect(),
27        }
28    }
29
30    /// Evaluate output size given input size.
31    ///
32    /// Uses `round()` for the f64 to usize conversion because expression values
33    /// are typically integers and any fractional results come from floating-point
34    /// arithmetic imprecision, not intentional fractions.
35    pub fn evaluate_output_size(&self, input: &ProblemSize) -> ProblemSize {
36        let fields: Vec<_> = self
37            .output_size
38            .iter()
39            .map(|(name, expr)| (*name, expr.eval(input).round() as usize))
40            .collect();
41        ProblemSize::new(fields)
42    }
43
44    /// Collect all input variable names referenced by the overhead expressions.
45    pub fn input_variable_names(&self) -> HashSet<&'static str> {
46        self.output_size
47            .iter()
48            .flat_map(|(_, expr)| expr.variables())
49            .collect()
50    }
51
52    /// Compose two overheads: substitute self's output into `next`'s input.
53    ///
54    /// Returns a new overhead whose expressions map from self's input variables
55    /// directly to `next`'s output variables.
56    pub fn compose(&self, next: &ReductionOverhead) -> ReductionOverhead {
57        use std::collections::HashMap;
58
59        // Build substitution map: output field name → output expression
60        let mapping: HashMap<&str, &Expr> = self
61            .output_size
62            .iter()
63            .map(|(name, expr)| (*name, expr))
64            .collect();
65
66        let composed = next
67            .output_size
68            .iter()
69            .map(|(name, expr)| (*name, expr.substitute(&mapping)))
70            .collect();
71
72        ReductionOverhead {
73            output_size: composed,
74        }
75    }
76
77    /// Get the expression for a named output field.
78    pub fn get(&self, name: &str) -> Option<&Expr> {
79        self.output_size
80            .iter()
81            .find(|(n, _)| *n == name)
82            .map(|(_, e)| e)
83    }
84}
85
86/// A registered reduction entry for static inventory registration.
87/// Uses function pointers to lazily derive variant fields from `Problem::variant()`.
88pub struct ReductionEntry {
89    /// Base name of source problem (e.g., "MaximumIndependentSet").
90    pub source_name: &'static str,
91    /// Base name of target problem (e.g., "MinimumVertexCover").
92    pub target_name: &'static str,
93    /// Function to derive source variant attributes from `Problem::variant()`.
94    pub source_variant_fn: fn() -> Vec<(&'static str, &'static str)>,
95    /// Function to derive target variant attributes from `Problem::variant()`.
96    pub target_variant_fn: fn() -> Vec<(&'static str, &'static str)>,
97    /// Function to create overhead information (lazy evaluation for static context).
98    pub overhead_fn: fn() -> ReductionOverhead,
99    /// Module path where the reduction is defined (from `module_path!()`).
100    pub module_path: &'static str,
101    /// Type-erased reduction executor.
102    /// Takes a `&dyn Any` (must be `&SourceType`), calls `ReduceTo::reduce_to()`,
103    /// and returns the result as a boxed `DynReductionResult`.
104    pub reduce_fn: fn(&dyn Any) -> Box<dyn DynReductionResult>,
105    /// Compiled overhead evaluation function.
106    /// Takes a `&dyn Any` (must be `&SourceType`), calls getter methods directly,
107    /// and returns the computed target problem size.
108    pub overhead_eval_fn: fn(&dyn Any) -> ProblemSize,
109}
110
111impl ReductionEntry {
112    /// Get the overhead by calling the function.
113    pub fn overhead(&self) -> ReductionOverhead {
114        (self.overhead_fn)()
115    }
116
117    /// Get the source variant by calling the function.
118    pub fn source_variant(&self) -> Vec<(&'static str, &'static str)> {
119        (self.source_variant_fn)()
120    }
121
122    /// Get the target variant by calling the function.
123    pub fn target_variant(&self) -> Vec<(&'static str, &'static str)> {
124        (self.target_variant_fn)()
125    }
126
127    /// Check if this reduction involves only the base (unweighted) variants.
128    pub fn is_base_reduction(&self) -> bool {
129        let source = self.source_variant();
130        let target = self.target_variant();
131        let source_unweighted = source
132            .iter()
133            .find(|(k, _)| *k == "weight")
134            .map(|(_, v)| *v == "One")
135            .unwrap_or(true);
136        let target_unweighted = target
137            .iter()
138            .find(|(k, _)| *k == "weight")
139            .map(|(_, v)| *v == "One")
140            .unwrap_or(true);
141        source_unweighted && target_unweighted
142    }
143}
144
145impl std::fmt::Debug for ReductionEntry {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        f.debug_struct("ReductionEntry")
148            .field("source_name", &self.source_name)
149            .field("target_name", &self.target_name)
150            .field("source_variant", &self.source_variant())
151            .field("target_variant", &self.target_variant())
152            .field("overhead", &self.overhead())
153            .field("module_path", &self.module_path)
154            .finish()
155    }
156}
157
158inventory::collect!(ReductionEntry);
159
160#[cfg(test)]
161#[path = "../unit_tests/rules/registry.rs"]
162mod tests;