Skip to main content

problemreductions/models/graph/
steiner_tree_in_graphs.rs

1//! Steiner Tree in Graphs problem implementation.
2//!
3//! The Steiner Tree problem asks for a minimum-weight subtree of a graph
4//! that connects all terminal vertices.
5
6use crate::registry::{FieldInfo, ProblemSchemaEntry, VariantDimension};
7use crate::topology::{Graph, SimpleGraph};
8use crate::traits::Problem;
9use crate::types::{Min, One, WeightElement};
10use num_traits::Zero;
11use serde::{Deserialize, Serialize};
12
13inventory::submit! {
14    ProblemSchemaEntry {
15        name: "SteinerTreeInGraphs",
16        display_name: "Steiner Tree in Graphs",
17        aliases: &[],
18        dimensions: &[
19            VariantDimension::new("graph", "SimpleGraph", &["SimpleGraph"]),
20            VariantDimension::new("weight", "i32", &["One", "i32"]),
21        ],
22        module_path: module_path!(),
23        description: "Find minimum weight subtree connecting all terminal vertices",
24        fields: &[
25            FieldInfo { name: "graph", type_name: "G", description: "The underlying graph G=(V,E)" },
26            FieldInfo { name: "terminals", type_name: "Vec<usize>", description: "Required terminal vertices R ⊆ V" },
27            FieldInfo { name: "edge_weights", type_name: "Vec<W>", description: "Edge weights w: E -> R" },
28        ],
29    }
30}
31
32/// The Steiner Tree in Graphs problem.
33///
34/// Given a weighted graph G = (V, E) with edge weights w_e and a
35/// subset R ⊆ V of required terminal vertices, find a subtree T of G
36/// that includes all vertices of R and minimizes the total edge weight
37/// Σ_{e ∈ T} w(e).
38///
39/// # Representation
40///
41/// Each edge is assigned a binary variable:
42/// - 0: edge is not in the tree
43/// - 1: edge is in the tree
44///
45/// A valid Steiner tree requires:
46/// - All terminal vertices are connected through selected edges
47/// - Selected edges form a connected subgraph (optimally a tree)
48///
49/// # Type Parameters
50///
51/// * `G` - The graph type (e.g., `SimpleGraph`)
52/// * `W` - The weight type for edges (e.g., `i32`, `f64`)
53///
54/// # Example
55///
56/// ```
57/// use problemreductions::models::graph::SteinerTreeInGraphs;
58/// use problemreductions::topology::SimpleGraph;
59/// use problemreductions::{Problem, Solver, BruteForce};
60///
61/// // Path graph 0-1-2-3, terminals {0, 3}
62/// let graph = SimpleGraph::new(4, vec![(0, 1), (1, 2), (2, 3)]);
63/// let problem = SteinerTreeInGraphs::new(graph, vec![0, 3], vec![1, 1, 1]);
64///
65/// let solver = BruteForce::new();
66/// let solution = solver.find_witness(&problem).unwrap();
67/// // Optimal: select all 3 edges (the only path from 0 to 3)
68/// assert_eq!(solution, vec![1, 1, 1]);
69/// ```
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SteinerTreeInGraphs<G, W> {
72    /// The underlying graph.
73    graph: G,
74    /// Required terminal vertices.
75    terminals: Vec<usize>,
76    /// Weights for each edge (in edge index order).
77    edge_weights: Vec<W>,
78}
79
80impl<G: Graph, W: Clone + Default> SteinerTreeInGraphs<G, W> {
81    /// Create a SteinerTreeInGraphs problem from a graph, terminals, and edge weights.
82    ///
83    /// # Panics
84    /// Panics if `edge_weights.len() != graph.num_edges()` or any terminal index is out of bounds.
85    pub fn new(graph: G, terminals: Vec<usize>, edge_weights: Vec<W>) -> Self {
86        assert_eq!(
87            edge_weights.len(),
88            graph.num_edges(),
89            "edge_weights length must match num_edges"
90        );
91        for &t in &terminals {
92            assert!(
93                t < graph.num_vertices(),
94                "terminal vertex {} out of bounds (num_vertices = {})",
95                t,
96                graph.num_vertices()
97            );
98        }
99        Self {
100            graph,
101            terminals,
102            edge_weights,
103        }
104    }
105
106    /// Get a reference to the underlying graph.
107    pub fn graph(&self) -> &G {
108        &self.graph
109    }
110
111    /// Get the terminal vertices.
112    pub fn terminals(&self) -> &[usize] {
113        &self.terminals
114    }
115
116    /// Get all edges with their weights.
117    pub fn edges(&self) -> Vec<(usize, usize, W)> {
118        self.graph
119            .edges()
120            .into_iter()
121            .zip(self.edge_weights.iter().cloned())
122            .map(|((u, v), w)| (u, v, w))
123            .collect()
124    }
125
126    /// Set new weights for the problem.
127    pub fn set_weights(&mut self, weights: Vec<W>) {
128        assert_eq!(weights.len(), self.graph.num_edges());
129        self.edge_weights = weights;
130    }
131
132    /// Get the weights for the problem.
133    pub fn weights(&self) -> Vec<W> {
134        self.edge_weights.clone()
135    }
136
137    /// Check if the problem uses a non-unit weight type.
138    pub fn is_weighted(&self) -> bool
139    where
140        W: WeightElement,
141    {
142        !W::IS_UNIT
143    }
144
145    /// Check if a configuration is a valid Steiner tree.
146    pub fn is_valid_solution(&self, config: &[usize]) -> bool {
147        if config.len() != self.graph.num_edges() {
148            return false;
149        }
150        let selected: Vec<bool> = config.iter().map(|&s| s == 1).collect();
151        is_steiner_tree(&self.graph, &self.terminals, &selected)
152    }
153}
154
155impl<G: Graph, W: WeightElement> SteinerTreeInGraphs<G, W> {
156    /// Get the number of vertices in the underlying graph.
157    pub fn num_vertices(&self) -> usize {
158        self.graph().num_vertices()
159    }
160
161    /// Get the number of edges in the underlying graph.
162    pub fn num_edges(&self) -> usize {
163        self.graph().num_edges()
164    }
165
166    /// Get the number of terminal vertices.
167    pub fn num_terminals(&self) -> usize {
168        self.terminals.len()
169    }
170}
171
172impl<G, W> Problem for SteinerTreeInGraphs<G, W>
173where
174    G: Graph + crate::variant::VariantParam,
175    W: WeightElement + crate::variant::VariantParam,
176{
177    const NAME: &'static str = "SteinerTreeInGraphs";
178    type Value = Min<W::Sum>;
179
180    fn variant() -> Vec<(&'static str, &'static str)> {
181        crate::variant_params![G, W]
182    }
183
184    fn dims(&self) -> Vec<usize> {
185        vec![2; self.graph.num_edges()]
186    }
187
188    fn evaluate(&self, config: &[usize]) -> Min<W::Sum> {
189        if config.len() != self.graph.num_edges() {
190            return Min(None);
191        }
192        let selected: Vec<bool> = config.iter().map(|&s| s == 1).collect();
193        if !is_steiner_tree(&self.graph, &self.terminals, &selected) {
194            return Min(None);
195        }
196        let mut total = W::Sum::zero();
197        for (idx, &sel) in config.iter().enumerate() {
198            if sel == 1 {
199                if let Some(w) = self.edge_weights.get(idx) {
200                    total += w.to_sum();
201                }
202            }
203        }
204        Min(Some(total))
205    }
206}
207
208/// Check if a selection of edges forms a valid Steiner tree (connected subgraph spanning all terminals).
209///
210/// A valid Steiner tree requires:
211/// 1. All terminal vertices are reachable from each other through selected edges.
212/// 2. The selected edges form a connected subgraph that includes all terminals.
213///
214/// Note: The optimal solution is always a tree, but we accept any connected subgraph
215/// spanning all terminals (the brute-force solver will find the minimum-weight one).
216///
217/// # Panics
218/// Panics if `selected.len() != graph.num_edges()`.
219pub(crate) fn is_steiner_tree<G: Graph>(graph: &G, terminals: &[usize], selected: &[bool]) -> bool {
220    assert_eq!(
221        selected.len(),
222        graph.num_edges(),
223        "selected length must match num_edges"
224    );
225
226    // If no terminals, any selection is trivially valid (including empty)
227    if terminals.is_empty() {
228        return true;
229    }
230
231    // If only one terminal, it's valid as long as that terminal exists
232    // (no edges needed to connect a single vertex)
233    if terminals.len() == 1 {
234        return true;
235    }
236
237    // Build adjacency list from selected edges
238    let n = graph.num_vertices();
239    let edges = graph.edges();
240    let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
241
242    let mut has_any_edge = false;
243    for (idx, &sel) in selected.iter().enumerate() {
244        if sel {
245            let (u, v) = edges[idx];
246            adj[u].push(v);
247            adj[v].push(u);
248            has_any_edge = true;
249        }
250    }
251
252    if !has_any_edge {
253        return false;
254    }
255
256    // BFS from the first terminal to check connectivity of all terminals
257    let start = terminals[0];
258    let mut visited = vec![false; n];
259    let mut queue = std::collections::VecDeque::new();
260    visited[start] = true;
261    queue.push_back(start);
262
263    while let Some(node) = queue.pop_front() {
264        for &neighbor in &adj[node] {
265            if !visited[neighbor] {
266                visited[neighbor] = true;
267                queue.push_back(neighbor);
268            }
269        }
270    }
271
272    // All terminals must be reachable
273    terminals.iter().all(|&t| visited[t])
274}
275
276crate::declare_variants! {
277    default SteinerTreeInGraphs<SimpleGraph, i32> => "2^num_terminals * num_vertices^3",
278    SteinerTreeInGraphs<SimpleGraph, One> => "2^num_terminals * num_vertices^3",
279}
280
281#[cfg(feature = "example-db")]
282pub(crate) fn canonical_model_example_specs() -> Vec<crate::example_db::specs::ModelExampleSpec> {
283    vec![crate::example_db::specs::ModelExampleSpec {
284        id: "steiner_tree_in_graphs_simplegraph_i32",
285        instance: Box::new(SteinerTreeInGraphs::new(
286            SimpleGraph::new(
287                6,
288                vec![(0, 1), (0, 2), (1, 3), (2, 3), (2, 5), (3, 4), (4, 5)],
289            ),
290            vec![0, 3, 5],
291            vec![3, 2, 4, 1, 2, 3, 1],
292        )),
293        // Optimal: edges {0,2}(w=2), {2,3}(w=1), {2,5}(w=2) = weight 5
294        optimal_config: vec![0, 1, 0, 1, 1, 0, 0],
295        optimal_value: serde_json::json!(5),
296    }]
297}
298
299#[cfg(test)]
300#[path = "../../unit_tests/models/graph/steiner_tree_in_graphs.rs"]
301mod tests;