smtlib/theories/
fixed_size_bit_vectors.rs

1#![doc = concat!("```ignore\n", include_str!("./FixedSizeBitVectors.smt2"), "```")]
2
3use itertools::Itertools;
4use smtlib_lowlevel::{
5    ast::{self, Term},
6    lexicon, Storage,
7};
8
9use crate::{
10    sorts::Sort,
11    terms::{app, qual_ident, Const, Dynamic, IntoWithStorage, STerm, Sorted, StaticSorted},
12    Bool,
13};
14
15/// A bit-vec is a fixed size sequence of boolean values. You can [read more
16/// about it
17/// here](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml), among
18/// other places.
19#[derive(Debug, Clone, Copy)]
20pub struct BitVec<'st, const M: usize>(STerm<'st>);
21impl<'st, const M: usize> From<Const<'st, BitVec<'st, M>>> for BitVec<'st, M> {
22    fn from(c: Const<'st, BitVec<'st, M>>) -> Self {
23        c.1
24    }
25}
26impl<const M: usize> std::fmt::Display for BitVec<'_, M> {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        self.term().fmt(f)
29    }
30}
31
32impl<'st, const M: usize> From<BitVec<'st, M>> for Dynamic<'st> {
33    fn from(i: BitVec<'st, M>) -> Self {
34        i.into_dynamic()
35    }
36}
37
38impl<'st, const M: usize> From<BitVec<'st, M>> for STerm<'st> {
39    fn from(i: BitVec<'st, M>) -> Self {
40        i.0
41    }
42}
43impl<'st, const M: usize> From<STerm<'st>> for BitVec<'st, M> {
44    fn from(t: STerm<'st>) -> Self {
45        BitVec(t)
46    }
47}
48
49fn i64_to_bit_array<const M: usize>(i: i64) -> [bool; M] {
50    std::array::from_fn(|idx| (i >> (M - idx - 1)) & 1 == 1)
51}
52
53// #[test]
54// fn test_bit_array() {
55//     assert_eq!(i64_to_bit_array::<4>(8), [true; 4]);
56// }
57
58impl<'st, const M: usize> TryFrom<BitVec<'st, M>> for i64 {
59    type Error = std::num::ParseIntError;
60
61    fn try_from(value: BitVec<'st, M>) -> Result<Self, Self::Error> {
62        match value.term() {
63            Term::SpecConstant(c) => match c {
64                ast::SpecConstant::Numeral(_) => todo!(),
65                ast::SpecConstant::Decimal(_) => todo!(),
66                ast::SpecConstant::Hexadecimal(h) => h.parse(),
67                ast::SpecConstant::Binary(b) => b.parse(),
68                ast::SpecConstant::String(_) => todo!(),
69            },
70            _ => todo!(),
71        }
72    }
73}
74impl<'st, const M: usize> TryFrom<BitVec<'st, M>> for [bool; M] {
75    type Error = std::num::ParseIntError;
76
77    fn try_from(value: BitVec<'st, M>) -> Result<Self, Self::Error> {
78        Ok(i64_to_bit_array(value.try_into()?))
79    }
80}
81
82impl<'st, const M: usize> StaticSorted<'st> for BitVec<'st, M> {
83    type Inner = Self;
84    const AST_SORT: ast::Sort<'static> = ast::Sort::new_indexed(
85        "BitVec",
86        &[ast::Index::Numeral(lexicon::Numeral::from_usize(M))],
87    );
88    fn static_st(&self) -> &'st Storage {
89        self.sterm().st()
90    }
91
92    fn sort() -> Sort<'st> {
93        Self::AST_SORT.into()
94    }
95
96    fn new_const(st: &'st Storage, name: &str) -> Const<'st, Self> {
97        let name = st.alloc_str(name);
98        let bv = Term::Identifier(qual_ident(name, Some(Self::AST_SORT)));
99        let bv = STerm::new(st, bv);
100        Const(name, bv.into())
101    }
102}
103impl<'st, const M: usize> IntoWithStorage<'st, BitVec<'st, M>> for [bool; M] {
104    fn into_with_storage(self, st: &'st Storage) -> BitVec<'st, M> {
105        let term = Term::Identifier(qual_ident(
106            st.alloc_str(&format!("#b{}", self.iter().map(|x| *x as u8).format(""))),
107            None,
108        ));
109        STerm::new(st, term).into()
110    }
111}
112impl<'st, const M: usize> IntoWithStorage<'st, BitVec<'st, M>> for i64 {
113    fn into_with_storage(self, st: &'st Storage) -> BitVec<'st, M> {
114        i64_to_bit_array(self).into_with_storage(st)
115    }
116}
117impl<'st, const M: usize> BitVec<'st, M> {
118    /// Construct a new bit-vec.
119    pub fn new(
120        st: &'st Storage,
121        value: impl IntoWithStorage<'st, BitVec<'st, M>>,
122    ) -> BitVec<'st, M> {
123        value.into_with_storage(st)
124    }
125    fn binop<T: From<STerm<'st>>>(self, op: &'st str, other: BitVec<'st, M>) -> T {
126        app(self.st(), op, (self.term(), other.term())).into()
127    }
128    fn unop<T: From<STerm<'st>>>(self, op: &'st str) -> T {
129        app(self.st(), op, self.term()).into()
130    }
131
132    #[cfg(feature = "const-bit-vec")]
133    /// Extract a slice of the bit-vec.
134    ///
135    /// The constraints `I`, `J`, and `M` are:
136    ///
137    /// ```ignore
138    /// M > I >= J
139    /// ```
140    pub fn extract<const I: usize, const J: usize>(self) -> BitVec<'st, { I - J + 1 }> {
141        assert!(M > I);
142        assert!(I >= J);
143
144        Term::Application(
145            ast::QualIdentifier::Identifier(ast::Identifier::Indexed(
146                Symbol("extract".to_string()),
147                vec![
148                    Index::Numeral(Numeral(I.to_string())),
149                    Index::Numeral(Numeral(J.to_string())),
150                ],
151            )),
152            vec![self.into()],
153        )
154        .into()
155    }
156    #[cfg(feature = "const-bit-vec")]
157    /// Concatenates `self` and `other` bit-vecs to a single contiguous bit-vec
158    /// with length `N + M`
159    pub fn concat<const N: usize>(
160        self,
161        other: impl Into<BitVec<'st, N>>,
162    ) -> BitVec<'st, { N + M }> {
163        Term::Application(
164            qual_ident("concat".to_string(), None),
165            vec![self.into(), other.into().into()],
166        )
167        .into()
168    }
169
170    // Unary
171    /// Calls `(bvnot self)` i.e. bitwise not
172    pub fn bvnot(self) -> Self {
173        self.unop("bvnot")
174    }
175    /// Calls `(bvneg self)` i.e. two's complement negation
176    pub fn bvneg(self) -> Self {
177        self.unop("bvneg")
178    }
179
180    // Binary
181    /// Calls `(bvnand self other)` i.e. bitwise nand
182    pub fn bvnand(self, other: impl Into<Self>) -> Self {
183        self.binop("bvnand", other.into())
184    }
185    /// Calls `(bvnor self other)` i.e. bitwise nor
186    pub fn bvnor(self, other: impl Into<Self>) -> Self {
187        self.binop("bvnor", other.into())
188    }
189    /// Calls `(bvxnor self other)` i.e. bitwise xnor
190    pub fn bvxnor(self, other: impl Into<Self>) -> Self {
191        self.binop("bvxnor", other.into())
192    }
193    /// Calls `(bvult self other)`
194    pub fn bvult(self, other: impl Into<Self>) -> Bool<'st> {
195        self.binop("bvult", other.into())
196    }
197    /// Calls `(bvule self other)` i.e. unsigned less or equal
198    pub fn bvule(self, other: impl Into<Self>) -> Bool<'st> {
199        self.binop("bvule", other.into())
200    }
201    /// Calls `(bvugt self other)` i.e. unsigned greater than
202    pub fn bvugt(self, other: impl Into<Self>) -> Bool<'st> {
203        self.binop("bvugt", other.into())
204    }
205    /// Calls `(bvuge self other)` i.e. unsigned greater or equal
206    pub fn bvuge(self, other: impl Into<Self>) -> Bool<'st> {
207        self.binop("bvuge", other.into())
208    }
209    /// Calls `(bvslt self other)` i.e. signed less than
210    pub fn bvslt(self, other: impl Into<Self>) -> Bool<'st> {
211        self.binop("bvslt", other.into())
212    }
213    /// Calls `(bvsle self other)` i.e. signed less or equal
214    pub fn bvsle(self, other: impl Into<Self>) -> Bool<'st> {
215        self.binop("bvsle", other.into())
216    }
217    /// Calls `(bvsgt self other)` i.e. signed greater than
218    pub fn bvsgt(self, other: impl Into<Self>) -> Bool<'st> {
219        self.binop("bvsgt", other.into())
220    }
221    /// Calls `(bvsge self other)` i.e. signed greater or equal
222    pub fn bvsge(self, other: impl Into<Self>) -> Bool<'st> {
223        self.binop("bvsge", other.into())
224    }
225}
226
227impl<const M: usize> std::ops::Not for BitVec<'_, M> {
228    type Output = Self;
229    fn not(self) -> Self::Output {
230        self.bvnot()
231    }
232}
233impl<const M: usize> std::ops::Neg for BitVec<'_, M> {
234    type Output = Self;
235    fn neg(self) -> Self::Output {
236        self.bvneg()
237    }
238}
239
240macro_rules! impl_op {
241    ($ty:ty, $other:ty, $trait:tt, $fn:ident, $op:ident, $a_trait:tt, $a_fn:tt, $a_op:tt) => {
242        impl<'st, const M: usize, R> std::ops::$trait<R> for Const<'st, $ty>
243        where
244            R: Into<$ty>,
245        {
246            type Output = $ty;
247            fn $fn(self, rhs: R) -> Self::Output {
248                self.1.binop(stringify!($op), rhs.into())
249            }
250        }
251        impl<'st, const M: usize, R> std::ops::$trait<R> for $ty
252        where
253            R: Into<$ty>,
254        {
255            type Output = Self;
256            fn $fn(self, rhs: R) -> Self::Output {
257                self.binop(stringify!($op), rhs.into())
258            }
259        }
260        // impl<'st, const M: usize> std::ops::$trait<Const<'st, $ty>> for $other {
261        //     type Output = $ty;
262        //     fn $fn(self, rhs: Const<'st, $ty>) -> Self::Output {
263        //         <$ty>::from(self).binop(stringify!($op), rhs.1)
264        //     }
265        // }
266        // impl<'st, const M: usize> std::ops::$trait<$ty> for $other {
267        //     type Output = $ty;
268        //     fn $fn(self, rhs: $ty) -> Self::Output {
269        //         <$ty>::from(self).binop(stringify!($op), rhs)
270        //     }
271        // }
272        impl<'st, const M: usize, R> std::ops::$a_trait<R> for $ty
273        where
274            R: Into<$ty>,
275        {
276            fn $a_fn(&mut self, rhs: R) {
277                *self = *self $a_op rhs;
278            }
279        }
280        impl<'st, const M: usize> $ty {
281            #[doc = concat!("Calls `(", stringify!($op), " self other)`")]
282            pub fn $op(self, other: impl Into<Self>) -> Self {
283                self.binop(stringify!($op), other.into())
284            }
285        }
286    };
287}
288
289impl_op!(BitVec<'st, M>, [bool; M], BitAnd, bitand, bvand, BitAndAssign, bitand_assign, &);
290impl_op!(BitVec<'st, M>, [bool; M], BitOr, bitor, bvor, BitOrAssign, bitor_assign, |);
291impl_op!(BitVec<'st, M>, [bool; M], BitXor, bitxor, bvxor, BitXorAssign, bitxor_assign, ^);
292impl_op!(BitVec<'st, M>, [bool; M], Add, add, bvadd, AddAssign, add_assign, +);
293// impl_op!(BitVec<'st, M>, [bool; M], Sub, sub, bvsub, SubAssign, sub_assign,
294// -);
295impl_op!(BitVec<'st, M>, [bool; M], Mul, mul, bvmul, MulAssign, mul_assign, *);
296impl_op!(BitVec<'st, M>, [bool; M], Div, div, bvudiv, DivAssign, div_assign, /);
297impl_op!(BitVec<'st, M>, [bool; M], Rem, rem, bvurem, RemAssign, rem_assign, %);
298impl_op!(BitVec<'st, M>, [bool; M], Shr, shr, bvlshr, ShrAssign, shr_assign, >>);
299impl_op!(BitVec<'st, M>, [bool; M], Shl, shl, bvshl, ShlAssign, shl_assign, <<);
300
301#[cfg(feature = "const-bit-vec")]
302#[cfg(test)]
303mod tests {
304    use smtlib_lowlevel::backend::Z3Binary;
305
306    use super::BitVec;
307    use crate::{terms::Sorted, Solver};
308
309    #[test]
310    fn bit_vec_extract_concat() -> Result<(), Box<dyn std::error::Error>> {
311        let a = BitVec::<6>::from_name("a");
312        let b = BitVec::from_name("b");
313        let c = BitVec::from_name("c");
314        let d = BitVec::from([true, false, true, true, false, true]);
315
316        let mut solver = Solver::new(Z3Binary::new("z3")?)?;
317
318        solver.assert(a._eq(!d))?;
319        solver.assert(b._eq(a.extract::<5, 2>()))?;
320        solver.assert(c._eq(a.concat(b)))?;
321
322        let model = solver.check_sat_with_model()?.expect_sat()?;
323
324        let a: [bool; 6] = model.eval(a).unwrap().try_into()?;
325        let b: [bool; 4] = model.eval(b).unwrap().try_into()?;
326        let c: [bool; 10] = model.eval(c).unwrap().try_into()?;
327        insta::assert_ron_snapshot!(a, @"(false, true, false, false, true, false)");
328        insta::assert_ron_snapshot!(b, @"(false, true, false, false)");
329        insta::assert_ron_snapshot!(c, @"(false, true, false, false, true, false, false, true, false, false)");
330
331        Ok(())
332    }
333
334    // #[test]
335    // fn bit_vec_math() -> Result<(), Box<dyn std::error::Error>> {
336    //     let a = BitVec::<6>::from_name("a");
337    //     let b = BitVec::<6>::from_name("b");
338    //     let c = BitVec::<6>::from_name("c");
339
340    //     let mut solver = Solver::new(Z3Binary::new("z3")?)?;
341
342    //     solver.assert(a._eq(BitVec::<6>::from(10)))?;
343    //     solver.assert(b._eq(BitVec::<6>::from(3)))?;
344    //     // solver.assert(c._eq(a % b))?;
345    //     solver.assert(c._eq(a + b))?;
346
347    //     solver.check_sat()?;
348    //     let model = solver.get_model()?;
349
350    //     let a: i64 = model.eval(a).unwrap().try_into()?;
351    //     let b: i64 = model.eval(b).unwrap().try_into()?;
352    //     let c: i64 = model.eval(c).unwrap().try_into()?;
353    //     insta::assert_ron_snapshot!(c, @"");
354    //     // insta::assert_ron_snapshot!(b, @"");
355    //     // insta::assert_ron_snapshot!(c, @"");
356
357    //     Ok(())
358    // }
359}