1
// SPDX-License-Identifier: GPL-3.0-or-later
2
// Copyright (C) 2023 Kunal Mehta <legoktm@debian.org>
3
//! This crate provides the proc_macro for [`mwbot`](https://docs.rs/mwbot/),
4
//! please refer to its documentation for usage.
5
#![deny(clippy::all)]
6
#![deny(rustdoc::all)]
7

            
8
use proc_macro::TokenStream;
9
use proc_macro2::{Ident, TokenStream as TokenStream2};
10
use quote::{format_ident, quote};
11
use std::collections::HashMap;
12
use syn::meta::ParseNestedMeta;
13
use syn::spanned::Spanned;
14
use syn::{
15
    parse_macro_input, Data, DeriveInput, Error, Expr, ExprLit, Field, Fields,
16
    GenericArgument, Lit, LitBool, LitStr, Path, PathArguments, Result, Type,
17
};
18

            
19
// TODO: figure out how to link to the trait in the docs. See
20
// <https://doc.rust-lang.org/rustdoc/write-documentation/linking-to-items-by-name.html#namespaces-and-disambiguators>
21
/// macro to implement the `Generator` trait
22
///
23
/// See the `Generator` trait documentation for details on usage.
24
#[proc_macro_derive(
25
    Generator,
26
    attributes(generator, params, param, templated_param)
