problemreductions/rules/
cost.rs1use crate::rules::registry::ReductionOverhead;
4use crate::types::ProblemSize;
5
6pub trait PathCostFn {
8 fn edge_cost(&self, overhead: &ReductionOverhead, current_size: &ProblemSize) -> f64;
10}
11
12pub struct Minimize(pub &'static str);
14
15impl PathCostFn for Minimize {
16 fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 {
17 overhead.evaluate_output_size(size).get(self.0).unwrap_or(0) as f64
18 }
19}
20
21pub struct MinimizeWeighted(pub Vec<(&'static str, f64)>);
23
24impl PathCostFn for MinimizeWeighted {
25 fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 {
26 let output = overhead.evaluate_output_size(size);
27 self.0.iter()
28 .map(|(field, weight)| weight * output.get(field).unwrap_or(0) as f64)
29 .sum()
30 }
31}
32
33pub struct MinimizeMax(pub Vec<&'static str>);
35
36impl PathCostFn for MinimizeMax {
37 fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 {
38 let output = overhead.evaluate_output_size(size);
39 self.0.iter()
40 .map(|field| output.get(field).unwrap_or(0) as f64)
41 .fold(0.0, f64::max)
42 }
43}
44
45pub struct MinimizeLexicographic(pub Vec<&'static str>);
47
48impl PathCostFn for MinimizeLexicographic {
49 fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 {
50 let output = overhead.evaluate_output_size(size);
51 let mut cost = 0.0;
52 let mut scale = 1.0;
53 for field in &self.0 {
54 cost += scale * output.get(field).unwrap_or(0) as f64;
55 scale *= 1e-10;
56 }
57 cost
58 }
59}
60
61pub struct MinimizeSteps;
63
64impl PathCostFn for MinimizeSteps {
65 fn edge_cost(&self, _overhead: &ReductionOverhead, _size: &ProblemSize) -> f64 {
66 1.0
67 }
68}
69
70pub struct CustomCost<F>(pub F);
72
73impl<F: Fn(&ReductionOverhead, &ProblemSize) -> f64> PathCostFn for CustomCost<F> {
74 fn edge_cost(&self, overhead: &ReductionOverhead, size: &ProblemSize) -> f64 {
75 (self.0)(overhead, size)
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use crate::polynomial::Polynomial;
83
84 fn test_overhead() -> ReductionOverhead {
85 ReductionOverhead::new(vec![
86 ("n", Polynomial::var("n").scale(2.0)),
87 ("m", Polynomial::var("m")),
88 ])
89 }
90
91 #[test]
92 fn test_minimize_single() {
93 let cost_fn = Minimize("n");
94 let size = ProblemSize::new(vec![("n", 10), ("m", 5)]);
95 let overhead = test_overhead();
96
97 assert_eq!(cost_fn.edge_cost(&overhead, &size), 20.0); }
99
100 #[test]
101 fn test_minimize_weighted() {
102 let cost_fn = MinimizeWeighted(vec![("n", 1.0), ("m", 2.0)]);
103 let size = ProblemSize::new(vec![("n", 10), ("m", 5)]);
104 let overhead = test_overhead();
105
106 assert_eq!(cost_fn.edge_cost(&overhead, &size), 30.0);
109 }
110
111 #[test]
112 fn test_minimize_steps() {
113 let cost_fn = MinimizeSteps;
114 let size = ProblemSize::new(vec![("n", 100)]);
115 let overhead = test_overhead();
116
117 assert_eq!(cost_fn.edge_cost(&overhead, &size), 1.0);
118 }
119
120 #[test]
121 fn test_minimize_max() {
122 let cost_fn = MinimizeMax(vec!["n", "m"]);
123 let size = ProblemSize::new(vec![("n", 10), ("m", 5)]);
124 let overhead = test_overhead();
125
126 assert_eq!(cost_fn.edge_cost(&overhead, &size), 20.0);
129 }
130
131 #[test]
132 fn test_minimize_lexicographic() {
133 let cost_fn = MinimizeLexicographic(vec!["n", "m"]);
134 let size = ProblemSize::new(vec![("n", 10), ("m", 5)]);
135 let overhead = test_overhead();
136
137 let cost = cost_fn.edge_cost(&overhead, &size);
140 assert!(cost > 20.0 && cost < 20.001);
141 }
142
143 #[test]
144 fn test_custom_cost() {
145 let cost_fn = CustomCost(|overhead: &ReductionOverhead, size: &ProblemSize| {
146 let output = overhead.evaluate_output_size(size);
147 (output.get("n").unwrap_or(0) + output.get("m").unwrap_or(0)) as f64
148 });
149 let size = ProblemSize::new(vec![("n", 10), ("m", 5)]);
150 let overhead = test_overhead();
151
152 assert_eq!(cost_fn.edge_cost(&overhead, &size), 25.0);
155 }
156
157 #[test]
158 fn test_minimize_missing_field() {
159 let cost_fn = Minimize("nonexistent");
160 let size = ProblemSize::new(vec![("n", 10)]);
161 let overhead = test_overhead();
162
163 assert_eq!(cost_fn.edge_cost(&overhead, &size), 0.0);
164 }
165
166 #[test]
167 fn test_minimize_max_empty() {
168 let cost_fn = MinimizeMax(vec![]);
169 let size = ProblemSize::new(vec![("n", 10)]);
170 let overhead = test_overhead();
171
172 assert_eq!(cost_fn.edge_cost(&overhead, &size), 0.0);
173 }
174}