Skip to main content

problemreductions/models/misc/
minimum_decision_tree.rs

1//! Minimum Decision Tree problem implementation.
2//!
3//! Given a set of objects distinguished by binary tests, find a decision tree
4//! that identifies each object with minimum total external path length
5//! (sum of depths of all leaves).
6
7use crate::registry::{FieldInfo, ProblemSchemaEntry};
8use crate::traits::Problem;
9use crate::types::Min;
10use serde::{Deserialize, Serialize};
11
12inventory::submit! {
13    ProblemSchemaEntry {
14        name: "MinimumDecisionTree",
15        display_name: "Minimum Decision Tree",
16        aliases: &[],
17        dimensions: &[],
18        module_path: module_path!(),
19        description: "Find decision tree identifying objects with minimum total path length",
20        fields: &[
21            FieldInfo { name: "test_matrix", type_name: "Vec<Vec<bool>>", description: "Binary matrix: test_matrix[j][i] = object i passes test j" },
22            FieldInfo { name: "num_objects", type_name: "usize", description: "Number of objects to identify" },
23            FieldInfo { name: "num_tests", type_name: "usize", description: "Number of available binary tests" },
24        ],
25    }
26}
27
28/// Minimum Decision Tree problem.
29///
30/// Given objects distinguished by binary tests, find a decision tree
31/// minimizing the total external path length (sum of leaf depths).
32///
33/// The configuration encodes a flattened complete binary tree of depth
34/// `num_objects - 1`. Each internal node stores either a test index
35/// (0..num_tests-1) or a sentinel value `num_tests` meaning "leaf".
36///
37/// # Example
38///
39/// ```
40/// use problemreductions::models::misc::MinimumDecisionTree;
41/// use problemreductions::{Problem, Solver, BruteForce};
42///
43/// let problem = MinimumDecisionTree::new(
44///     vec![
45///         vec![true, true, false, false],   // T0
46///         vec![true, false, false, false],   // T1
47///         vec![false, true, false, true],    // T2
48///     ],
49///     4,
50///     3,
51/// );
52/// let solver = BruteForce::new();
53/// let value = solver.solve(&problem);
54/// ```
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct MinimumDecisionTree {
57    /// Binary matrix: test_matrix[j][i] = true iff object i passes test j.
58    test_matrix: Vec<Vec<bool>>,
59    /// Number of objects.
60    num_objects: usize,
61    /// Number of tests.
62    num_tests: usize,
63}
64
65impl MinimumDecisionTree {
66    /// Create a new MinimumDecisionTree problem.
67    ///
68    /// # Panics
69    /// - If num_objects < 2 or num_tests < 1
70    /// - If test_matrix dimensions don't match
71    /// - If tests don't distinguish all object pairs
72    pub fn new(test_matrix: Vec<Vec<bool>>, num_objects: usize, num_tests: usize) -> Self {
73        assert!(num_objects >= 2, "Need at least 2 objects");
74        assert!(num_tests >= 1, "Need at least 1 test");
75        assert_eq!(
76            test_matrix.len(),
77            num_tests,
78            "test_matrix must have num_tests rows"
79        );
80        for (j, row) in test_matrix.iter().enumerate() {
81            assert_eq!(
82                row.len(),
83                num_objects,
84                "test_matrix[{j}] must have num_objects columns"
85            );
86        }
87        // Check that every pair of objects is distinguished by at least one test
88        for a in 0..num_objects {
89            for b in (a + 1)..num_objects {
90                let distinguished = (0..num_tests).any(|j| test_matrix[j][a] != test_matrix[j][b]);
91                assert!(
92                    distinguished,
93                    "Objects {a} and {b} are not distinguished by any test"
94                );
95            }
96        }
97        Self {
98            test_matrix,
99            num_objects,
100            num_tests,
101        }
102    }
103
104    /// Get the number of objects.
105    pub fn num_objects(&self) -> usize {
106        self.num_objects
107    }
108
109    /// Get the number of tests.
110    pub fn num_tests(&self) -> usize {
111        self.num_tests
112    }
113
114    /// Get the test matrix.
115    pub fn test_matrix(&self) -> &[Vec<bool>] {
116        &self.test_matrix
117    }
118
119    /// Number of internal node slots in the flattened complete binary tree.
120    fn num_tree_slots(&self) -> usize {
121        (1usize << (self.num_objects - 1)) - 1
122    }
123
124    /// Sentinel value meaning "this node is a leaf".
125    fn leaf_sentinel(&self) -> usize {
126        self.num_tests
127    }
128
129    /// Simulate the decision tree for all objects and return total external path length,
130    /// or None if the tree is invalid (doesn't identify all objects uniquely).
131    fn simulate(&self, config: &[usize]) -> Option<usize> {
132        let sentinel = self.leaf_sentinel();
133        let max_slots = self.num_tree_slots();
134        let mut seen_leaves = std::collections::HashSet::new();
135        let mut total_depth = 0usize;
136
137        for obj in 0..self.num_objects {
138            let mut node = 0usize;
139            let mut depth = 0usize;
140
141            loop {
142                if node >= max_slots || config[node] == sentinel {
143                    // Two objects at same leaf — invalid
144                    if !seen_leaves.insert(node) {
145                        return None;
146                    }
147                    total_depth += depth;
148                    break;
149                }
150
151                let test_idx = config[node];
152                debug_assert!(test_idx < self.num_tests);
153
154                let result = self.test_matrix[test_idx][obj];
155                node = if result { 2 * node + 2 } else { 2 * node + 1 };
156                depth += 1;
157
158                if depth > self.num_objects {
159                    return None;
160                }
161            }
162        }
163
164        Some(total_depth)
165    }
166}
167
168impl Problem for MinimumDecisionTree {
169    const NAME: &'static str = "MinimumDecisionTree";
170    type Value = Min<usize>;
171
172    fn dims(&self) -> Vec<usize> {
173        // Each internal node can hold test 0..num_tests-1 or sentinel (leaf)
174        vec![self.num_tests + 1; self.num_tree_slots()]
175    }
176
177    fn evaluate(&self, config: &[usize]) -> Min<usize> {
178        if config.len() != self.num_tree_slots() {
179            return Min(None);
180        }
181        Min(self.simulate(config))
182    }
183
184    fn variant() -> Vec<(&'static str, &'static str)> {
185        crate::variant_params![]
186    }
187}
188
189crate::declare_variants! {
190    default MinimumDecisionTree => "num_tests^num_objects",
191}
192
193#[cfg(feature = "example-db")]
194pub(crate) fn canonical_model_example_specs() -> Vec<crate::example_db::specs::ModelExampleSpec> {
195    vec![crate::example_db::specs::ModelExampleSpec {
196        id: "minimum_decision_tree",
197        instance: Box::new(MinimumDecisionTree::new(
198            vec![
199                vec![true, true, false, false],
200                vec![true, false, false, false],
201                vec![false, true, false, true],
202            ],
203            4,
204            3,
205        )),
206        // T0 at root, T2 left, T1 right, rest are leaves (sentinel=3)
207        optimal_config: vec![0, 2, 1, 3, 3, 3, 3],
208        optimal_value: serde_json::json!(8),
209    }]
210}
211
212#[cfg(test)]
213#[path = "../../unit_tests/models/misc/minimum_decision_tree.rs"]
214mod tests;