27
)]
28
162
pub fn generator(input: TokenStream) -> TokenStream {
29
    // Parse the input tokens into a syntax tree
30
162
    let input = parse_macro_input!(input as DeriveInput);
31
162
    let tokens = build(input).unwrap_or_else(|err| err.into_compile_error());
32
162
    tokens.into()
33
162
}
34

            
35
162
fn build(input: DeriveInput) -> Result<TokenStream2> {
36
162
    let name = &input.ident;
37
162
    let params = parse(&input)?;
38
162
    let fixed_params = parse_fixed_params(&input)?;
39
162
    let generator_option = parse_generator_option(&input)?;
40
    //dbg!(&params);
41
162
    let plain_impl = build_impl(name, &params, &generator_option.exports);
42
162
    //println!("{}", &plain_impl);
43
162
    let trait_impl =
44
162
        build_trait_impl(name, &params, fixed_params, &generator_option);
45
162

            
46
162
    // Build the output, possibly using quasi-quotation
47
162
    let expanded = quote! {
48
162
        #plain_impl
49
162
        #trait_impl
50
162
    };
51
162
    Ok(expanded)
52
162
}
53

            
54
struct Parameter {
55
    rust_name: Ident,
56
    rust_type: Type,
57
    doc: Option<LitStr>,
58
    mw_name: LitStr,
59
    required: bool,
60
}
61

            
62
impl Parameter {
63
    /// Whether rust_type is "bool"
64
620
    fn is_boolean(&self) -> bool {
65
620
        if let Type::Path(type_path) = &self.rust_type {
66
620
            if let Some(first) = type_path.path.segments.first() {
67
620
                return first.ident == "bool";
68
            }
69
        }
70

            
71
        false
72
620
    }
73
}
74

            
75
/// Peeks through Option<T> to determine the underlying type and if its required
76
/// Given `Option<foo>`, return `(foo, false)`
77
/// Given `foo`, return `(foo, true)`
78
718
fn parse_through_option(ty: &Type) -> (Type, bool) {
79
718
    if let Type::Path(type_path) = ty {
80
718
        if let Some(first) = type_path.path.segments.first() {
81
718
            if first.ident == "Option" {
82
620
                if let PathArguments::AngleBracketed(inside) = &first.arguments
83
                {
84
620
                    if let Some(GenericArgument::Type(ty)) = inside.args.first()
85
                    {
86
620
                        return (ty.clone(), false);
87
                    }
88
                }
89
98
            }
90
        }
91
    }
92
98
    (ty.clone(), true)
93
718
}
94

            
95
/// Build the `Generator` trait implementation
96
162
fn build_trait_impl(
97
162
    name: &Ident,
98
162
    params: &[Parameter],
99
162
    fixed_params: HashMap<String, LitStr>,
100
162
    generator_option: &GeneratorOption,
101
162
) -> TokenStream2 {
102
162
    let GeneratorOption {
103
162
        exports,
104
162
        return_type,
105
162
        response_type,
106
162
        transform_fn,
107
162
        wrap_in_vec,
108
162
    } = generator_option;
109
162

            
110
162
    let fixed_params: Vec<_> = fixed_params
111
162
        .into_iter()
112
447
        .map(|(key, value)| {
113
366
            quote! {
114
366
                map.insert(#key, #value.to_string());
115
366
            }
116
447
        })
117
162
        .collect();
118
162
    let params: Vec<_> = params
119
162
        .iter()
120
799
        .map(|param| {
121
718
            let mw_name = &param.mw_name;
122
718
            let rust_name = &param.rust_name;
123
718
            if param.required {
124
98
                quote! {
125
98
                    map.insert(#mw_name, #exports::ParamValue::stringify(&self.#rust_name));
126
98
                }
127
620
            } else if param.is_boolean() {
128
                // For boolean parameters, include it if it's true, otherwise
129
                // omit it entirely.
130
18
                quote! {
131
18
                    if self.#rust_name.unwrap_or(false) {
132
18
                        map.insert(#mw_name, "1".to_string());
133
18
                    }
134
18
                }
135
            } else {
136
602
                quote! {
137
602
                    if let Some(value) = &self.#rust_name {
138
602
                        map.insert(#mw_name, #exports::ParamValue::stringify(value));
139
602
                    }
140
602
                }
141
            }
142
799
        })
143
162
        .collect();
144

            
145
162
    let wrap_vec = if wrap_in_vec.value {
146
6
        quote! {
147
6
            values
148
6
        }
149
    } else {
150
156
        quote! {
151
156
            {
152
156
                let mut vec = Vec::new();
153
156
                vec.push(values);
154
156
                vec
155
156
            };
156
156
        }
157
    };
158

            
159
162
    quote! {
160
162
        impl #exports::Generator for #name {
161
162
            type Output = #exports::Result<#return_type>;
162
162

            
163
162
            fn params(&self) -> #exports::HashMap<&'static str, String> {
164
162
                let mut map = #exports::HashMap::new();
165
162
                #(#fixed_params)*
166
162
                #(#params)*
167
162
                map
168
162
            }
169
162

            
170
162
            fn generate(self, bot: &#exports::Bot) -> #exports::tokio::Receiver<Self::Output> {
171
162
                let (tx, rx) = #exports::tokio::channel(50);
172
162
                let bot: #exports::Bot = #exports::Clone::clone(&bot);
173
162
                #exports::tokio::spawn(async move {
174
162
                    let params = #exports::IntoIterator::into_iter(#exports::Generator::params(&self));
175
162
                    let params = #exports::Iterator::map(params, |(k, v)| (k.to_string(), v));
176
162
                    let params = #exports::Iterator::collect::<#exports::HashMap<String, String>>(params);
177
162

            
178
162
                    let mut params = #exports::Params {
179
162
                        main: params,
180
162
                        ..#exports::Default::default()
181
162
                    };
182
162
                    loop {
183
162
                        let resp: #response_type =
184
162
                            match #exports::mwapi_responses::query_api(&bot.api(), params.merged())
185
162
                                .await
186
162
                            {
187
162
                                Ok(resp) => resp,
188
162
                                Err(err) => {
189
162
                                    match tx.send(Err(<#exports::Error as #exports::From<_>>::from(err))).await {
190
162
                                        Ok(_) => break,
191
162
                                        // Receiver hung up, just abort
192
162
                                        Err(_) => return,
193
162
                                    }
194
162
                                }
195
162
                            };
196
162
                        params.continue_ = #exports::Clone::clone(&resp.continue_);
197
162
                        for item in #exports::mwapi_responses::ApiResponse::<_>::into_items(resp) {
198
162
                            let values = match #transform_fn(&bot, item)  {
199
162
                                Ok(values) => values,
200
162
                                Err(err) => {
201
162
                                    if tx.send(Err(err)).await.is_err() {
202
162
                                        return;
203
162
                                    };
204
162
                                    continue;
205
162
                                },
206
162
                            };
207
162
                            let values = #wrap_vec;
208
162

            
209
162
                            for value in values {
210
162
                                if tx.send(Ok(value)).await.is_err() {
211
162
                                    // Receiver hung up, just abort
212
162
                                    return;
213
162
                                }
214
162
                            }
215
162
                        }
216
162
                        if params.continue_.is_empty() {
217
162
                            // All done
218
162
                            break;
219
162
                        }
220
162
                    }
221
162
                });
