1use indexmap::{map::Entry, IndexMap, IndexSet};
2use itertools::Itertools;
3use smtlib_lowlevel::{
4 ast::{self, Identifier, QualIdentifier},
5 backend,
6 lexicon::{Numeral, Symbol},
7 Driver, Logger, Storage,
8};
9
10use crate::{
11 funs, sorts,
12 terms::{qual_ident, Dynamic},
13 Bool, Error, Logic, Model, SatResult, SatResultWithModel, Sorted,
14};
15
16#[derive(Debug)]
45pub struct Solver<'st, B> {
46 driver: Driver<'st, B>,
47 push_pop_stack: Vec<StackSizes>,
48 decls: IndexMap<Identifier<'st>, ast::Sort<'st>>,
49 declared_sorts: IndexSet<ast::Sort<'st>>,
50}
51
52#[derive(Debug)]
53struct StackSizes {
54 decls: usize,
55 declared_sorts: usize,
56}
57
58impl<'st, B> Solver<'st, B>
59where
60 B: backend::Backend,
61{
62 pub fn new(st: &'st Storage, backend: B) -> Result<Self, Error> {
67 Ok(Self {
68 driver: Driver::new(st, backend)?,
69 push_pop_stack: Vec::new(),
70 decls: Default::default(),
71 declared_sorts: Default::default(),
72 })
73 }
74 pub fn st(&self) -> &'st Storage {
76 self.driver.st()
77 }
78 pub fn set_logger(&mut self, logger: impl Logger) {
81 self.driver.set_logger(logger)
82 }
83 pub fn set_timeout(&mut self, ms: usize) -> Result<(), Error> {
85 let cmd = ast::Command::SetOption(ast::Option::Attribute(ast::Attribute::WithValue(
86 smtlib_lowlevel::lexicon::Keyword(":timeout"),
87 ast::AttributeValue::SpecConstant(ast::SpecConstant::Numeral(Numeral::from_usize(ms))),
88 )));
89 match self.driver.exec(cmd)? {
90 ast::GeneralResponse::Success => Ok(()),
91 ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
92 _ => todo!(),
93 }
94 }
95 pub fn set_logic(&mut self, logic: Logic) -> Result<(), Error> {
100 let cmd = ast::Command::SetLogic(Symbol(self.st().alloc_str(&logic.name())));
101 match self.driver.exec(cmd)? {
102 ast::GeneralResponse::Success => Ok(()),
103 ast::GeneralResponse::SpecificSuccessResponse(_) => todo!(),
104 ast::GeneralResponse::Unsupported => todo!(),
105 ast::GeneralResponse::Error(_) => todo!(),
106 }
107 }
108 pub fn run_command(
110 &mut self,
111 cmd: ast::Command<'st>,
112 ) -> Result<ast::GeneralResponse<'st>, Error> {
113 Ok(self.driver.exec(cmd)?)
114 }
115 pub fn assert(&mut self, b: Bool<'st>) -> Result<(), Error> {
119 let term = b.term();
120
121 self.declare_all_consts(term)?;
122
123 let cmd = ast::Command::Assert(term);
124 match self.driver.exec(cmd)? {
125 ast::GeneralResponse::Success => Ok(()),
126 ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
127 _ => todo!(),
128 }
129 }
130 pub fn check_sat(&mut self) -> Result<SatResult, Error> {
136 let cmd = ast::Command::CheckSat;
137 match self.driver.exec(cmd)? {
138 ast::GeneralResponse::SpecificSuccessResponse(
139 ast::SpecificSuccessResponse::CheckSatResponse(res),
140 ) => Ok(match res {
141 ast::CheckSatResponse::Sat => SatResult::Sat,
142 ast::CheckSatResponse::Unsat => SatResult::Unsat,
143 ast::CheckSatResponse::Unknown => SatResult::Unknown,
144 }),
145 ast::GeneralResponse::Error(msg) => Err(Error::Smt(msg.to_string(), format!("{cmd}"))),
146 res => todo!("{res:?}"),
147 }
148 }
149 pub fn check_sat_with_model(&mut self) -> Result<SatResultWithModel<'st>, Error> {
155 match self.check_sat()? {
156 SatResult::Unsat => Ok(SatResultWithModel::Unsat),
157 SatResult::Sat => Ok(SatResultWithModel::Sat(self.get_model()?)),
158 SatResult::Unknown => Ok(SatResultWithModel::Unknown),
159 }
160 }
161 pub fn get_model(&mut self) -> Result<Model<'st>, Error> {
168 match self.driver.exec(ast::Command::GetModel)? {
169 ast::GeneralResponse::SpecificSuccessResponse(
170 ast::SpecificSuccessResponse::GetModelResponse(model),
171 ) => Ok(Model::new(self.st(), model)),
172 res => todo!("{res:?}"),
173 }
174 }
175 pub fn declare_fun(&mut self, fun: &funs::Fun<'st>) -> Result<(), Error> {
178 for var in fun.vars {
179 self.declare_sort(&var.ast())?;
180 }
181 self.declare_sort(&fun.return_sort.ast())?;
182
183 if fun.vars.is_empty() {
184 return self.declare_const(&qual_ident(fun.name, Some(fun.return_sort.ast())));
185 }
186
187 let cmd = ast::Command::DeclareFun(
188 Symbol(fun.name),
189 self.st()
190 .alloc_slice(&fun.vars.iter().map(|s| s.ast()).collect_vec()),
191 fun.return_sort.ast(),
192 );
193 match self.driver.exec(cmd)? {
194 ast::GeneralResponse::Success => Ok(()),
195 ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
196 _ => todo!(),
197 }
198 }
199 pub fn simplify(
201 &mut self,
202 t: Dynamic<'st>,
203 ) -> Result<&'st smtlib_lowlevel::ast::Term<'st>, Error> {
204 self.declare_all_consts(t.term())?;
205
206 let cmd = ast::Command::Simplify(t.term());
207
208 match self.driver.exec(cmd)? {
209 ast::GeneralResponse::SpecificSuccessResponse(
210 ast::SpecificSuccessResponse::SimplifyResponse(t),
211 ) => Ok(t.0),
212 res => todo!("{res:?}"),
213 }
214 }
215
216 pub fn scope<T>(
221 &mut self,
222 f: impl FnOnce(&mut Solver<'st, B>) -> Result<T, Error>,
223 ) -> Result<T, Error> {
224 self.push(1)?;
225 let res = f(self)?;
226 self.pop(1)?;
227 Ok(res)
228 }
229
230 fn push(&mut self, levels: usize) -> Result<(), Error> {
231 self.push_pop_stack.push(StackSizes {
232 decls: self.decls.len(),
233 declared_sorts: self.declared_sorts.len(),
234 });
235
236 let cmd = ast::Command::Push(Numeral::from_usize(levels));
237 match self.driver.exec(cmd)? {
238 ast::GeneralResponse::Success => {}
239 ast::GeneralResponse::Error(e) => {
240 return Err(Error::Smt(e.to_string(), cmd.to_string()))
241 }
242 _ => todo!(),
243 };
244 Ok(())
245 }
246
247 fn pop(&mut self, levels: usize) -> Result<(), Error> {
248 if let Some(sizes) = self.push_pop_stack.pop() {
249 self.decls.truncate(sizes.decls);
250 self.declared_sorts.truncate(sizes.declared_sorts);
251 }
252
253 let cmd = ast::Command::Pop(Numeral::from_usize(levels));
254 match self.driver.exec(cmd)? {
255 ast::GeneralResponse::Success => {}
256 ast::GeneralResponse::Error(e) => {
257 return Err(Error::Smt(e.to_string(), cmd.to_string()))
258 }
259 _ => todo!(),
260 };
261 Ok(())
262 }
263
264 fn declare_all_consts(&mut self, t: &'st ast::Term<'st>) -> Result<(), Error> {
265 for q in t.all_consts() {
266 self.declare_const(q)?;
267 }
268 Ok(())
269 }
270
271 fn declare_const(&mut self, q: &QualIdentifier<'st>) -> Result<(), Error> {
272 match q {
273 QualIdentifier::Identifier(_) => {}
274 QualIdentifier::Sorted(i, s) => {
275 self.declare_sort(s)?;
276
277 match self.decls.entry(*i) {
278 Entry::Occupied(stored) => assert_eq!(s, stored.get()),
279 Entry::Vacant(v) => {
280 v.insert(*s);
281 match i {
282 Identifier::Simple(sym) => {
283 self.driver.exec(ast::Command::DeclareConst(*sym, *s))?;
284 }
285 Identifier::Indexed(_, _) => todo!(),
286 }
287 }
288 }
289 }
290 };
291 Ok(())
292 }
293
294 fn declare_sort(&mut self, s: &ast::Sort<'st>) -> Result<(), Error> {
295 if self.declared_sorts.contains(s) {
296 return Ok(());
297 }
298 self.declared_sorts.insert(*s);
299
300 let cmd = match s {
301 ast::Sort::Sort(ident) => {
302 let sym = match ident {
303 Identifier::Simple(sym) => sym,
304 Identifier::Indexed(_, _) => {
305 return Ok(());
308 }
309 };
310 if sorts::is_built_in_sort(sym.0) {
311 return Ok(());
312 }
313 ast::Command::DeclareSort(*sym, Numeral::from_usize(0))
314 }
315 ast::Sort::Parametric(ident, params) => {
316 let sym = match ident {
317 Identifier::Simple(sym) => sym,
318 Identifier::Indexed(_, _) => {
319 return Ok(());
322 }
323 };
324 if sorts::is_built_in_sort(sym.0) {
325 return Ok(());
326 }
327 ast::Command::DeclareSort(*sym, Numeral::from_usize(params.len()))
328 }
329 };
330 match self.driver.exec(cmd)? {
331 ast::GeneralResponse::Success => Ok(()),
332 ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
333 _ => todo!(),
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use smtlib_lowlevel::{backend::z3_binary::Z3Binary, Storage};
341
342 use super::Solver;
343 use crate::{terms::StaticSorted, Int, SatResult, Sorted};
344
345 #[test]
346 fn scope() -> Result<(), crate::Error> {
347 let st = Storage::new();
348 let mut solver = Solver::new(&st, Z3Binary::new("z3").unwrap())?;
349
350 let x = Int::new_const(&st, "x");
351
352 solver.assert(x._eq(10))?;
353
354 solver.scope(|solver| {
355 solver.assert(x._eq(20))?;
356
357 assert_eq!(solver.check_sat()?, SatResult::Unsat);
358
359 Ok(())
360 })?;
361
362 assert_eq!(solver.check_sat()?, SatResult::Sat);
363
364 Ok(())
365 }
366}