problemreductions/rules/
cost.rs

1//! Cost functions for reduction path optimization.
2
3use crate::rules::registry::ReductionOverhead;
4use crate::types::ProblemSize;
5
6/// User-defined cost function for path optimization.
7pub trait PathCostFn {
8    /// Compute cost of taking an edge given current problem size.
9    fn edge_cost(&self, overhead: &ReductionOverhead, current_size: &ProblemSize) -> f64;
10}
11
12/// Minimize a single output field.
13pub struct Minimize(pub &'static str);
14
15impl PathCostFn for Minimize {
16    fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 {
17        overhead.evaluate_output_size(size).get(self.0).unwrap_or(0) as f64
18    }
19}
20
21/// Minimize weighted sum of output fields.
22pub struct MinimizeWeighted(pub Vec<(&'static str, f64)>);
23
24impl PathCostFn for MinimizeWeighted {
25    fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 {
26        let output = overhead.evaluate_output_size(size);
27        self.0.iter()
28            .map(|(field, weight)| weight * output.get(field).unwrap_or(0) as f64)
29            .sum()
30    }
31}
32
33/// Minimize the maximum of specified fields.
34pub struct MinimizeMax(pub Vec<&'static str>);
35
36impl PathCostFn for MinimizeMax {
37    fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 {
38        let output = overhead.evaluate_output_size(size);
39        self.0.iter()
40            .map(|field| output.get(field).unwrap_or(0) as f64)
41            .fold(0.0, f64::max)
42    }
43}
44
45/// Lexicographic: minimize first field, break ties with subsequent.
46pub struct MinimizeLexicographic(pub Vec<&'static str>);
47
48impl PathCostFn for MinimizeLexicographic {
49    fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 {
50        let output = overhead.evaluate_output_size(size);
51        let mut cost = 0.0;
52        let mut scale = 1.0;
53        for field in &self.0 {
54            cost += scale * output.get(field).unwrap_or(0) as f64;
55            scale *= 1e-10;
56        }
57        cost
58    }
59}
60
61/// Minimize number of reduction steps.
62pub struct MinimizeSteps;
63
64impl PathCostFn for MinimizeSteps {
65    fn edge_cost(&self, _overhead: &ReductionOverhead, _size: &ProblemSize) -> f64 {
66        1.0
67    }
68}
69
70/// Custom cost function from closure.
71pub struct CustomCost<F>(pub F);
72
73impl<F: Fn(&ReductionOverhead, &ProblemSize) -> f64> PathCostFn for CustomCost<F> {
74    fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 {
75        (self.0)(overhead, size)
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use crate::polynomial::Polynomial;
83
84    fn test_overhead() -> ReductionOverhead {
85        ReductionOverhead::new(vec![
86            ("n", Polynomial::var("n").scale(2.0)),
87            ("m", Polynomial::var("m")),
88        ])
89    }
90
91    #[test]
92    fn test_minimize_single() {
93        let cost_fn = Minimize("n");
94        let size = ProblemSize::new(vec![("n", 10), ("m", 5)]);
95        let overhead = test_overhead();
96
97        assert_eq!(cost_fn.edge_cost(&overhead, &size), 20.0);  // 2 * 10
98    }
99
100    #[test]
101    fn test_minimize_weighted() {
102        let cost_fn = MinimizeWeighted(vec![("n", 1.0), ("m", 2.0)]);
103        let size = ProblemSize::new(vec![("n", 10), ("m", 5)]);
104        let overhead = test_overhead();
105
106        // output n = 20, output m = 5
107        // cost = 1.0 * 20 + 2.0 * 5 = 30
108        assert_eq!(cost_fn.edge_cost(&overhead, &size), 30.0);
109    }
110
111    #[test]
112    fn test_minimize_steps() {
113        let cost_fn = MinimizeSteps;
114        let size = ProblemSize::new(vec![("n", 100)]);
115        let overhead = test_overhead();
116
117        assert_eq!(cost_fn.edge_cost(&overhead, &size), 1.0);
118    }
119
120    #[test]
121    fn test_minimize_max() {
122        let cost_fn = MinimizeMax(vec!["n", "m"]);
123        let size = ProblemSize::new(vec![("n", 10), ("m", 5)]);
124        let overhead = test_overhead();
125
126        // output n = 20, output m = 5
127        // max(20, 5) = 20
128        assert_eq!(cost_fn.edge_cost(&overhead, &size), 20.0);
129    }
130
131    #[test]
132    fn test_minimize_lexicographic() {
133        let cost_fn = MinimizeLexicographic(vec!["n", "m"]);
134        let size = ProblemSize::new(vec![("n", 10), ("m", 5)]);
135        let overhead = test_overhead();
136
137        // output n = 20, output m = 5
138        // cost = 20 * 1.0 + 5 * 1e-10 = 20.0000000005
139        let cost = cost_fn.edge_cost(&overhead, &size);
140        assert!(cost > 20.0 && cost < 20.001);
141    }
142
143    #[test]
144    fn test_custom_cost() {
145        let cost_fn = CustomCost(|overhead: &ReductionOverhead, size: &ProblemSize| {
146            let output = overhead.evaluate_output_size(size);
147            (output.get("n").unwrap_or(0) + output.get("m").unwrap_or(0)) as f64
148        });
149        let size = ProblemSize::new(vec![("n", 10), ("m", 5)]);
150        let overhead = test_overhead();
151
152        // output n = 20, output m = 5
153        // custom = 20 + 5 = 25
154        assert_eq!(cost_fn.edge_cost(&overhead, &size), 25.0);
155    }
156
157    #[test]
158    fn test_minimize_missing_field() {
159        let cost_fn = Minimize("nonexistent");
160        let size = ProblemSize::new(vec![("n", 10)]);
161        let overhead = test_overhead();
162
163        assert_eq!(cost_fn.edge_cost(&overhead, &size), 0.0);
164    }
165
166    #[test]
167    fn test_minimize_max_empty() {
168        let cost_fn = MinimizeMax(vec![]);
169        let size = ProblemSize::new(vec![("n", 10)]);
170        let overhead = test_overhead();
171
172        assert_eq!(cost_fn.edge_cost(&overhead, &size), 0.0);
173    }
174}