222
162
                rx
223
162
            }
224
162
        }
225
162
    }
226
162
}
227

            
228
/// Build the type's implementation with `new()` and parameter builders
229
162
fn build_impl(
230
162
    name: &Ident,
231
162
    params: &[Parameter],
232
162
    exports: &Path,
233
162
) -> TokenStream2 {
234
162
    let type_def: Vec<_> = params
235
162
        .iter()
236
799
        .filter(|param| param.required)
237
211
        .map(|param| {
238
98
            let name = &param.rust_name;
239
98
            let ty = &param.rust_type;
240
98
            quote! { #name: impl #exports::Into<#ty> }
241
211
        })
242
162
        .collect();
243
162
    let fields: Vec<_> = params
244
162
        .iter()
245
799
        .map(|param| {
246
718
            let name = &param.rust_name;
247
718
            if param.required {
248
98
                quote! { #name: #exports::Into::into(#name) }
249
            } else {
250
620
                quote! { #name: #exports::Option::None }
251
            }
252
799
        })
253
162
        .collect();
254
162

            
255
162
    let setters: Vec<_> = params
256
162
        .iter()
257
799
        .filter(|param| !param.required)
258
701
        .map(|param| setter(param, exports))
259
162
        .collect();
260
162

            
261
162
    quote! {
262
162
        impl #name {
263
162
            pub fn new( #(#type_def),* ) -> Self {
264
162
                Self {
265
162
                    #(#fields),*
266
162
                }
267
162
            }
268
162

            
269
162
            #(#setters)*
270
162
        }
271
162
    }
272
162
}
273

            
274
/// Build a parameter setter function
275
620
fn setter(param: &Parameter, exports: &Path) -> TokenStream2 {
276
620
    let name = &param.rust_name;
277
620
    let ty = &param.rust_type;
278
620
    let doc = match &param.doc {
279
84
        Some(doc) => quote! {
280
84
            #[doc = #doc]
281
84
        },
282
536
        None => quote! {},
283
    };
284
620
    assert!(!param.required);
285
620
    let with_name = format_ident!("with_{}", name);
286
620
    let set_name = format_ident!("set_{}", name);
287
620
    quote! {
288
620
        #doc
289
620
        pub fn #name(mut self, value: impl #exports::Into<#ty>) -> Self {
290
620
            self.#name = Some(#exports::Into::into(value));
291
620
            self
292
620
        }
293
620

            
294
620
        pub fn #with_name(mut self, value: impl #exports::Into<Option<#ty>>) -> Self {
295
620
            self.#name = #exports::Into::into(value);
296
620
            self
297
620
        }
298
620

            
299
620
        pub fn #set_name(&mut self, value: impl #exports::Into<Option<#ty>>) {
300
620
            self.#name = #exports::Into::into(value);
301
620
        }
302
620
    }
