tapi_macro/
lib.rs

1use darling::FromMeta;
2use proc_macro2::Ident;
3use quote::format_ident;
4use serde_derive_internals::{ast, attr::TagType};
5use syn::parse_macro_input;
6
7#[derive(Debug)]
8struct Args {
9    path: String,
10    method: Ident,
11}
12
13impl syn::parse::Parse for Args {
14    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
15        syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated(input).map(
16            |punctuated| {
17                let mut path = None;
18                let mut method = None;
19                for meta in punctuated {
20                    match meta {
21                        syn::Meta::NameValue(syn::MetaNameValue {
22                            path: syn::Path { segments, .. },
23                            value,
24                            ..
25                        }) => {
26                            let ident = segments.first().unwrap().ident.to_string();
27                            match ident.as_str() {
28                                "path" => {
29                                    path = {
30                                        match value {
31                                            syn::Expr::Lit(syn::ExprLit {
32                                                lit: syn::Lit::Str(lit_str),
33                                                ..
34                                            }) => Some(lit_str.value()),
35                                            _ => panic!("unknown attribute"),
36                                        }
37                                    }
38                                }
39                                "method" => {
40                                    method = {
41                                        match value {
42                                            syn::Expr::Path(syn::ExprPath { path, .. }) => {
43                                                Some(path.segments.first().unwrap().ident.clone())
44                                            }
45                                            _ => panic!("unknown attribute"),
46                                        }
47                                    }
48                                }
49                                _ => panic!("unknown attribute"),
50                            }
51                        }
52                        _ => panic!("unknown attribute"),
53                    }
54                }
55                Args {
56                    path: path.unwrap(),
57                    method: method.unwrap(),
58                }
59            },
60        )
61    }
62}
63
64#[proc_macro_attribute]
65pub fn tapi(
66    attr: proc_macro::TokenStream,
67    item: proc_macro::TokenStream,
68) -> proc_macro::TokenStream {
69    let item = proc_macro2::TokenStream::from(item);
70
71    let Args { path, method } = parse_macro_input!(attr as Args);
72
73    let fn_ = syn::parse2::<syn::ItemFn>(item.clone()).unwrap();
74
75    let name = fn_.sig.ident;
76    let mut body_ty = Vec::new();
77    for inp in &fn_.sig.inputs {
78        match inp {
79            syn::FnArg::Receiver(_) => {
80                todo!("idk what to do with receivers")
81            }
82            syn::FnArg::Typed(t) => {
83                body_ty.push((*t.ty).clone());
84            }
85        }
86    }
87    let res_ty = match &fn_.sig.output {
88        syn::ReturnType::Default => None,
89        syn::ReturnType::Type(_, ty) => Some((**ty).clone()),
90    };
91
92    let res_ty = res_ty.unwrap_or_else(|| {
93        syn::parse2::<syn::Type>(quote::quote! {
94            ()
95        })
96        .unwrap()
97    });
98
99    let handler = match method.to_string().as_str() {
100        "Get" => format_ident!("get"),
101        "Post" => format_ident!("post"),
102        "Put" => format_ident!("put"),
103        "Delete" => format_ident!("delete"),
104        "Patch" => format_ident!("patch"),
105        _ => todo!("unknown method: {}", method.to_string()),
106    };
107
108    let output = quote::quote! {
109        mod #name {
110            #![allow(unused_parens)]
111
112            use super::*;
113            pub struct endpoint;
114            impl ::tapi::endpoints::Endpoint<AppState> for endpoint {
115                fn path(&self) -> &'static str {
116                    #path
117                }
118                fn method(&self) -> ::tapi::endpoints::Method {
119                    ::tapi::endpoints::Method::#method
120                }
121                fn bind_to(&self, router: ::axum::Router<AppState>) -> ::axum::Router<AppState> {
122                    router.route(#path, ::axum::routing::#handler(super::#name))
123                }
124                fn body(&self) -> ::tapi::endpoints::RequestStructure {
125                    let mut s = ::tapi::endpoints::RequestStructure::new(::tapi::endpoints::Method::#method);
126                    #(
127                        s.merge_with(
128                            <#body_ty as ::tapi::endpoints::RequestTapiExtractor>::extract_request()
129                        );
130                    )*
131                    s
132                }
133                fn res(&self) -> ::tapi::endpoints::ResponseTapi {
134                    <#res_ty as ::tapi::endpoints::ResponseTapiExtractor>::extract_response()
135                }
136            }
137        }
138
139        #[tracing::instrument(name = "route", skip_all, fields(path = #path, method = stringify!(#method)))]
140        #item
141    };
142    output.into()
143}
144
145#[derive(Debug, Default, FromMeta)]
146struct DeriveInput {
147    krate: Option<String>,
148    path: Option<String>,
149}
150
151#[proc_macro_derive(Tapi, attributes(serde, tapi))]
152pub fn tapi_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
153    let input = proc_macro2::TokenStream::from(input);
154
155    let derive_input = syn::parse2::<syn::DeriveInput>(input.clone()).unwrap();
156
157    let tapi_derive_input = derive_input
158        .attrs
159        .iter()
160        .find_map(|attr| {
161            if attr.meta.path().is_ident("tapi") {
162                Some(
163                    DeriveInput::from_meta(&attr.meta)
164                        .unwrap_or_else(|_| panic!("at: {}", line!())),
165                )
166            } else {
167                None
168            }
169        })
170        .unwrap_or_default();
171
172    let tapi_path = tapi_derive_input
173        .krate
174        .as_ref()
175        .map(|krate| {
176            syn::parse_str(krate)
177                .unwrap_or_else(|_| panic!("failed to parse krate path: {}", line!()))
178        })
179        .unwrap_or_else(|| quote::quote!(::tapi));
180
181    let path = match &tapi_derive_input.path {
182        Some(path) => {
183            let path = path.split("::");
184            quote::quote!(
185                fn path() -> Vec<&'static str> {
186                    vec![#(#path),*]
187                }
188            )
189        }
190        None => quote::quote!(),
191    };
192
193    let name = derive_input.ident.clone();
194    let generics = derive_input.generics.params.clone();
195    let mut sgenerics = Vec::new();
196    let mut life_times = Vec::new();
197    for g in &generics {
198        match g {
199            syn::GenericParam::Lifetime(l) => {
200                life_times.push(l.lifetime.clone());
201            }
202            syn::GenericParam::Type(ty) => {
203                let ident = &ty.ident;
204                sgenerics.push(quote::quote!(#ident))
205            }
206            syn::GenericParam::Const(_) => todo!("syn::GenericParam::Const"),
207        }
208    }
209    let container = {
210        let cx = serde_derive_internals::Ctxt::new();
211        let container = ast::Container::from_ast(
212            &cx,
213            &derive_input,
214            serde_derive_internals::Derive::Serialize,
215        )
216        .unwrap();
217        cx.check().unwrap();
218        container
219    };
220
221    let attr = build_container_attributes(&container, &tapi_path);
222
223    let result: proc_macro2::TokenStream = match &container.data {
224        ast::Data::Struct(_style, st_fields) => {
225            // TODO: rewrite this to use the `style`
226            let mut fields = Vec::new();
227            let mut kind_fields = Vec::new();
228            let mut tuple_fields = Vec::new();
229            for field in st_fields {
230                let ty = field.ty.clone();
231                let field_flags = &field;
232                let attr = build_field_attributes(&field_flags.attrs, &tapi_path);
233                let field_name = match field.original.ident.clone() {
234                    Some(_) => {
235                        let serialize_name =
236                            format_ident!("{}", field.attrs.name().serialize_name());
237                        let deserialize_name =
238                            format_ident!("{}", field.attrs.name().deserialize_name());
239
240                        quote::quote!(#tapi_path::kind::FieldName::Named(#tapi_path::kind::Name {
241                            serialize_name: stringify!(#serialize_name).to_string(),
242                            deserialize_name: stringify!(#deserialize_name).to_string(),
243                        }))
244                    }
245                    None => {
246                        tuple_fields.push(quote::quote!(#tapi_path::kind::TupleStructField {
247                            attr: #attr,
248                            ty: <#ty as #tapi_path::Tapi>::boxed(),
249                        }));
250                        continue;
251                    }
252                };
253                fields.push(field.ty.clone());
254                kind_fields.push(quote::quote!(
255                    #tapi_path::kind::Field {
256                        attr: #attr,
257                        name: #field_name,
258                        ty: <#ty as #tapi_path::Tapi>::boxed(),
259                    }
260                ));
261            }
262            if tuple_fields.is_empty() {
263                quote::quote! {
264                    #[allow(unused_parens)]
265                    impl<#(#life_times,)* #(#sgenerics: 'static + #tapi_path::Tapi),*> #tapi_path::Tapi for #name<#(#life_times,)* #(#sgenerics),*> {
266                        fn name() -> &'static str {
267                            stringify!(#name)
268                        }
269                        fn id() -> std::any::TypeId {
270                            std::any::TypeId::of::<#name<#(#sgenerics),*>>()
271                        }
272                        #path
273                        fn kind() -> #tapi_path::kind::TypeKind {
274                            #tapi_path::kind::TypeKind::Struct(#tapi_path::kind::Struct {
275                                attr: #attr,
276                                fields: [#(#kind_fields),*].to_vec(),
277                            })
278                        }
279                    }
280                }
281            } else {
282                assert!(kind_fields.is_empty());
283                quote::quote! {
284                    #[allow(unused_parens)]
285                    impl<#(#life_times,)* #(#sgenerics: 'static + #tapi_path::Tapi),*> #tapi_path::Tapi for #name<#(#life_times,)* #(#sgenerics),*> {
286                        fn name() -> &'static str {
287                            stringify!(#name)
288                        }
289                        fn id() -> std::any::TypeId {
290                            std::any::TypeId::of::<#name<#(#sgenerics),*>>()
291                        }
292                        #path
293                        fn kind() -> #tapi_path::kind::TypeKind {
294                            #tapi_path::kind::TypeKind::TupleStruct(#tapi_path::kind::TupleStruct {
295                                attr: #attr,
296                                fields: [#(#tuple_fields),*].to_vec(),
297                            })
298                        }
299                    }
300                }
301            }
302        }
303        ast::Data::Enum(en_variants) => {
304            let mut kind_variants = Vec::new();
305            for variant in en_variants {
306                let ident = &variant.ident;
307
308                match &variant.style {
309                    ast::Style::Unit => {
310                        assert!(variant.fields.is_empty(), "unit has no fields");
311
312                        kind_variants.push(quote::quote!(#tapi_path::kind::EnumVariant {
313                            name: stringify!(#ident).to_string(),
314                            kind: #tapi_path::kind::VariantKind::Unit,
315                        }))
316                    }
317                    ast::Style::Struct => {
318                        let fields = variant.fields.iter().map(|f| {
319                            let ty = f.ty.clone();
320                            let attr = build_field_attributes(&f.attrs, &tapi_path);
321
322                            let serialize_name =
323                                format_ident!("{}", f.attrs.name().serialize_name());
324                            let deserialize_name =
325                                format_ident!("{}", f.attrs.name().deserialize_name());
326
327                            quote::quote!(
328                                #tapi_path::kind::Field {
329                                    attr: #attr,
330                                    name: #tapi_path::kind::FieldName::Named(#tapi_path::kind::Name {
331                                        serialize_name: stringify!(#serialize_name).to_string(),
332                                        deserialize_name: stringify!(#deserialize_name).to_string(),
333                                    }),
334                                    ty: <#ty as #tapi_path::Tapi>::boxed(),
335                                }
336                            )
337                        });
338                        kind_variants.push(quote::quote!(#tapi_path::kind::EnumVariant {
339                            name: stringify!(#ident).to_string(),
340                            kind: #tapi_path::kind::VariantKind::Struct([#(#fields),*].to_vec()),
341                        }))
342                    }
343                    ast::Style::Tuple => {
344                        let fields = variant.fields.iter().map(|f| f.ty.clone());
345                        kind_variants.push(quote::quote!(#tapi_path::kind::EnumVariant {
346                            name: stringify!(#ident).to_string(),
347                            kind: #tapi_path::kind::VariantKind::Tuple([#(<#fields as #tapi_path::Tapi>::boxed()),*].to_vec()),
348                        }))
349                    }
350                    ast::Style::Newtype => {
351                        assert_eq!(variant.fields.len(), 1, "newtype has exactly one field");
352
353                        let fields = variant.fields.iter().map(|f| f.ty.clone());
354                        kind_variants.push(quote::quote!(#tapi_path::kind::EnumVariant {
355                            name: stringify!(#ident).to_string(),
356                            kind: #tapi_path::kind::VariantKind::Tuple([#(<#fields as #tapi_path::Tapi>::boxed()),*].to_vec()),
357                        }))
358                    }
359                }
360            }
361            quote::quote! {
362                #[allow(unused_parens)]
363                impl<#(#life_times,)* #(#sgenerics: 'static + #tapi_path::Tapi),*> #tapi_path::Tapi for #name<#(#life_times,)* #(#sgenerics),*> {
364                    fn name() -> &'static str {
365                        stringify!(#name)
366                    }
367                    fn id() -> std::any::TypeId {
368                        std::any::TypeId::of::<#name>()
369                    }
370                    #path
371                    fn kind() -> #tapi_path::kind::TypeKind {
372                        #tapi_path::kind::TypeKind::Enum(#tapi_path::kind::Enum {
373                            attr: #attr,
374                            variants: [#(#kind_variants),*].to_vec(),
375                        })
376                    }
377                }
378            }
379        }
380    };
381
382    // let pretty = prettyplease::unparse(&syn::parse2(result.clone()).unwrap());
383    // eprintln!("{pretty}");
384    result.into()
385}
386
387fn build_container_attributes(
388    serde_flags: &ast::Container<'_>,
389    tapi_path: &proc_macro2::TokenStream,
390) -> proc_macro2::TokenStream {
391    let name = {
392        let serialize_name = serde_flags.attrs.name().serialize_name();
393        let deserialize_name = serde_flags.attrs.name().deserialize_name();
394        quote::quote!(#tapi_path::kind::Name {
395            serialize_name: #serialize_name.to_string(),
396            deserialize_name: #deserialize_name.to_string(),
397        })
398    };
399    let transparent = serde_flags.attrs.transparent();
400    let deny_unknown_fields = serde_flags.attrs.deny_unknown_fields();
401    let default = match serde_flags.attrs.default() {
402        serde_derive_internals::attr::Default::None => {
403            quote::quote!(#tapi_path::kind::Default::None)
404        }
405        serde_derive_internals::attr::Default::Default => {
406            quote::quote!(#tapi_path::kind::Default::Default)
407        }
408        serde_derive_internals::attr::Default::Path(_) => {
409            quote::quote!(#tapi_path::kind::Default::Path)
410        }
411    };
412    let tag = {
413        let tag = serde_flags.attrs.tag();
414        match tag {
415            TagType::External => quote::quote!(#tapi_path::kind::TagType::External),
416            TagType::Internal { tag } => {
417                quote::quote!(#tapi_path::kind::TagType::Internal { tag: #tag.to_string() })
418            }
419            TagType::Adjacent { tag, content } => quote::quote!(
420                #tapi_path::kind::TagType::Adjacent {
421                    tag: #tag.to_string(),
422                    content: #content.to_string(),
423                }
424            ),
425            TagType::None => quote::quote!(#tapi_path::kind::TagType::None),
426        }
427    };
428    let type_from = match serde_flags.attrs.type_from() {
429        Some(type_from) => {
430            quote::quote!(Some(<#type_from as #tapi_path::Tapi>::boxed()))
431        }
432        None => quote::quote!(None),
433    };
434    let type_try_from = match serde_flags.attrs.type_try_from() {
435        Some(type_try_from) => {
436            quote::quote!(Some(<#type_try_from as #tapi_path::Tapi>::boxed()))
437        }
438        None => quote::quote!(None),
439    };
440    let type_into = match serde_flags.attrs.type_into() {
441        Some(type_into) => {
442            quote::quote!(Some(<#type_into as #tapi_path::Tapi>::boxed()))
443        }
444        None => quote::quote!(None),
445    };
446    let is_packed = serde_flags.attrs.is_packed();
447    let identifier = match serde_flags.attrs.identifier() {
448        serde_derive_internals::attr::Identifier::No => {
449            quote::quote!(#tapi_path::kind::Identifier::No)
450        }
451        serde_derive_internals::attr::Identifier::Field => {
452            quote::quote!(#tapi_path::kind::Identifier::Field)
453        }
454        serde_derive_internals::attr::Identifier::Variant => {
455            quote::quote!(#tapi_path::kind::Identifier::Variant)
456        }
457    };
458    let has_flatten = serde_flags.attrs.has_flatten();
459    let non_exhaustive = serde_flags.attrs.non_exhaustive();
460    quote::quote!(#tapi_path::kind::ContainerAttributes {
461        name: #name,
462        // rename_all_rules: todo!("rename_all_rules"),
463        // rename_all_fields_rules: todo!("rename_all_fields_rules"),
464        transparent: #transparent,
465        deny_unknown_fields: #deny_unknown_fields,
466        default: #default,
467        // ser_bound: todo!("ser_bound"),
468        // de_bound: todo!("de_bound"),
469        tag: #tag,
470        type_from: #type_from,
471        type_try_from: #type_try_from,
472        type_into: #type_into,
473        // remote: todo!("Pa"),
474        is_packed: #is_packed,
475        identifier: #identifier,
476        has_flatten: #has_flatten,
477        // custom_serde_path: todo!("custom_serde_path"),
478        // serde_path: todo!("serde_path"),
479        // /// Error message generated when type can’t be deserialized. If None, default message will be used
480        // expecting: todo!("expecting"),
481        non_exhaustive: #non_exhaustive,
482    })
483}
484
485fn build_field_attributes(
486    serde_flags: &serde_derive_internals::attr::Field,
487    tapi_path: &proc_macro2::TokenStream,
488) -> proc_macro2::TokenStream {
489    let name = {
490        let serialize_name = serde_flags.name().serialize_name();
491        let deserialize_name = serde_flags.name().deserialize_name();
492        quote::quote!(#tapi_path::kind::Name {
493            serialize_name: #serialize_name.to_string(),
494            deserialize_name: #deserialize_name.to_string(),
495        })
496    };
497    let aliases = {
498        let aliases = serde_flags.aliases();
499        quote::quote!([#(stringify!(#aliases).to_string()),*].into_iter().collect())
500    };
501    let skip_serializing = serde_flags.skip_serializing();
502    let skip_deserializing = serde_flags.skip_deserializing();
503    let default = match serde_flags.default() {
504        serde_derive_internals::attr::Default::None => {
505            quote::quote!(#tapi_path::kind::Default::None)
506        }
507        serde_derive_internals::attr::Default::Default => {
508            quote::quote!(#tapi_path::kind::Default::Default)
509        }
510        serde_derive_internals::attr::Default::Path(_) => {
511            quote::quote!(#tapi_path::kind::Default::Path)
512        }
513    };
514    let flatten = serde_flags.flatten();
515    let transparent = serde_flags.transparent();
516    quote::quote!(#tapi_path::kind::FieldAttributes {
517        name: #name,
518        aliases: #aliases,
519        skip_serializing: #skip_serializing,
520        skip_deserializing: #skip_deserializing,
521        // skip_serializing_if: #skip_serializing_if,
522        default: #default,
523        // serialize_with: #serialize_with,
524        // deserialize_with: #deserialize_with,
525        // ser_bound: #ser_bound,
526        // de_bound: #de_bound,
527        // borrowed_lifetimes: #borrowed_lifetimes,
528        // getter: #getter,
529        flatten: #flatten,
530        transparent: #transparent,
531    })
532}