1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
extern crate proc_macro;

#[proc_macro_derive(SimpleUserFacingError, attributes(user_facing))]
pub fn derive_simple_user_facing_error(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = syn::parse_macro_input!(input as syn::DeriveInput);

    let data = match &input.data {
        syn::Data::Struct(data) => data,
        _ => {
            return syn::Error::new_spanned(input, "derive works only on structs")
                .to_compile_error()
                .into()
        }
    };

    if !data.fields.is_empty() {
        return syn::Error::new_spanned(&data.fields, "SimpleUserFacingError implementors cannot have fields")
            .to_compile_error()
            .into();
    }

    let UserErrorDeriveInput { ident, code, message } = match UserErrorDeriveInput::new(&input) {
        Ok(input) => input,
        Err(err) => return err.into_compile_error().into(),
    };

    proc_macro::TokenStream::from(quote::quote! {
        impl serde::Serialize for #ident {
            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
                where S: serde::Serializer
            {
                serializer.serialize_none()
            }
        }

        impl crate::SimpleUserFacingError for #ident {
            const ERROR_CODE: &'static str = #code;
            const MESSAGE: &'static str = #message;
        }
    })
}

#[proc_macro_derive(UserFacingError, attributes(user_facing))]
pub fn derive_user_facing_error(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = syn::parse_macro_input!(input as syn::DeriveInput);

    let data = match &input.data {
        syn::Data::Struct(data) => data,
        _ => {
            return syn::Error::new_spanned(input, "derive works only on structs")
                .to_compile_error()
                .into()
        }
    };

    let UserErrorDeriveInput { ident, code, message } = match UserErrorDeriveInput::new(&input) {
        Ok(input) => input,
        Err(err) => return err.into_compile_error().into(),
    };

    let template_variables: Box<dyn Iterator<Item = _>> = match &data.fields {
        syn::Fields::Named(named) => Box::new(named.named.iter().map(|field| field.ident.as_ref().unwrap())),
        syn::Fields::Unit => Box::new(std::iter::empty()),
        syn::Fields::Unnamed(unnamed) => {
            return syn::Error::new_spanned(unnamed, "The error fields must be named")
                .to_compile_error()
                .into()
        }
    };

    proc_macro::TokenStream::from(quote::quote! {
        impl crate::UserFacingError for #ident {
            const ERROR_CODE: &'static str = #code;

            fn message(&self) -> String {
                format!(
                    #message,
                    #(
                        #template_variables = self.#template_variables
                    ),*
                )
            }
        }
    })
}

struct UserErrorDeriveInput<'a> {
    /// The name of the struct.
    ident: &'a syn::Ident,
    /// The error code.
    code: syn::LitStr,
    /// The error message format string.
    message: syn::LitStr,
}

impl<'a> UserErrorDeriveInput<'a> {
    fn new(input: &'a syn::DeriveInput) -> Result<Self, syn::Error> {
        let mut code = None;
        let mut message = None;

        for attr in &input.attrs {
            if !attr
                .path
                .get_ident()
                .map(|ident| ident == "user_facing")
                .unwrap_or(false)
            {
                continue;
            }

            for namevalue in attr.parse_args_with(|stream: &'_ syn::parse::ParseBuffer| {
                syn::punctuated::Punctuated::<syn::MetaNameValue, syn::Token![,]>::parse_terminated(stream)
            })? {
                let litstr = match namevalue.lit {
                    syn::Lit::Str(litstr) => litstr,
                    other => {
                        return Err(syn::Error::new_spanned(
                            other,
                            "Expected attribute of the form `#[user_facing(code = \"...\", message = \"...\")]`",
                        ))
                    }
                };

                match namevalue.path.get_ident() {
                    Some(ident) if ident == "code" => {
                        code = Some(litstr);
                    }
                    Some(ident) if ident == "message" => {
                        message = Some(litstr);
                    }
                    other => {
                        return Err(syn::Error::new_spanned(
                            other,
                            "Expected attribute of the form `#[user_facing(code = \"...\", message = \"...\")]`",
                        ))
                    }
                }
            }
        }

        match (message, code) {
            (Some(message), Some(code)) => Ok(UserErrorDeriveInput {
                ident: &input.ident,
                message,
                code,
            }),
            _ => Err(syn::Error::new_spanned(
                input,
                "Expected attribute of the form `#[user_facing(code = \"...\", message = \"...\")]`",
            )),
        }
    }
}