303
620
}
304

            
305
/// Parse Rust struct types into our `Parameter` type
306
162
fn parse(input: &DeriveInput) -> Result<Vec<Parameter>> {
307
162
    let mut params = vec![];
308
162
    let data = if let Data::Struct(data) = &input.data {
309
162
        data
310
    } else {
311
        return Err(Error::new(input.ident.span(), "expected a struct"));
312
    };
313
162
    let fields = if let Fields::Named(fields) = &data.fields {
314
162
        fields
315
    } else {
316
        return Err(Error::new(
317
            input.ident.span(),
318
            "struct fields must have names",
319
        ));
320
    };
321
880
    for field in &fields.named {
322
718
        let (rust_type, required) = parse_through_option(&field.ty);
323

            
324
718
        let (doc, mw_name) = parse_mw_name(field)?;
325
718
        let param = Parameter {
326
718
            rust_name: field.ident.clone().ok_or_else(|| {
327
                Error::new(field.span(), "struct fields must have names")
328
718
            })?,
329
718
            doc,
330
718
            mw_name,
331
718
            rust_type,
332
718
            required,
333
718
        };
334
718
        params.push(param);
335
    }
336
162
    Ok(params)
337
162
}
338

            
339
/// Parse the generator attribute
340
162
fn parse_fixed_params(input: &DeriveInput) -> Result<HashMap<String, LitStr>> {
341
162
    let mut map = HashMap::new();
342
764
    for attr in &input.attrs {
343
602
        if attr.path().is_ident("params") {
344
447
            attr.parse_nested_meta(|meta| {
345
366
                let key = meta
346
366
                    .path
347
366
                    .get_ident()
348
366
                    .ok_or_else(|| meta.error("invalid parameter name"))?;
349
366
                let value: LitStr = meta.value()?.parse()?;
350
366
                map.insert(key.to_string(), value);
351
366
                Ok(())
352
447
            })?;
353
440
        }
354
    }
355
162
    if !map
356
162
        .keys()
357
333
        .any(|key| key == "generator" || key == "list" || key == "prop")
358
    {
359
        Err(Error::new(
360
            input.ident.span(),
361
            "Missing #[params(generator = \"...\")], #[params(list = \" ... \")] or #[params(prop = \" ... \")] attribute",
362
        ))
363
    } else {
364
162
        Ok(map)
365
    }
366
162
}
367

            
368
126
fn get_lit_str(ident: &str, meta: &ParseNestedMeta) -> Result<Option<LitStr>> {
369
126
    if !meta.path.is_ident(ident) {
370
        return Ok(None);
371
126
    }
372

            
373
126
    let expr: Expr = meta.value()?.parse()?;
374
126
    let mut value = &expr;
375
126
    while let Expr::Group(e) = value {
376
        value = &e.expr;
377
    }
378
    if let Expr::Lit(ExprLit {
379
126
        lit: Lit::Str(lit), ..
380
126
    }) = value
381
    {
382
126
        let suffix = lit.suffix();
383
126
        if !suffix.is_empty() {
384
            return Err(meta.error(format!(
385
                "unexpected suffix `{}` on string literal",
386
                suffix
387
            )));
388
126
        }
389
126
        Ok(Some(lit.clone()))
390
    } else {
391
        Err(meta.error(format!(
392
            "expected attribute to be a string: `{}` = \"...\"",
393
            ident
394
        )))
395
    }
396
126
}
397

            
398
178
fn get_path(meta: &ParseNestedMeta, ident: &str) -> Result<Option<Path>> {
399
178
    if !meta.path.is_ident(ident) {
400
132
        return Ok(None);
401
46
    }
402
46
    let string = match get_lit_str(ident, meta)? {
403
46
        Some(string) => string,
404
        None => return Ok(None),
405
    };
406

            
407
46
    string.parse().map(Some)
408
178
}
409

            
410
212
fn get_type(meta: &ParseNestedMeta, ident: &str) -> Result<Option<Type>> {
411
212
    if !meta.path.is_ident(ident) {
412
132
        return Ok(None);
413
80
    }
414
80
    let string = match get_lit_str(ident, meta)? {
415
80
        Some(string) => string,
416
        None => return Ok(None),
417
    };
418

            
419
80
    string.parse().map(Some)
420
212
}
421

            
422
struct GeneratorOption {
423
    exports: Path,
424
    return_type: Type,
425
    response_type: Type,
426
    transform_fn: Path,
427
    wrap_in_vec: LitBool,
428
}
429

            
430
162
fn parse_generator_option(input: &DeriveInput) -> Result<GeneratorOption> {
431
162
    let mut crate_: Option<Path> = None;
432
162
    let mut return_type: Option<Type> = None;
433
162
    let mut response_type: Option<Type> = None;
434
162
    let mut transform_fn: Option<Path> = None;
435
162
    let mut wrap_in_vec: Option<LitBool> = None;
436
764
    for attr in &input.attrs {
437
602
        if !attr.path().is_ident("generator") {
438
560
            continue;
439
42
        }
440
153
        attr.parse_nested_meta(|meta| {
441
132
            if let Some(path) = get_path(&meta, "crate")? {
442
6
                crate_ = Some(path);
443
6
                return Ok(());
444
126
            }
445
126
            if let Some(path) = get_type(&meta, "return_type")? {
446
40
                return_type = Some(path);
447
40
                return Ok(());
448
86
            }
449
86
            if let Some(path) = get_type(&meta, "response_type")? {
450
40
                response_type = Some(path);
451
40
                return Ok(());
452
46
            }
453
46
            if let Some(path) = get_path(&meta, "transform_fn")? {
454
40
                transform_fn = Some(path);
455
40
                return Ok(());
456
6
            }
457
6
            if meta.path.is_ident("wrap_in_vec") {
458
6
                wrap_in_vec = Some(meta.value()?.parse()?);
459
6
                return Ok(());
460
            }
461
            Ok(())
462
153
        })?;
463
    }
464
240
    let crate_ = crate_.unwrap_or_else(|| syn::parse_quote!(crate));
465
162
    let exports: Path = syn::parse_quote!(#crate_::generators::__exports);
466
162
    Ok(GeneratorOption {
467
162
        exports: exports.clone(),
468
162
        return_type: return_type
469
223
            .unwrap_or_else(|| syn::parse_quote!(#exports::Page)),
470
162
        response_type: response_type
471
223
            .unwrap_or_else(|| syn::parse_quote!(#exports::InfoResponse)),
472
162
        transform_fn: transform_fn
473
223
            .unwrap_or_else(|| syn::parse_quote!(#exports::transform_to_page)),
474
240
        wrap_in_vec: wrap_in_vec.unwrap_or_else(|| syn::parse_quote!(false)),
475
162
    })
476
162
}
477

            
478
/// Parse the parameter attribute
479
718
fn parse_mw_name(field: &Field) -> Result<(Option<LitStr>, LitStr)> {
480
718
    let mut doc = None;
481
718
    let mut param = None;
482
1532
    for attr in &field.attrs {
483
814
        if attr.path().is_ident("doc") {
484
96
            if let Expr::Lit(expr) = &attr.meta.require_name_value()?.value {
485
96
                if let Lit::Str(lit) = &expr.lit {
486
96
                    doc = Some(lit.clone());
487
96
                }
488
            }
489
            // else: error?
490
718
        } else if attr.path().is_ident("param") {
491
718
            let parsed: LitStr = attr.parse_args()?;
492
718
            param = Some(parsed);
493
        }
494
    }
495
718
    match param {
496
718
        Some(param) => Ok((doc, param)),
497
        None => Err(Error::new(
498
            field.ident.span(),
499
            "Missing #[param(\"...\")] attribute",
500
        )),
501
    }
502
718
}