smtlib/
solver.rs

1use indexmap::{map::Entry, IndexMap, IndexSet};
2use itertools::Itertools;
3use smtlib_lowlevel::{
4    ast::{self, Identifier, QualIdentifier},
5    backend,
6    lexicon::{Numeral, Symbol},
7    Driver, Logger, Storage,
8};
9
10use crate::{
11    funs, sorts,
12    terms::{qual_ident, Dynamic},
13    Bool, Error, Logic, Model, SatResult, SatResultWithModel, Sorted,
14};
15
16/// The [`Solver`] type is the primary entrypoint to interaction with the
17/// solver. Checking for validity of a set of assertions requires:
18/// ```
19/// # use smtlib::{Int, prelude::*};
20/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
21/// // 1. Set up storage (TODO: document)
22/// let st = smtlib::Storage::new();
23/// // 2. Set up the backend (in this case z3)
24/// let backend = smtlib::backend::z3_binary::Z3Binary::new("z3")?;
25/// // 3. Set up the solver
26/// let mut solver = smtlib::Solver::new(&st, backend)?;
27/// // 4. Declare the necessary constants
28/// let x = Int::new_const(&st, "x");
29/// // 5. Add assertions to the solver
30/// solver.assert(x._eq(12))?;
31/// // 6. Check for validity, and optionally construct a model
32/// let sat_res = solver.check_sat_with_model()?;
33/// // 7. In this case we expect sat, and thus want to extract the model
34/// let model = sat_res.expect_sat()?;
35/// // 8. Interpret the result by extract the values of constants which
36/// //    satisfy the assertions.
37/// match model.eval(x) {
38///     Some(x) => println!("This is the value of x: {x}"),
39///     None => panic!("Oh no! This should never happen, as x was part of an assert"),
40/// }
41/// # Ok(())
42/// # }
43/// ```
44#[derive(Debug)]
45pub struct Solver<'st, B> {
46    driver: Driver<'st, B>,
47    push_pop_stack: Vec<StackSizes>,
48    decls: IndexMap<Identifier<'st>, ast::Sort<'st>>,
49    declared_sorts: IndexSet<ast::Sort<'st>>,
50}
51
52#[derive(Debug)]
53struct StackSizes {
54    decls: usize,
55    declared_sorts: usize,
56}
57
58impl<'st, B> Solver<'st, B>
59where
60    B: backend::Backend,
61{
62    /// Construct a new solver provided with the backend to use.
63    ///
64    /// The read more about which backends are available, check out the
65    /// documentation of the [`backend`] module.
66    pub fn new(st: &'st Storage, backend: B) -> Result<Self, Error> {
67        Ok(Self {
68            driver: Driver::new(st, backend)?,
69            push_pop_stack: Vec::new(),
70            decls: Default::default(),
71            declared_sorts: Default::default(),
72        })
73    }
74    /// Get the smtlib storage.
75    pub fn st(&self) -> &'st Storage {
76        self.driver.st()
77    }
78    /// Set the logger for the solver. This is useful for debugging or tracing
79    /// purposes.
80    pub fn set_logger(&mut self, logger: impl Logger) {
81        self.driver.set_logger(logger)
82    }
83    /// Set the timeout for the solver. The timeout is in milliseconds.
84    pub fn set_timeout(&mut self, ms: usize) -> Result<(), Error> {
85        let cmd = ast::Command::SetOption(ast::Option::Attribute(ast::Attribute::WithValue(
86            smtlib_lowlevel::lexicon::Keyword(":timeout"),
87            ast::AttributeValue::SpecConstant(ast::SpecConstant::Numeral(Numeral::from_usize(ms))),
88        )));
89        match self.driver.exec(cmd)? {
90            ast::GeneralResponse::Success => Ok(()),
91            ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
92            _ => todo!(),
93        }
94    }
95    /// Explicitly sets the logic for the solver. For some backends this is not
96    /// required, as they will infer what ever logic fits the current program.
97    ///
98    /// To read more about logics read the documentation of [`Logic`].
99    pub fn set_logic(&mut self, logic: Logic) -> Result<(), Error> {
100        let cmd = ast::Command::SetLogic(Symbol(self.st().alloc_str(&logic.name())));
101        match self.driver.exec(cmd)? {
102            ast::GeneralResponse::Success => Ok(()),
103            ast::GeneralResponse::SpecificSuccessResponse(_) => todo!(),
104            ast::GeneralResponse::Unsupported => todo!(),
105            ast::GeneralResponse::Error(_) => todo!(),
106        }
107    }
108    /// Runs the given command on the solver, and returns the result.
109    pub fn run_command(
110        &mut self,
111        cmd: ast::Command<'st>,
112    ) -> Result<ast::GeneralResponse<'st>, Error> {
113        Ok(self.driver.exec(cmd)?)
114    }
115    /// Adds the constraint of `b` as an assertion to the solver. To check for
116    /// satisfiability call [`Solver::check_sat`] or
117    /// [`Solver::check_sat_with_model`].
118    pub fn assert(&mut self, b: Bool<'st>) -> Result<(), Error> {
119        let term = b.term();
120
121        self.declare_all_consts(term)?;
122
123        let cmd = ast::Command::Assert(term);
124        match self.driver.exec(cmd)? {
125            ast::GeneralResponse::Success => Ok(()),
126            ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
127            _ => todo!(),
128        }
129    }
130    /// Checks for satisfiability of the assertions sent to the solver using
131    /// [`Solver::assert`].
132    ///
133    /// If you are interested in producing a model satisfying the assertions
134    /// check out [`Solver::check_sat`].
135    pub fn check_sat(&mut self) -> Result<SatResult, Error> {
136        let cmd = ast::Command::CheckSat;
137        match self.driver.exec(cmd)? {
138            ast::GeneralResponse::SpecificSuccessResponse(
139                ast::SpecificSuccessResponse::CheckSatResponse(res),
140            ) => Ok(match res {
141                ast::CheckSatResponse::Sat => SatResult::Sat,
142                ast::CheckSatResponse::Unsat => SatResult::Unsat,
143                ast::CheckSatResponse::Unknown => SatResult::Unknown,
144            }),
145            ast::GeneralResponse::Error(msg) => Err(Error::Smt(msg.to_string(), format!("{cmd}"))),
146            res => todo!("{res:?}"),
147        }
148    }
149    /// Checks for satisfiability of the assertions sent to the solver using
150    /// [`Solver::assert`], and produces a [model](Model) in case of `sat`.
151    ///
152    /// If you are not interested in the produced model, check out
153    /// [`Solver::check_sat`].
154    pub fn check_sat_with_model(&mut self) -> Result<SatResultWithModel<'st>, Error> {
155        match self.check_sat()? {
156            SatResult::Unsat => Ok(SatResultWithModel::Unsat),
157            SatResult::Sat => Ok(SatResultWithModel::Sat(self.get_model()?)),
158            SatResult::Unknown => Ok(SatResultWithModel::Unknown),
159        }
160    }
161    /// Produces the model for satisfying the assertions. If you are looking to
162    /// retrieve a model after calling [`Solver::check_sat`], consider using
163    /// [`Solver::check_sat_with_model`] instead.
164    ///
165    /// > **NOTE:** This must only be called after having called
166    /// > [`Solver::check_sat`] and it returning [`SatResult::Sat`].
167    pub fn get_model(&mut self) -> Result<Model<'st>, Error> {
168        match self.driver.exec(ast::Command::GetModel)? {
169            ast::GeneralResponse::SpecificSuccessResponse(
170                ast::SpecificSuccessResponse::GetModelResponse(model),
171            ) => Ok(Model::new(self.st(), model)),
172            res => todo!("{res:?}"),
173        }
174    }
175    /// Declares a function to the solver. For more details refer to the
176    /// [`funs`] module.
177    pub fn declare_fun(&mut self, fun: &funs::Fun<'st>) -> Result<(), Error> {
178        for var in fun.vars {
179            self.declare_sort(&var.ast())?;
180        }
181        self.declare_sort(&fun.return_sort.ast())?;
182
183        if fun.vars.is_empty() {
184            return self.declare_const(&qual_ident(fun.name, Some(fun.return_sort.ast())));
185        }
186
187        let cmd = ast::Command::DeclareFun(
188            Symbol(fun.name),
189            self.st()
190                .alloc_slice(&fun.vars.iter().map(|s| s.ast()).collect_vec()),
191            fun.return_sort.ast(),
192        );
193        match self.driver.exec(cmd)? {
194            ast::GeneralResponse::Success => Ok(()),
195            ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
196            _ => todo!(),
197        }
198    }
199    /// Simplifies the given term
200    pub fn simplify(
201        &mut self,
202        t: Dynamic<'st>,
203    ) -> Result<&'st smtlib_lowlevel::ast::Term<'st>, Error> {
204        self.declare_all_consts(t.term())?;
205
206        let cmd = ast::Command::Simplify(t.term());
207
208        match self.driver.exec(cmd)? {
209            ast::GeneralResponse::SpecificSuccessResponse(
210                ast::SpecificSuccessResponse::SimplifyResponse(t),
211            ) => Ok(t.0),
212            res => todo!("{res:?}"),
213        }
214    }
215
216    /// Start a new scope, execute the given closure, and then pop the scope.
217    ///
218    /// A scope is a way to group a set of assertions together, and then later
219    /// rollback all the assertions to the state before the scope was started.
220    pub fn scope<T>(
221        &mut self,
222        f: impl FnOnce(&mut Solver<'st, B>) -> Result<T, Error>,
223    ) -> Result<T, Error> {
224        self.push(1)?;
225        let res = f(self)?;
226        self.pop(1)?;
227        Ok(res)
228    }
229
230    fn push(&mut self, levels: usize) -> Result<(), Error> {
231        self.push_pop_stack.push(StackSizes {
232            decls: self.decls.len(),
233            declared_sorts: self.declared_sorts.len(),
234        });
235
236        let cmd = ast::Command::Push(Numeral::from_usize(levels));
237        match self.driver.exec(cmd)? {
238            ast::GeneralResponse::Success => {}
239            ast::GeneralResponse::Error(e) => {
240                return Err(Error::Smt(e.to_string(), cmd.to_string()))
241            }
242            _ => todo!(),
243        };
244        Ok(())
245    }
246
247    fn pop(&mut self, levels: usize) -> Result<(), Error> {
248        if let Some(sizes) = self.push_pop_stack.pop() {
249            self.decls.truncate(sizes.decls);
250            self.declared_sorts.truncate(sizes.declared_sorts);
251        }
252
253        let cmd = ast::Command::Pop(Numeral::from_usize(levels));
254        match self.driver.exec(cmd)? {
255            ast::GeneralResponse::Success => {}
256            ast::GeneralResponse::Error(e) => {
257                return Err(Error::Smt(e.to_string(), cmd.to_string()))
258            }
259            _ => todo!(),
260        };
261        Ok(())
262    }
263
264    fn declare_all_consts(&mut self, t: &'st ast::Term<'st>) -> Result<(), Error> {
265        for q in t.all_consts() {
266            self.declare_const(q)?;
267        }
268        Ok(())
269    }
270
271    fn declare_const(&mut self, q: &QualIdentifier<'st>) -> Result<(), Error> {
272        match q {
273            QualIdentifier::Identifier(_) => {}
274            QualIdentifier::Sorted(i, s) => {
275                self.declare_sort(s)?;
276
277                match self.decls.entry(*i) {
278                    Entry::Occupied(stored) => assert_eq!(s, stored.get()),
279                    Entry::Vacant(v) => {
280                        v.insert(*s);
281                        match i {
282                            Identifier::Simple(sym) => {
283                                self.driver.exec(ast::Command::DeclareConst(*sym, *s))?;
284                            }
285                            Identifier::Indexed(_, _) => todo!(),
286                        }
287                    }
288                }
289            }
290        };
291        Ok(())
292    }
293
294    fn declare_sort(&mut self, s: &ast::Sort<'st>) -> Result<(), Error> {
295        if self.declared_sorts.contains(s) {
296            return Ok(());
297        }
298        self.declared_sorts.insert(*s);
299
300        let cmd = match s {
301            ast::Sort::Sort(ident) => {
302                let sym = match ident {
303                    Identifier::Simple(sym) => sym,
304                    Identifier::Indexed(_, _) => {
305                        // TODO: is it correct that only sorts from theores can
306                        // be indexed, and thus does not need to be declared?
307                        return Ok(());
308                    }
309                };
310                if sorts::is_built_in_sort(sym.0) {
311                    return Ok(());
312                }
313                ast::Command::DeclareSort(*sym, Numeral::from_usize(0))
314            }
315            ast::Sort::Parametric(ident, params) => {
316                let sym = match ident {
317                    Identifier::Simple(sym) => sym,
318                    Identifier::Indexed(_, _) => {
319                        // TODO: is it correct that only sorts from theores can
320                        // be indexed, and thus does not need to be declared?
321                        return Ok(());
322                    }
323                };
324                if sorts::is_built_in_sort(sym.0) {
325                    return Ok(());
326                }
327                ast::Command::DeclareSort(*sym, Numeral::from_usize(params.len()))
328            }
329        };
330        match self.driver.exec(cmd)? {
331            ast::GeneralResponse::Success => Ok(()),
332            ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
333            _ => todo!(),
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use smtlib_lowlevel::{backend::z3_binary::Z3Binary, Storage};
341
342    use super::Solver;
343    use crate::{terms::StaticSorted, Int, SatResult, Sorted};
344
345    #[test]
346    fn scope() -> Result<(), crate::Error> {
347        let st = Storage::new();
348        let mut solver = Solver::new(&st, Z3Binary::new("z3").unwrap())?;
349
350        let x = Int::new_const(&st, "x");
351
352        solver.assert(x._eq(10))?;
353
354        solver.scope(|solver| {
355            solver.assert(x._eq(20))?;
356
357            assert_eq!(solver.check_sat()?, SatResult::Unsat);
358
359            Ok(())
360        })?;
361
362        assert_eq!(solver.check_sat()?, SatResult::Sat);
363
364        Ok(())
365    }
366}