Skip to main content

problemreductions/models/graph/
steiner_tree.rs

1//! Steiner Tree problem implementation.
2//!
3//! Given a weighted graph and a set of terminal vertices, find a minimum-weight
4//! tree that connects all terminals.
5
6use std::collections::BTreeSet;
7
8use num_traits::Zero;
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    registry::{FieldInfo, ProblemSchemaEntry, VariantDimension},
13    topology::{Graph, SimpleGraph},
14    traits::Problem,
15    types::{Min, One, WeightElement},
16};
17
18inventory::submit! {
19    ProblemSchemaEntry {
20        name: "SteinerTree",
21        display_name: "Steiner Tree",
22        aliases: &[],
23        dimensions: &[
24            VariantDimension::new("graph", "SimpleGraph", &["SimpleGraph"]),
25            VariantDimension::new("weight", "i32", &["One", "i32"]),
26        ],
27        module_path: module_path!(),
28        description: "Find minimum weight tree connecting terminal vertices",
29        fields: &[
30            FieldInfo { name: "graph", type_name: "G", description: "The underlying graph G=(V,E)" },
31            FieldInfo { name: "edge_weights", type_name: "Vec<W>", description: "Edge weights w: E -> R" },
32            FieldInfo { name: "terminals", type_name: "Vec<usize>", description: "Terminal vertices T that must be connected" },
33        ],
34    }
35}
36
37/// The Steiner Tree problem.
38///
39/// Given a weighted graph G = (V, E) with edge weights w_e and a set
40/// of terminal vertices T, find a tree S in G such that T is a subset
41/// of V(S), minimizing the total edge weight of S.
42///
43/// # Representation
44///
45/// Each edge is assigned a binary variable:
46/// - 0: edge is not in the Steiner tree
47/// - 1: edge is in the Steiner tree
48///
49/// A valid Steiner tree requires:
50/// - Selected edges form a tree (connected + acyclic)
51/// - All terminal vertices are included
52///
53/// # Type Parameters
54///
55/// * `G` - The graph type (e.g., `SimpleGraph`)
56/// * `W` - The weight type for edges (e.g., `i32`, `f64`)
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SteinerTree<G, W> {
59    /// The underlying graph.
60    graph: G,
61    /// Weights for each edge (in edge index order).
62    edge_weights: Vec<W>,
63    /// Terminal vertices that must be connected.
64    terminals: Vec<usize>,
65}
66
67impl<G: Graph, W: Clone + Default> SteinerTree<G, W> {
68    /// Create a SteinerTree problem from a graph, edge weights, and terminals.
69    pub fn new(graph: G, edge_weights: Vec<W>, terminals: Vec<usize>) -> Self {
70        assert_eq!(
71            edge_weights.len(),
72            graph.num_edges(),
73            "edge_weights length must match num_edges"
74        );
75        let n = graph.num_vertices();
76        let distinct_terminals: BTreeSet<_> = terminals.iter().copied().collect();
77        assert_eq!(
78            distinct_terminals.len(),
79            terminals.len(),
80            "terminals must be distinct"
81        );
82        for &t in &terminals {
83            assert!(t < n, "terminal {t} out of range (num_vertices = {n})");
84        }
85        assert!(terminals.len() >= 2, "at least 2 terminals required");
86        Self {
87            graph,
88            edge_weights,
89            terminals,
90        }
91    }
92
93    /// Create a SteinerTree problem with unit edge weights.
94    pub fn unit_weights(graph: G, terminals: Vec<usize>) -> Self
95    where
96        W: From<i32>,
97    {
98        let edge_weights = vec![W::from(1); graph.num_edges()];
99        Self::new(graph, edge_weights, terminals)
100    }
101
102    /// Get a reference to the underlying graph.
103    pub fn graph(&self) -> &G {
104        &self.graph
105    }
106
107    /// Get a reference to the edge weights.
108    pub fn edge_weights(&self) -> &[W] {
109        &self.edge_weights
110    }
111
112    /// Set new edge weights.
113    pub fn set_weights(&mut self, weights: Vec<W>) {
114        assert_eq!(weights.len(), self.graph.num_edges());
115        self.edge_weights = weights;
116    }
117
118    /// Get the edge weights as a Vec.
119    pub fn weights(&self) -> Vec<W> {
120        self.edge_weights.clone()
121    }
122
123    /// Get the terminal vertices.
124    pub fn terminals(&self) -> &[usize] {
125        &self.terminals
126    }
127
128    /// Check if the problem uses a non-unit weight type.
129    pub fn is_weighted(&self) -> bool
130    where
131        W: WeightElement,
132    {
133        !W::IS_UNIT
134    }
135
136    /// Check if a configuration is a valid Steiner tree.
137    pub fn is_valid_solution(&self, config: &[usize]) -> bool {
138        is_valid_steiner_tree(&self.graph, &self.terminals, config)
139    }
140}
141
142impl<G: Graph, W: WeightElement> SteinerTree<G, W> {
143    /// Get the number of vertices in the underlying graph.
144    pub fn num_vertices(&self) -> usize {
145        self.graph.num_vertices()
146    }
147
148    /// Get the number of edges in the underlying graph.
149    pub fn num_edges(&self) -> usize {
150        self.graph.num_edges()
151    }
152
153    /// Get the number of terminal vertices.
154    pub fn num_terminals(&self) -> usize {
155        self.terminals.len()
156    }
157}
158
159/// Check if a configuration forms a valid Steiner tree:
160/// 1. Selected edges form a connected subgraph containing all terminals
161/// 2. Selected edges are acyclic (tree property)
162fn is_valid_steiner_tree<G: Graph>(graph: &G, terminals: &[usize], config: &[usize]) -> bool {
163    let n = graph.num_vertices();
164    let edges = graph.edges();
165    if config.len() != edges.len() {
166        return false;
167    }
168
169    // Build adjacency list from selected edges
170    let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
171    let mut selected_count = 0usize;
172    let mut involved = vec![false; n];
173    for (idx, &sel) in config.iter().enumerate() {
174        if sel == 1 {
175            let (u, v) = edges[idx];
176            adj[u].push(v);
177            adj[v].push(u);
178            involved[u] = true;
179            involved[v] = true;
180            selected_count += 1;
181        }
182    }
183
184    if selected_count == 0 {
185        return false;
186    }
187
188    // BFS from first terminal to check connectivity
189    let start = terminals[0];
190    let mut visited = vec![false; n];
191    let mut queue = std::collections::VecDeque::new();
192    visited[start] = true;
193    queue.push_back(start);
194    while let Some(v) = queue.pop_front() {
195        for &u in &adj[v] {
196            if !visited[u] {
197                visited[u] = true;
198                queue.push_back(u);
199            }
200        }
201    }
202
203    // All terminals must be reachable
204    if !terminals.iter().all(|&t| visited[t]) {
205        return false;
206    }
207
208    // All involved vertices must be in one connected component
209    if (0..n).any(|i| involved[i] && !visited[i]) {
210        return false;
211    }
212
213    // Tree property: #edges == #involved_vertices - 1
214    let involved_count = involved.iter().filter(|&&x| x).count();
215    selected_count == involved_count - 1
216}
217
218impl<G, W> Problem for SteinerTree<G, W>
219where
220    G: Graph + crate::variant::VariantParam,
221    W: WeightElement + crate::variant::VariantParam,
222{
223    const NAME: &'static str = "SteinerTree";
224    type Value = Min<W::Sum>;
225
226    fn variant() -> Vec<(&'static str, &'static str)> {
227        crate::variant_params![G, W]
228    }
229
230    fn dims(&self) -> Vec<usize> {
231        vec![2; self.graph.num_edges()]
232    }
233
234    fn evaluate(&self, config: &[usize]) -> Min<W::Sum> {
235        if !is_valid_steiner_tree(&self.graph, &self.terminals, config) {
236            return Min(None);
237        }
238        let mut total = W::Sum::zero();
239        for (idx, &selected) in config.iter().enumerate() {
240            if selected == 1 {
241                if let Some(w) = self.edge_weights.get(idx) {
242                    total += w.to_sum();
243                }
244            }
245        }
246        Min(Some(total))
247    }
248}
249
250crate::declare_variants! {
251    default SteinerTree<SimpleGraph, i32> => "3^num_terminals * num_vertices + 2^num_terminals * num_vertices^2",
252    SteinerTree<SimpleGraph, One> => "3^num_terminals * num_vertices + 2^num_terminals * num_vertices^2",
253}
254
255#[cfg(feature = "example-db")]
256pub(crate) fn canonical_model_example_specs() -> Vec<crate::example_db::specs::ModelExampleSpec> {
257    vec![crate::example_db::specs::ModelExampleSpec {
258        id: "steiner_tree_simplegraph_i32",
259        instance: Box::new(SteinerTree::new(
260            SimpleGraph::new(
261                5,
262                vec![(0, 1), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4)],
263            ),
264            vec![2, 5, 2, 1, 5, 6, 1],
265            vec![0, 2, 4],
266        )),
267        optimal_config: vec![1, 0, 1, 1, 0, 0, 1],
268        optimal_value: serde_json::json!(6),
269    }]
270}
271
272#[cfg(test)]
273#[path = "../../unit_tests/models/graph/steiner_tree.rs"]
274mod tests;