problemreductions/models/misc/
minimum_decision_tree.rs1use crate::registry::{FieldInfo, ProblemSchemaEntry};
8use crate::traits::Problem;
9use crate::types::Min;
10use serde::{Deserialize, Serialize};
11
12inventory::submit! {
13 ProblemSchemaEntry {
14 name: "MinimumDecisionTree",
15 display_name: "Minimum Decision Tree",
16 aliases: &[],
17 dimensions: &[],
18 module_path: module_path!(),
19 description: "Find decision tree identifying objects with minimum total path length",
20 fields: &[
21 FieldInfo { name: "test_matrix", type_name: "Vec<Vec<bool>>", description: "Binary matrix: test_matrix[j][i] = object i passes test j" },
22 FieldInfo { name: "num_objects", type_name: "usize", description: "Number of objects to identify" },
23 FieldInfo { name: "num_tests", type_name: "usize", description: "Number of available binary tests" },
24 ],
25 }
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct MinimumDecisionTree {
57 test_matrix: Vec<Vec<bool>>,
59 num_objects: usize,
61 num_tests: usize,
63}
64
65impl MinimumDecisionTree {
66 pub fn new(test_matrix: Vec<Vec<bool>>, num_objects: usize, num_tests: usize) -> Self {
73 assert!(num_objects >= 2, "Need at least 2 objects");
74 assert!(num_tests >= 1, "Need at least 1 test");
75 assert_eq!(
76 test_matrix.len(),
77 num_tests,
78 "test_matrix must have num_tests rows"
79 );
80 for (j, row) in test_matrix.iter().enumerate() {
81 assert_eq!(
82 row.len(),
83 num_objects,
84 "test_matrix[{j}] must have num_objects columns"
85 );
86 }
87 for a in 0..num_objects {
89 for b in (a + 1)..num_objects {
90 let distinguished = (0..num_tests).any(|j| test_matrix[j][a] != test_matrix[j][b]);
91 assert!(
92 distinguished,
93 "Objects {a} and {b} are not distinguished by any test"
94 );
95 }
96 }
97 Self {
98 test_matrix,
99 num_objects,
100 num_tests,
101 }
102 }
103
104 pub fn num_objects(&self) -> usize {
106 self.num_objects
107 }
108
109 pub fn num_tests(&self) -> usize {
111 self.num_tests
112 }
113
114 pub fn test_matrix(&self) -> &[Vec<bool>] {
116 &self.test_matrix
117 }
118
119 fn num_tree_slots(&self) -> usize {
121 (1usize << (self.num_objects - 1)) - 1
122 }
123
124 fn leaf_sentinel(&self) -> usize {
126 self.num_tests
127 }
128
129 fn simulate(&self, config: &[usize]) -> Option<usize> {
132 let sentinel = self.leaf_sentinel();
133 let max_slots = self.num_tree_slots();
134 let mut seen_leaves = std::collections::HashSet::new();
135 let mut total_depth = 0usize;
136
137 for obj in 0..self.num_objects {
138 let mut node = 0usize;
139 let mut depth = 0usize;
140
141 loop {
142 if node >= max_slots || config[node] == sentinel {
143 if !seen_leaves.insert(node) {
145 return None;
146 }
147 total_depth += depth;
148 break;
149 }
150
151 let test_idx = config[node];
152 debug_assert!(test_idx < self.num_tests);
153
154 let result = self.test_matrix[test_idx][obj];
155 node = if result { 2 * node + 2 } else { 2 * node + 1 };
156 depth += 1;
157
158 if depth > self.num_objects {
159 return None;
160 }
161 }
162 }
163
164 Some(total_depth)
165 }
166}
167
168impl Problem for MinimumDecisionTree {
169 const NAME: &'static str = "MinimumDecisionTree";
170 type Value = Min<usize>;
171
172 fn dims(&self) -> Vec<usize> {
173 vec![self.num_tests + 1; self.num_tree_slots()]
175 }
176
177 fn evaluate(&self, config: &[usize]) -> Min<usize> {
178 if config.len() != self.num_tree_slots() {
179 return Min(None);
180 }
181 Min(self.simulate(config))
182 }
183
184 fn variant() -> Vec<(&'static str, &'static str)> {
185 crate::variant_params![]
186 }
187}
188
189crate::declare_variants! {
190 default MinimumDecisionTree => "num_tests^num_objects",
191}
192
193#[cfg(feature = "example-db")]
194pub(crate) fn canonical_model_example_specs() -> Vec<crate::example_db::specs::ModelExampleSpec> {
195 vec![crate::example_db::specs::ModelExampleSpec {
196 id: "minimum_decision_tree",
197 instance: Box::new(MinimumDecisionTree::new(
198 vec![
199 vec![true, true, false, false],
200 vec![true, false, false, false],
201 vec![false, true, false, true],
202 ],
203 4,
204 3,
205 )),
206 optimal_config: vec![0, 2, 1, 3, 3, 3, 3],
208 optimal_value: serde_json::json!(8),
209 }]
210}
211
212#[cfg(test)]
213#[path = "../../unit_tests/models/misc/minimum_decision_tree.rs"]
214mod tests;