problemreductions/
types.rs

1//! Common types used across the problemreductions library.
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5
6/// Marker trait for numeric weight types.
7///
8/// Weight subsumption uses Rust's `From` trait:
9/// - `i32 → f64` is valid (From<i32> for f64 exists)
10/// - `f64 → i32` is invalid (no lossless conversion)
11pub trait NumericWeight: Clone + Default + PartialOrd + num_traits::Num + num_traits::Zero + std::ops::AddAssign + 'static {}
12
13// Blanket implementation for any type satisfying the bounds
14impl<T> NumericWeight for T where T: Clone + Default + PartialOrd + num_traits::Num + num_traits::Zero + std::ops::AddAssign + 'static {}
15
16/// Specifies whether larger or smaller objective values are better.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub enum EnergyMode {
19    /// Larger objective values are better (maximization).
20    LargerSizeIsBetter,
21    /// Smaller objective values are better (minimization).
22    SmallerSizeIsBetter,
23}
24
25impl EnergyMode {
26    /// Returns true if this mode prefers larger values.
27    pub fn is_maximization(&self) -> bool {
28        matches!(self, EnergyMode::LargerSizeIsBetter)
29    }
30
31    /// Returns true if this mode prefers smaller values.
32    pub fn is_minimization(&self) -> bool {
33        matches!(self, EnergyMode::SmallerSizeIsBetter)
34    }
35
36    /// Compare two values according to this energy mode.
37    /// Returns true if `a` is better than `b`.
38    pub fn is_better<T: PartialOrd>(&self, a: &T, b: &T) -> bool {
39        match self {
40            EnergyMode::LargerSizeIsBetter => a > b,
41            EnergyMode::SmallerSizeIsBetter => a < b,
42        }
43    }
44
45    /// Compare two values according to this energy mode.
46    /// Returns true if `a` is better than or equal to `b`.
47    pub fn is_better_or_equal<T: PartialOrd>(&self, a: &T, b: &T) -> bool {
48        match self {
49            EnergyMode::LargerSizeIsBetter => a >= b,
50            EnergyMode::SmallerSizeIsBetter => a <= b,
51        }
52    }
53}
54
55/// The result of evaluating a solution's size/energy.
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
57pub struct SolutionSize<T> {
58    /// The objective value of the solution.
59    pub size: T,
60    /// Whether the solution satisfies all constraints.
61    pub is_valid: bool,
62}
63
64impl<T> SolutionSize<T> {
65    /// Create a new valid solution size.
66    pub fn valid(size: T) -> Self {
67        Self {
68            size,
69            is_valid: true,
70        }
71    }
72
73    /// Create a new invalid solution size.
74    pub fn invalid(size: T) -> Self {
75        Self {
76            size,
77            is_valid: false,
78        }
79    }
80
81    /// Create a new solution size with explicit validity.
82    pub fn new(size: T, is_valid: bool) -> Self {
83        Self { size, is_valid }
84    }
85}
86
87impl<T: fmt::Display> fmt::Display for SolutionSize<T> {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        write!(
90            f,
91            "SolutionSize({}, {})",
92            self.size,
93            if self.is_valid { "valid" } else { "invalid" }
94        )
95    }
96}
97
98/// Problem size metadata (varies by problem type).
99#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
100pub struct ProblemSize {
101    /// Named size components.
102    pub components: Vec<(String, usize)>,
103}
104
105impl ProblemSize {
106    /// Create a new problem size with named components.
107    pub fn new(components: Vec<(&str, usize)>) -> Self {
108        Self {
109            components: components
110                .into_iter()
111                .map(|(k, v)| (k.to_string(), v))
112                .collect(),
113        }
114    }
115
116    /// Get a size component by name.
117    pub fn get(&self, name: &str) -> Option<usize> {
118        self.components
119            .iter()
120            .find(|(k, _)| k == name)
121            .map(|(_, v)| *v)
122    }
123}
124
125impl fmt::Display for ProblemSize {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(f, "ProblemSize{{")?;
128        for (i, (name, value)) in self.components.iter().enumerate() {
129            if i > 0 {
130                write!(f, ", ")?;
131            }
132            write!(f, "{}: {}", name, value)?;
133        }
134        write!(f, "}}")
135    }
136}
137
138/// A local constraint on a subset of variables.
139///
140/// The constraint specifies which configurations of the variables are valid.
141/// The `spec` vector is indexed by the configuration value (treating variables
142/// as digits in a base-`num_flavors` number).
143#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
144pub struct LocalConstraint {
145    /// Number of flavors (domain size) for each variable.
146    pub num_flavors: usize,
147    /// Indices of variables involved in this constraint.
148    pub variables: Vec<usize>,
149    /// Specification vector: `spec[config]` = true if config is valid.
150    /// Length must be num_flavors^variables.len().
151    pub spec: Vec<bool>,
152}
153
154impl LocalConstraint {
155    /// Create a new local constraint.
156    pub fn new(num_flavors: usize, variables: Vec<usize>, spec: Vec<bool>) -> Self {
157        debug_assert_eq!(
158            spec.len(),
159            num_flavors.pow(variables.len() as u32),
160            "spec length must be num_flavors^num_variables"
161        );
162        Self {
163            num_flavors,
164            variables,
165            spec,
166        }
167    }
168
169    /// Check if a configuration satisfies this constraint.
170    pub fn is_satisfied(&self, config: &[usize]) -> bool {
171        let index = self.config_to_index(config);
172        self.spec.get(index).copied().unwrap_or(false)
173    }
174
175    /// Convert a full configuration to an index into the spec vector.
176    fn config_to_index(&self, config: &[usize]) -> usize {
177        let mut index = 0;
178        for (i, &var) in self.variables.iter().enumerate() {
179            let value = config.get(var).copied().unwrap_or(0);
180            index += value * self.num_flavors.pow((self.variables.len() - 1 - i) as u32);
181        }
182        index
183    }
184
185    /// Get the number of variables in this constraint.
186    pub fn num_variables(&self) -> usize {
187        self.variables.len()
188    }
189}
190
191/// A local contribution to the solution size from a subset of variables.
192///
193/// Similar to LocalConstraint but stores objective values instead of validity.
194#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
195pub struct LocalSolutionSize<T> {
196    /// Number of flavors (domain size) for each variable.
197    pub num_flavors: usize,
198    /// Indices of variables involved.
199    pub variables: Vec<usize>,
200    /// Specification vector: `spec[config]` = contribution for that config.
201    /// Length must be num_flavors^variables.len().
202    pub spec: Vec<T>,
203}
204
205impl<T: Clone> LocalSolutionSize<T> {
206    /// Create a new local solution size.
207    pub fn new(num_flavors: usize, variables: Vec<usize>, spec: Vec<T>) -> Self {
208        debug_assert_eq!(
209            spec.len(),
210            num_flavors.pow(variables.len() as u32),
211            "spec length must be num_flavors^num_variables"
212        );
213        Self {
214            num_flavors,
215            variables,
216            spec,
217        }
218    }
219
220    /// Get the contribution from a configuration.
221    pub fn evaluate(&self, config: &[usize]) -> T {
222        let index = self.config_to_index(config);
223        self.spec[index].clone()
224    }
225
226    /// Convert a full configuration to an index into the spec vector.
227    fn config_to_index(&self, config: &[usize]) -> usize {
228        let mut index = 0;
229        for (i, &var) in self.variables.iter().enumerate() {
230            let value = config.get(var).copied().unwrap_or(0);
231            index += value * self.num_flavors.pow((self.variables.len() - 1 - i) as u32);
232        }
233        index
234    }
235
236    /// Get the number of variables in this local objective.
237    pub fn num_variables(&self) -> usize {
238        self.variables.len()
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_energy_mode() {
248        let max_mode = EnergyMode::LargerSizeIsBetter;
249        let min_mode = EnergyMode::SmallerSizeIsBetter;
250
251        assert!(max_mode.is_maximization());
252        assert!(!max_mode.is_minimization());
253        assert!(!min_mode.is_maximization());
254        assert!(min_mode.is_minimization());
255
256        assert!(max_mode.is_better(&10, &5));
257        assert!(!max_mode.is_better(&5, &10));
258        assert!(min_mode.is_better(&5, &10));
259        assert!(!min_mode.is_better(&10, &5));
260
261        assert!(max_mode.is_better_or_equal(&10, &10));
262        assert!(min_mode.is_better_or_equal(&10, &10));
263    }
264
265    #[test]
266    fn test_solution_size() {
267        let valid = SolutionSize::valid(42);
268        assert_eq!(valid.size, 42);
269        assert!(valid.is_valid);
270
271        let invalid = SolutionSize::invalid(0);
272        assert!(!invalid.is_valid);
273
274        let custom = SolutionSize::new(100, false);
275        assert_eq!(custom.size, 100);
276        assert!(!custom.is_valid);
277    }
278
279    #[test]
280    fn test_solution_size_display() {
281        let valid = SolutionSize::valid(42);
282        assert_eq!(format!("{}", valid), "SolutionSize(42, valid)");
283
284        let invalid = SolutionSize::invalid(0);
285        assert_eq!(format!("{}", invalid), "SolutionSize(0, invalid)");
286    }
287
288    #[test]
289    fn test_problem_size() {
290        let ps = ProblemSize::new(vec![("vertices", 10), ("edges", 20)]);
291        assert_eq!(ps.get("vertices"), Some(10));
292        assert_eq!(ps.get("edges"), Some(20));
293        assert_eq!(ps.get("unknown"), None);
294    }
295
296    #[test]
297    fn test_problem_size_display() {
298        let ps = ProblemSize::new(vec![("vertices", 10), ("edges", 20)]);
299        assert_eq!(format!("{}", ps), "ProblemSize{vertices: 10, edges: 20}");
300
301        let empty = ProblemSize::new(vec![]);
302        assert_eq!(format!("{}", empty), "ProblemSize{}");
303
304        let single = ProblemSize::new(vec![("n", 5)]);
305        assert_eq!(format!("{}", single), "ProblemSize{n: 5}");
306    }
307
308    #[test]
309    fn test_local_constraint() {
310        // Binary constraint on 2 variables: only (0,0) and (1,1) are valid
311        let constraint = LocalConstraint::new(2, vec![0, 1], vec![true, false, false, true]);
312
313        assert!(constraint.is_satisfied(&[0, 0]));
314        assert!(!constraint.is_satisfied(&[0, 1]));
315        assert!(!constraint.is_satisfied(&[1, 0]));
316        assert!(constraint.is_satisfied(&[1, 1]));
317        assert_eq!(constraint.num_variables(), 2);
318    }
319
320    #[test]
321    fn test_local_constraint_out_of_bounds() {
322        let constraint = LocalConstraint::new(2, vec![5, 6], vec![true, false, false, true]);
323        // Test with config that doesn't have indices 5 and 6 - defaults to 0
324        assert!(constraint.is_satisfied(&[0, 0, 0]));
325    }
326
327    #[test]
328    fn test_local_solution_size() {
329        // Binary objective on 1 variable: weight 0 for 0, weight 5 for 1
330        let objective = LocalSolutionSize::new(2, vec![0], vec![0, 5]);
331
332        assert_eq!(objective.evaluate(&[0]), 0);
333        assert_eq!(objective.evaluate(&[1]), 5);
334        assert_eq!(objective.num_variables(), 1);
335    }
336
337    #[test]
338    fn test_local_solution_size_multi_variable() {
339        // Binary objective on 2 variables
340        let objective = LocalSolutionSize::new(2, vec![0, 1], vec![0, 1, 2, 3]);
341        assert_eq!(objective.evaluate(&[0, 0]), 0);
342        assert_eq!(objective.evaluate(&[0, 1]), 1);
343        assert_eq!(objective.evaluate(&[1, 0]), 2);
344        assert_eq!(objective.evaluate(&[1, 1]), 3);
345    }
346
347    #[test]
348    fn test_numeric_weight_impls() {
349        fn assert_numeric_weight<T: NumericWeight>() {}
350
351        assert_numeric_weight::<i32>();
352        assert_numeric_weight::<f64>();
353        assert_numeric_weight::<i64>();
354        assert_numeric_weight::<f32>();
355    }
356}