smtlib/
funs.rs

1//! Function declarations.
2
3use itertools::Itertools;
4use smtlib_lowlevel::{ast, lexicon::Symbol, Storage};
5
6use crate::{
7    sorts::Sort,
8    terms::{qual_ident, Dynamic, STerm},
9    Sorted,
10};
11
12/// A function declaration.
13#[derive(Debug)]
14pub struct Fun<'st> {
15    /// smtlib storage
16    pub st: &'st Storage,
17    /// The name of the function.
18    pub name: &'st str,
19    /// The sorts of the arguments.
20    pub vars: &'st [Sort<'st>],
21    /// The sort of the return value.
22    pub return_sort: Sort<'st>,
23}
24
25impl<'st> Fun<'st> {
26    /// Create a new function declaration.
27    pub fn new(
28        st: &'st Storage,
29        name: impl Into<String>,
30        vars: Vec<Sort<'st>>,
31        return_ty: Sort<'st>,
32    ) -> Self {
33        Self {
34            st,
35            name: st.alloc_str(&name.into()),
36            vars: st.alloc_slice(&vars),
37            return_sort: return_ty,
38        }
39    }
40
41    /// Call the function with the given arguments.
42    ///
43    /// The arguments must be sorted in the same order as the function
44    /// declaration and checked for both arity and sort.
45    pub fn call(&self, args: &[Dynamic<'st>]) -> Result<Dynamic<'st>, crate::Error> {
46        if self.vars.len() != args.len() {
47            todo!()
48        }
49        for (expected, given) in self.vars.iter().zip(args) {
50            if *expected != given.sort() {
51                todo!("expected {expected:?} given {:?}", given.sort())
52            }
53        }
54        let term = if args.is_empty() {
55            ast::Term::Identifier(qual_ident(self.name, None))
56        } else {
57            ast::Term::Application(
58                qual_ident(self.name, None),
59                self.st
60                    .alloc_slice(&args.iter().map(|arg| arg.term()).collect_vec()),
61            )
62        };
63        Ok(Dynamic::from_term_sort(
64            STerm::new(self.st, term),
65            self.return_sort,
66        ))
67    }
68
69    /// Get the lowlevel AST representation of the function declaration.
70    pub fn ast(&self) -> ast::FunctionDec {
71        ast::FunctionDec(
72            Symbol(self.name),
73            self.st
74                .alloc_slice(&self.vars.iter().map(|sort| sort.ast()).collect_vec()),
75            self.return_sort.ast(),
76        )
77    }
78}