Skip to main content

problemreductions/
types.rs

1//! Common types used across the problemreductions library.
2
3use serde::de::{self, DeserializeOwned, Visitor};
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::fmt;
6
7/// Bound for objective value types (i32, f64, etc.)
8pub trait NumericSize:
9    Clone
10    + Default
11    + PartialOrd
12    + num_traits::Num
13    + num_traits::Zero
14    + num_traits::Bounded
15    + std::ops::AddAssign
16    + 'static
17{
18}
19
20impl<T> NumericSize for T where
21    T: Clone
22        + Default
23        + PartialOrd
24        + num_traits::Num
25        + num_traits::Zero
26        + num_traits::Bounded
27        + std::ops::AddAssign
28        + 'static
29{
30}
31
32/// Maps a weight element to its sum/metric type.
33///
34/// This decouples the per-element weight type from the accumulation type.
35/// For concrete weights (`i32`, `f64`), `Sum` is the same type.
36/// For the unit weight `One`, `Sum = i32`.
37pub trait WeightElement: Clone + Default + 'static {
38    /// The numeric type used for sums and comparisons.
39    type Sum: NumericSize;
40    /// Whether this is the unit weight type (`One`).
41    const IS_UNIT: bool;
42    /// Convert this weight element to the sum type.
43    fn to_sum(&self) -> Self::Sum;
44}
45
46impl WeightElement for i32 {
47    type Sum = i32;
48    const IS_UNIT: bool = false;
49    fn to_sum(&self) -> i32 {
50        *self
51    }
52}
53
54impl WeightElement for f64 {
55    type Sum = f64;
56    const IS_UNIT: bool = false;
57    fn to_sum(&self) -> f64 {
58        *self
59    }
60}
61
62/// The constant 1. Unit weight for unweighted problems.
63///
64/// When used as the weight type parameter `W`, indicates that all weights
65/// are uniformly 1. `One::to_sum()` returns `1i32`.
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
67pub struct One;
68
69impl Serialize for One {
70    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
71    where
72        S: Serializer,
73    {
74        serializer.serialize_i32(1)
75    }
76}
77
78impl<'de> Deserialize<'de> for One {
79    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
80    where
81        D: Deserializer<'de>,
82    {
83        struct OneVisitor;
84
85        impl<'de> Visitor<'de> for OneVisitor {
86            type Value = One;
87
88            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89                formatter.write_str("the unit weight `One` encoded as 1 or unit/null")
90            }
91
92            fn visit_i64<E>(self, value: i64) -> Result<One, E>
93            where
94                E: de::Error,
95            {
96                if value == 1 {
97                    Ok(One)
98                } else {
99                    Err(E::custom(format!("expected 1 for One, got {value}")))
100                }
101            }
102
103            fn visit_u64<E>(self, value: u64) -> Result<One, E>
104            where
105                E: de::Error,
106            {
107                if value == 1 {
108                    Ok(One)
109                } else {
110                    Err(E::custom(format!("expected 1 for One, got {value}")))
111                }
112            }
113
114            fn visit_unit<E>(self) -> Result<One, E>
115            where
116                E: de::Error,
117            {
118                Ok(One)
119            }
120
121            fn visit_none<E>(self) -> Result<One, E>
122            where
123                E: de::Error,
124            {
125                Ok(One)
126            }
127
128            fn visit_str<E>(self, value: &str) -> Result<One, E>
129            where
130                E: de::Error,
131            {
132                if value == "One" {
133                    Ok(One)
134                } else {
135                    Err(E::custom(format!("expected \"One\" for One, got {value}")))
136                }
137            }
138        }
139
140        deserializer.deserialize_any(OneVisitor)
141    }
142}
143
144impl WeightElement for One {
145    type Sum = i32;
146    const IS_UNIT: bool = true;
147    fn to_sum(&self) -> i32 {
148        1
149    }
150}
151
152impl std::fmt::Display for One {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        write!(f, "One")
155    }
156}
157
158impl From<i32> for One {
159    fn from(_: i32) -> Self {
160        One
161    }
162}
163
164/// Backward-compatible alias for `One`.
165pub type Unweighted = One;
166
167/// Foldable aggregate values for enumerating a problem's configuration space.
168pub trait Aggregate: Clone + fmt::Debug + Serialize + DeserializeOwned {
169    /// Neutral element for folding.
170    fn identity() -> Self;
171
172    /// Associative combine operation.
173    fn combine(self, other: Self) -> Self;
174
175    /// Whether this aggregate admits representative witness configurations.
176    fn supports_witnesses() -> bool {
177        false
178    }
179
180    /// Whether a configuration-level value belongs to the witness set
181    /// for the final aggregate value.
182    fn contributes_to_witnesses(_config_value: &Self, _total: &Self) -> bool {
183        false
184    }
185}
186
187/// Maximum aggregate over feasible values.
188#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
189pub struct Max<V>(pub Option<V>);
190
191impl<V: fmt::Debug + PartialOrd + Clone + Serialize + DeserializeOwned> Aggregate for Max<V> {
192    fn identity() -> Self {
193        Max(None)
194    }
195
196    fn combine(self, other: Self) -> Self {
197        use std::cmp::Ordering;
198
199        match (self.0, other.0) {
200            (None, rhs) => Max(rhs),
201            (lhs, None) => Max(lhs),
202            (Some(lhs), Some(rhs)) => {
203                let ord = lhs.partial_cmp(&rhs).expect("cannot compare values (NaN?)");
204                match ord {
205                    Ordering::Less => Max(Some(rhs)),
206                    Ordering::Equal | Ordering::Greater => Max(Some(lhs)),
207                }
208            }
209        }
210    }
211
212    fn supports_witnesses() -> bool {
213        true
214    }
215
216    fn contributes_to_witnesses(config_value: &Self, total: &Self) -> bool {
217        matches!((config_value, total), (Max(Some(value)), Max(Some(best))) if value == best)
218    }
219}
220
221impl<V: fmt::Display> fmt::Display for Max<V> {
222    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223        match &self.0 {
224            Some(value) => write!(f, "Max({value})"),
225            None => write!(f, "Max(None)"),
226        }
227    }
228}
229
230impl<V> Max<V> {
231    pub fn is_valid(&self) -> bool {
232        self.0.is_some()
233    }
234
235    pub fn size(&self) -> Option<&V> {
236        self.0.as_ref()
237    }
238
239    pub fn unwrap(self) -> V {
240        self.0.expect("called unwrap on invalid Max value")
241    }
242}
243
244/// Minimum aggregate over feasible values.
245#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
246pub struct Min<V>(pub Option<V>);
247
248impl<V: fmt::Debug + PartialOrd + Clone + Serialize + DeserializeOwned> Aggregate for Min<V> {
249    fn identity() -> Self {
250        Min(None)
251    }
252
253    fn combine(self, other: Self) -> Self {
254        use std::cmp::Ordering;
255
256        match (self.0, other.0) {
257            (None, rhs) => Min(rhs),
258            (lhs, None) => Min(lhs),
259            (Some(lhs), Some(rhs)) => {
260                let ord = lhs.partial_cmp(&rhs).expect("cannot compare values (NaN?)");
261                match ord {
262                    Ordering::Greater => Min(Some(rhs)),
263                    Ordering::Equal | Ordering::Less => Min(Some(lhs)),
264                }
265            }
266        }
267    }
268
269    fn supports_witnesses() -> bool {
270        true
271    }
272
273    fn contributes_to_witnesses(config_value: &Self, total: &Self) -> bool {
274        matches!((config_value, total), (Min(Some(value)), Min(Some(best))) if value == best)
275    }
276}
277
278impl<V: fmt::Display> fmt::Display for Min<V> {
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        match &self.0 {
281            Some(value) => write!(f, "Min({value})"),
282            None => write!(f, "Min(None)"),
283        }
284    }
285}
286
287impl<V> Min<V> {
288    pub fn is_valid(&self) -> bool {
289        self.0.is_some()
290    }
291
292    pub fn size(&self) -> Option<&V> {
293        self.0.as_ref()
294    }
295
296    pub fn unwrap(self) -> V {
297        self.0.expect("called unwrap on invalid Min value")
298    }
299}
300
301/// Trait for aggregate values that represent optimization objectives.
302pub trait OptimizationValue: Aggregate {
303    /// The inner numeric type used for comparisons with decision bounds.
304    type Inner: Clone + PartialOrd + fmt::Debug + Serialize + DeserializeOwned;
305
306    /// Whether this aggregate value satisfies the provided decision bound.
307    fn meets_bound(value: &Self, bound: &Self::Inner) -> bool;
308}
309
310impl<V: fmt::Debug + PartialOrd + Clone + Serialize + DeserializeOwned> OptimizationValue
311    for Min<V>
312{
313    type Inner = V;
314
315    fn meets_bound(value: &Self, bound: &V) -> bool {
316        matches!(&value.0, Some(v) if *v <= *bound)
317    }
318}
319
320impl<V: fmt::Debug + PartialOrd + Clone + Serialize + DeserializeOwned> OptimizationValue
321    for Max<V>
322{
323    type Inner = V;
324
325    fn meets_bound(value: &Self, bound: &V) -> bool {
326        matches!(&value.0, Some(v) if *v >= *bound)
327    }
328}
329
330/// Sum aggregate for value-only problems.
331#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
332pub struct Sum<W>(pub W);
333
334impl<W: fmt::Debug + NumericSize + Serialize + DeserializeOwned> Aggregate for Sum<W> {
335    fn identity() -> Self {
336        Sum(W::zero())
337    }
338
339    fn combine(self, other: Self) -> Self {
340        let mut total = self.0;
341        total += other.0;
342        Sum(total)
343    }
344}
345
346impl<W: fmt::Display> fmt::Display for Sum<W> {
347    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
348        write!(f, "Sum({})", self.0)
349    }
350}
351
352/// Disjunction aggregate for existential satisfaction.
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
354pub struct Or(pub bool);
355
356impl Or {
357    pub fn is_valid(&self) -> bool {
358        self.0
359    }
360
361    pub fn unwrap(self) -> bool {
362        self.0
363    }
364}
365
366impl Aggregate for Or {
367    fn identity() -> Self {
368        Or(false)
369    }
370
371    fn combine(self, other: Self) -> Self {
372        Or(self.0 || other.0)
373    }
374
375    fn supports_witnesses() -> bool {
376        true
377    }
378
379    fn contributes_to_witnesses(config_value: &Self, total: &Self) -> bool {
380        config_value.0 && total.0
381    }
382}
383
384impl fmt::Display for Or {
385    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386        write!(f, "Or({})", self.0)
387    }
388}
389
390impl std::ops::Not for Or {
391    type Output = bool;
392
393    fn not(self) -> Self::Output {
394        !self.0
395    }
396}
397
398impl PartialEq<bool> for Or {
399    fn eq(&self, other: &bool) -> bool {
400        self.0 == *other
401    }
402}
403
404impl PartialEq<Or> for bool {
405    fn eq(&self, other: &Or) -> bool {
406        *self == other.0
407    }
408}
409
410/// Conjunction aggregate for universal satisfaction.
411#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
412pub struct And(pub bool);
413
414impl Aggregate for And {
415    fn identity() -> Self {
416        And(true)
417    }
418
419    fn combine(self, other: Self) -> Self {
420        And(self.0 && other.0)
421    }
422}
423
424impl fmt::Display for And {
425    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426        write!(f, "And({})", self.0)
427    }
428}
429
430#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
431pub enum ExtremumSense {
432    Maximize,
433    Minimize,
434}
435
436#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
437pub struct Extremum<V> {
438    pub sense: ExtremumSense,
439    pub value: Option<V>,
440}
441
442impl<V> Extremum<V> {
443    pub fn maximize(value: Option<V>) -> Self {
444        Self {
445            sense: ExtremumSense::Maximize,
446            value,
447        }
448    }
449
450    pub fn minimize(value: Option<V>) -> Self {
451        Self {
452            sense: ExtremumSense::Minimize,
453            value,
454        }
455    }
456
457    pub fn is_valid(&self) -> bool {
458        self.value.is_some()
459    }
460
461    pub fn size(&self) -> Option<&V> {
462        self.value.as_ref()
463    }
464
465    pub fn unwrap(self) -> V {
466        self.value.expect("called unwrap on invalid Extremum value")
467    }
468}
469
470impl<V: fmt::Debug + PartialOrd + Clone + Serialize + DeserializeOwned> Aggregate for Extremum<V> {
471    fn identity() -> Self {
472        Self::maximize(None)
473    }
474
475    fn combine(self, other: Self) -> Self {
476        use std::cmp::Ordering;
477
478        match (self.value, other.value) {
479            (None, rhs) => Self {
480                sense: other.sense,
481                value: rhs,
482            },
483            (lhs, None) => Self {
484                sense: self.sense,
485                value: lhs,
486            },
487            (Some(lhs), Some(rhs)) => {
488                assert_eq!(
489                    self.sense, other.sense,
490                    "cannot combine Extremum values with different senses"
491                );
492                let ord = lhs.partial_cmp(&rhs).expect("cannot compare values (NaN?)");
493                let keep_self = match self.sense {
494                    ExtremumSense::Maximize => matches!(ord, Ordering::Equal | Ordering::Greater),
495                    ExtremumSense::Minimize => matches!(ord, Ordering::Equal | Ordering::Less),
496                };
497                if keep_self {
498                    Self {
499                        sense: self.sense,
500                        value: Some(lhs),
501                    }
502                } else {
503                    Self {
504                        sense: other.sense,
505                        value: Some(rhs),
506                    }
507                }
508            }
509        }
510    }
511
512    fn supports_witnesses() -> bool {
513        true
514    }
515
516    fn contributes_to_witnesses(config_value: &Self, total: &Self) -> bool {
517        matches!(
518            (config_value.value.as_ref(), total.value.as_ref()),
519            (Some(value), Some(best)) if config_value.sense == total.sense && value == best
520        )
521    }
522}
523
524impl<V: fmt::Display> fmt::Display for Extremum<V> {
525    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
526        match (&self.sense, &self.value) {
527            (ExtremumSense::Maximize, Some(value)) => write!(f, "Max({value})"),
528            (ExtremumSense::Maximize, None) => write!(f, "Max(None)"),
529            (ExtremumSense::Minimize, Some(value)) => write!(f, "Min({value})"),
530            (ExtremumSense::Minimize, None) => write!(f, "Min(None)"),
531        }
532    }
533}
534
535/// Problem size metadata (varies by problem type).
536#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
537pub struct ProblemSize {
538    /// Named size components.
539    pub components: Vec<(String, usize)>,
540}
541
542impl ProblemSize {
543    /// Create a new problem size with named components.
544    pub fn new(components: Vec<(&str, usize)>) -> Self {
545        Self {
546            components: components
547                .into_iter()
548                .map(|(k, v)| (k.to_string(), v))
549                .collect(),
550        }
551    }
552
553    /// Get a size component by name.
554    pub fn get(&self, name: &str) -> Option<usize> {
555        self.components
556            .iter()
557            .find(|(k, _)| k == name)
558            .map(|(_, v)| *v)
559    }
560
561    /// Sum of all component values.
562    pub fn total(&self) -> usize {
563        self.components.iter().map(|(_, v)| *v).sum()
564    }
565}
566
567impl fmt::Display for ProblemSize {
568    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
569        write!(f, "ProblemSize{{")?;
570        for (i, (name, value)) in self.components.iter().enumerate() {
571            if i > 0 {
572                write!(f, ", ")?;
573            }
574            write!(f, "{}: {}", name, value)?;
575        }
576        write!(f, "}}")
577    }
578}
579
580use crate::impl_variant_param;
581
582impl_variant_param!(f64, "weight");
583impl_variant_param!(i32, "weight", parent: f64, cast: |w| *w as f64);
584impl_variant_param!(One, "weight", parent: i32, cast: |_| 1i32);
585
586#[cfg(test)]
587#[path = "unit_tests/types.rs"]
588mod tests;
589
590#[cfg(test)]
591#[path = "unit_tests/types_optimization_value.rs"]
592mod optimization_value_tests;