problemreductions/models/algebraic/
closest_vector_problem.rs1use crate::registry::{FieldInfo, ProblemSchemaEntry};
7use crate::traits::{OptimizationProblem, Problem};
8use crate::types::{Direction, SolutionSize};
9use serde::{Deserialize, Serialize};
10
11inventory::submit! {
12 ProblemSchemaEntry {
13 name: "ClosestVectorProblem",
14 module_path: module_path!(),
15 description: "Find the closest lattice point to a target vector",
16 fields: &[
17 FieldInfo { name: "basis", type_name: "Vec<Vec<T>>", description: "Basis matrix B as column vectors" },
18 FieldInfo { name: "target", type_name: "Vec<f64>", description: "Target vector t" },
19 FieldInfo { name: "bounds", type_name: "Vec<VarBounds>", description: "Integer bounds per variable" },
20 ],
21 }
22}
23
24#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
29pub struct VarBounds {
30 pub lower: Option<i64>,
32 pub upper: Option<i64>,
34}
35
36impl VarBounds {
37 pub fn binary() -> Self {
39 Self {
40 lower: Some(0),
41 upper: Some(1),
42 }
43 }
44
45 pub fn non_negative() -> Self {
47 Self {
48 lower: Some(0),
49 upper: None,
50 }
51 }
52
53 pub fn unbounded() -> Self {
55 Self {
56 lower: None,
57 upper: None,
58 }
59 }
60
61 pub fn bounded(lo: i64, hi: i64) -> Self {
63 Self {
64 lower: Some(lo),
65 upper: Some(hi),
66 }
67 }
68
69 pub fn contains(&self, value: i64) -> bool {
71 if let Some(lo) = self.lower {
72 if value < lo {
73 return false;
74 }
75 }
76 if let Some(hi) = self.upper {
77 if value > hi {
78 return false;
79 }
80 }
81 true
82 }
83
84 pub fn num_values(&self) -> Option<usize> {
87 match (self.lower, self.upper) {
88 (Some(lo), Some(hi)) => {
89 if hi >= lo {
90 Some((hi - lo + 1) as usize)
91 } else {
92 Some(0)
93 }
94 }
95 _ => None,
96 }
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ClosestVectorProblem<T> {
109 basis: Vec<Vec<T>>,
111 target: Vec<f64>,
113 bounds: Vec<VarBounds>,
115}
116
117impl<T> ClosestVectorProblem<T> {
118 pub fn new(basis: Vec<Vec<T>>, target: Vec<f64>, bounds: Vec<VarBounds>) -> Self {
128 let n = basis.len();
129 assert_eq!(
130 bounds.len(),
131 n,
132 "bounds length must match number of basis vectors"
133 );
134 let m = target.len();
135 for (i, col) in basis.iter().enumerate() {
136 assert_eq!(
137 col.len(),
138 m,
139 "basis vector {i} has length {}, expected {m}",
140 col.len()
141 );
142 }
143 Self {
144 basis,
145 target,
146 bounds,
147 }
148 }
149
150 pub fn num_basis_vectors(&self) -> usize {
152 self.basis.len()
153 }
154
155 pub fn ambient_dimension(&self) -> usize {
157 self.target.len()
158 }
159
160 pub fn basis(&self) -> &[Vec<T>] {
162 &self.basis
163 }
164
165 pub fn target(&self) -> &[f64] {
167 &self.target
168 }
169
170 pub fn bounds(&self) -> &[VarBounds] {
172 &self.bounds
173 }
174
175 fn config_to_values(&self, config: &[usize]) -> Vec<i64> {
177 config
178 .iter()
179 .enumerate()
180 .map(|(i, &c)| {
181 let lo = self.bounds.get(i).and_then(|b| b.lower).unwrap_or(0);
182 lo + c as i64
183 })
184 .collect()
185 }
186}
187
188impl<T> Problem for ClosestVectorProblem<T>
189where
190 T: Clone
191 + Into<f64>
192 + crate::variant::VariantParam
193 + Serialize
194 + for<'de> Deserialize<'de>
195 + std::fmt::Debug
196 + 'static,
197{
198 const NAME: &'static str = "ClosestVectorProblem";
199 type Metric = SolutionSize<f64>;
200
201 fn dims(&self) -> Vec<usize> {
202 self.bounds
203 .iter()
204 .map(|b| {
205 b.num_values().expect(
206 "CVP brute-force enumeration requires all variables to have finite bounds",
207 )
208 })
209 .collect()
210 }
211
212 fn evaluate(&self, config: &[usize]) -> SolutionSize<f64> {
213 let values = self.config_to_values(config);
214 let m = self.ambient_dimension();
215 let mut diff = vec![0.0f64; m];
216 for (i, &x_i) in values.iter().enumerate() {
217 for (j, b_ji) in self.basis[i].iter().enumerate() {
218 diff[j] += x_i as f64 * b_ji.clone().into();
219 }
220 }
221 for (d, t) in diff.iter_mut().zip(self.target.iter()) {
222 *d -= t;
223 }
224 let norm = diff.iter().map(|d| d * d).sum::<f64>().sqrt();
225 SolutionSize::Valid(norm)
226 }
227
228 fn variant() -> Vec<(&'static str, &'static str)> {
229 crate::variant_params![T]
230 }
231}
232
233impl<T> OptimizationProblem for ClosestVectorProblem<T>
234where
235 T: Clone
236 + Into<f64>
237 + crate::variant::VariantParam
238 + Serialize
239 + for<'de> Deserialize<'de>
240 + std::fmt::Debug
241 + 'static,
242{
243 type Value = f64;
244
245 fn direction(&self) -> Direction {
246 Direction::Minimize
247 }
248}
249
250crate::declare_variants! {
251 ClosestVectorProblem<i32> => "2^num_basis_vectors",
252 ClosestVectorProblem<f64> => "2^num_basis_vectors",
253}
254
255#[cfg(test)]
256#[path = "../../unit_tests/models/algebraic/closest_vector_problem.rs"]
257mod tests;