problemreductions/models/graph/
steiner_tree.rs1use std::collections::BTreeSet;
7
8use num_traits::Zero;
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 registry::{FieldInfo, ProblemSchemaEntry, VariantDimension},
13 topology::{Graph, SimpleGraph},
14 traits::Problem,
15 types::{Min, One, WeightElement},
16};
17
18inventory::submit! {
19 ProblemSchemaEntry {
20 name: "SteinerTree",
21 display_name: "Steiner Tree",
22 aliases: &[],
23 dimensions: &[
24 VariantDimension::new("graph", "SimpleGraph", &["SimpleGraph"]),
25 VariantDimension::new("weight", "i32", &["One", "i32"]),
26 ],
27 module_path: module_path!(),
28 description: "Find minimum weight tree connecting terminal vertices",
29 fields: &[
30 FieldInfo { name: "graph", type_name: "G", description: "The underlying graph G=(V,E)" },
31 FieldInfo { name: "edge_weights", type_name: "Vec<W>", description: "Edge weights w: E -> R" },
32 FieldInfo { name: "terminals", type_name: "Vec<usize>", description: "Terminal vertices T that must be connected" },
33 ],
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SteinerTree<G, W> {
59 graph: G,
61 edge_weights: Vec<W>,
63 terminals: Vec<usize>,
65}
66
67impl<G: Graph, W: Clone + Default> SteinerTree<G, W> {
68 pub fn new(graph: G, edge_weights: Vec<W>, terminals: Vec<usize>) -> Self {
70 assert_eq!(
71 edge_weights.len(),
72 graph.num_edges(),
73 "edge_weights length must match num_edges"
74 );
75 let n = graph.num_vertices();
76 let distinct_terminals: BTreeSet<_> = terminals.iter().copied().collect();
77 assert_eq!(
78 distinct_terminals.len(),
79 terminals.len(),
80 "terminals must be distinct"
81 );
82 for &t in &terminals {
83 assert!(t < n, "terminal {t} out of range (num_vertices = {n})");
84 }
85 assert!(terminals.len() >= 2, "at least 2 terminals required");
86 Self {
87 graph,
88 edge_weights,
89 terminals,
90 }
91 }
92
93 pub fn unit_weights(graph: G, terminals: Vec<usize>) -> Self
95 where
96 W: From<i32>,
97 {
98 let edge_weights = vec![W::from(1); graph.num_edges()];
99 Self::new(graph, edge_weights, terminals)
100 }
101
102 pub fn graph(&self) -> &G {
104 &self.graph
105 }
106
107 pub fn edge_weights(&self) -> &[W] {
109 &self.edge_weights
110 }
111
112 pub fn set_weights(&mut self, weights: Vec<W>) {
114 assert_eq!(weights.len(), self.graph.num_edges());
115 self.edge_weights = weights;
116 }
117
118 pub fn weights(&self) -> Vec<W> {
120 self.edge_weights.clone()
121 }
122
123 pub fn terminals(&self) -> &[usize] {
125 &self.terminals
126 }
127
128 pub fn is_weighted(&self) -> bool
130 where
131 W: WeightElement,
132 {
133 !W::IS_UNIT
134 }
135
136 pub fn is_valid_solution(&self, config: &[usize]) -> bool {
138 is_valid_steiner_tree(&self.graph, &self.terminals, config)
139 }
140}
141
142impl<G: Graph, W: WeightElement> SteinerTree<G, W> {
143 pub fn num_vertices(&self) -> usize {
145 self.graph.num_vertices()
146 }
147
148 pub fn num_edges(&self) -> usize {
150 self.graph.num_edges()
151 }
152
153 pub fn num_terminals(&self) -> usize {
155 self.terminals.len()
156 }
157}
158
159fn is_valid_steiner_tree<G: Graph>(graph: &G, terminals: &[usize], config: &[usize]) -> bool {
163 let n = graph.num_vertices();
164 let edges = graph.edges();
165 if config.len() != edges.len() {
166 return false;
167 }
168
169 let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
171 let mut selected_count = 0usize;
172 let mut involved = vec![false; n];
173 for (idx, &sel) in config.iter().enumerate() {
174 if sel == 1 {
175 let (u, v) = edges[idx];
176 adj[u].push(v);
177 adj[v].push(u);
178 involved[u] = true;
179 involved[v] = true;
180 selected_count += 1;
181 }
182 }
183
184 if selected_count == 0 {
185 return false;
186 }
187
188 let start = terminals[0];
190 let mut visited = vec![false; n];
191 let mut queue = std::collections::VecDeque::new();
192 visited[start] = true;
193 queue.push_back(start);
194 while let Some(v) = queue.pop_front() {
195 for &u in &adj[v] {
196 if !visited[u] {
197 visited[u] = true;
198 queue.push_back(u);
199 }
200 }
201 }
202
203 if !terminals.iter().all(|&t| visited[t]) {
205 return false;
206 }
207
208 if (0..n).any(|i| involved[i] && !visited[i]) {
210 return false;
211 }
212
213 let involved_count = involved.iter().filter(|&&x| x).count();
215 selected_count == involved_count - 1
216}
217
218impl<G, W> Problem for SteinerTree<G, W>
219where
220 G: Graph + crate::variant::VariantParam,
221 W: WeightElement + crate::variant::VariantParam,
222{
223 const NAME: &'static str = "SteinerTree";
224 type Value = Min<W::Sum>;
225
226 fn variant() -> Vec<(&'static str, &'static str)> {
227 crate::variant_params![G, W]
228 }
229
230 fn dims(&self) -> Vec<usize> {
231 vec![2; self.graph.num_edges()]
232 }
233
234 fn evaluate(&self, config: &[usize]) -> Min<W::Sum> {
235 if !is_valid_steiner_tree(&self.graph, &self.terminals, config) {
236 return Min(None);
237 }
238 let mut total = W::Sum::zero();
239 for (idx, &selected) in config.iter().enumerate() {
240 if selected == 1 {
241 if let Some(w) = self.edge_weights.get(idx) {
242 total += w.to_sum();
243 }
244 }
245 }
246 Min(Some(total))
247 }
248}
249
250crate::declare_variants! {
251 default SteinerTree<SimpleGraph, i32> => "3^num_terminals * num_vertices + 2^num_terminals * num_vertices^2",
252 SteinerTree<SimpleGraph, One> => "3^num_terminals * num_vertices + 2^num_terminals * num_vertices^2",
253}
254
255#[cfg(feature = "example-db")]
256pub(crate) fn canonical_model_example_specs() -> Vec<crate::example_db::specs::ModelExampleSpec> {
257 vec![crate::example_db::specs::ModelExampleSpec {
258 id: "steiner_tree_simplegraph_i32",
259 instance: Box::new(SteinerTree::new(
260 SimpleGraph::new(
261 5,
262 vec![(0, 1), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4)],
263 ),
264 vec![2, 5, 2, 1, 5, 6, 1],
265 vec![0, 2, 4],
266 )),
267 optimal_config: vec![1, 0, 1, 1, 0, 0, 1],
268 optimal_value: serde_json::json!(6),
269 }]
270}
271
272#[cfg(test)]
273#[path = "../../unit_tests/models/graph/steiner_tree.rs"]
274mod tests;