1
/*
2
Copyright (C) 2021 Kunal Mehta <legoktm@debian.org>
3

            
4
This program is free software: you can redistribute it and/or modify
5
it under the terms of the GNU General Public License as published by
6
the Free Software Foundation, either version 3 of the License, or
7
(at your option) any later version.
8

            
9
This program is distributed in the hope that it will be useful,
10
but WITHOUT ANY WARRANTY; without even the implied warranty of
11
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
GNU General Public License for more details.
13

            
14
You should have received a copy of the GNU General Public License
15
along with this program.  If not, see <https://www.gnu.org/licenses/>.
16
 */
17

            
18
use crate::params::RequestParams;
19
use crate::responses::LoginResponse;
20
use crate::tokens::TokenStore;
21
#[cfg(feature = "upload")]
22
use crate::upload;
23
use crate::{ApiError, Assert, Error, ErrorFormat, Method, Params, Result};
24
use reqwest::{
25
    header, Client as HttpClient, ClientBuilder as HttpClientBuilder,
26
    StatusCode,
27
};
28
use serde::de::DeserializeOwned;
29
use serde_json::Value;
30
#[cfg(feature = "upload")]
31
use std::path::PathBuf;
32
use std::{fmt::Debug, sync::Arc};
33
use tokio::sync::{RwLock, Semaphore};
34
use tracing::{debug, error, warn};
35

            
36
/// Build a new API client.
37
/// ```
38
/// # use mwapi::{Client, Result};
39
/// # async fn doc() -> Result<()> {
40
/// let client: Client = Client::builder("https://example.org/w/api.php")
41
///     .set_oauth2_token("foobar")
42
///     .set_errorformat(mwapi::ErrorFormat::Html)
43
///     .build().await?;
44
/// # Ok(())
45
/// # }
46
/// ```
47
#[derive(Clone, Debug)]
48
pub struct Builder {
49
    api_url: String,
50
    assert: Option<Assert>,
51
    concurrency: usize,
52
    maxlag: Option<u32>,
53
    retry_limit: Option<u32>,
54
    user_agent: Option<String>,
55
    oauth2_token: Option<String>,
56
    errorformat: ErrorFormat,
57
    botpassword: Option<BotPassword>,
58
    http_client: Option<HttpClientProvider>,
59
}
60

            
61
#[derive(Clone)]
62
pub struct HttpClientProvider(Arc<dyn Fn() -> HttpClientBuilder + Send + Sync>);
63

            
64
impl Debug for HttpClientProvider {
65
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66
        f.write_str("HttpClientProvider")
67
    }
68
}
69

            
70
impl From<HttpClientProvider> for HttpClientBuilder {
71
2
    fn from(provider: HttpClientProvider) -> Self {
72
2
        provider.0()
73
2
    }
74
}
75

            
76
impl<F> From<F> for HttpClientProvider
77
where
78
    F: Fn() -> HttpClientBuilder + Send + Sync + 'static,
79
{
80
2
    fn from(value: F) -> Self {
81
2
        Self(Arc::new(value))
82
2
    }
83
}
84

            
85
#[derive(Clone, Debug)]
86
struct BotPassword {
87
    username: String,
88
    password: String,
89
}
90

            
91
impl Builder {
92
    /// Create a new `Builder` instance. Typically you will use
93
    /// [`Client::builder()`] instead.
94
112
    pub fn new(api_url: &str) -> Self {
95
112
        Self {
96
112
            api_url: api_url.to_string(),
97
112
            assert: Default::default(),
98
112
            concurrency: 1,
99
112
            maxlag: None,
100
112
            retry_limit: None,
101
112
            user_agent: None,
102
112
            oauth2_token: None,
103
112
            errorformat: Default::default(),
104
112
            botpassword: None,
105
112
            http_client: None,
106
112
        }
107
112
    }
108

            
109
    /// Actually build the `Client` instance.
110
168
    pub async fn build(self) -> Result<Client> {
111
        // If some auth method is set, override assert to be assert=user unless
112
        // a user specified something else.
113
112
        let assert = match self.assert {
114
4
            Some(assert) => assert,
115
            None => {
116
108
                if self.oauth2_token.is_some() || self.botpassword.is_some() {
117
4
                    Assert::User
118
                } else {
119
104
                    Assert::None
120
                }
121
            }
122
        };
123

            
124
112
        let config = ClientConfig {
125
112
            api_url: self.api_url,
126
112
            assert,
127
112
            oauth2_token: self.oauth2_token,
128
112
            errorformat: self.errorformat,
129
112
            maxlag: self.maxlag,
130
112
            retry_limit: self.retry_limit.unwrap_or(10),
131
112
        };
132
112

            
133
112
        let mut http = self
134
112
            .http_client
135
112
            .map(Into::<HttpClientBuilder>::into)
136
112
            .unwrap_or_default();
137
112
        let ua = self
138
112
            .user_agent
139
112
            .unwrap_or(format!("mwapi-rs/{}", crate::VERSION));
140
112

            
141
112
        #[cfg(target_arch = "wasm32")]
142
112
        {
143
112
            let mut headers = header::HeaderMap::new();
144
112
            headers
145
112
                .insert("Api-User-Agent", header::HeaderValue::from_str(&ua)?);
146
112
            http = http.default_headers(headers);
147
112
        }
148
112

            
149
112
        #[cfg(not(target_arch = "wasm32"))]
150
112
        {
151
112
            http = http.cookie_store(true).user_agent(ua);
152
112
        }
153

            
154
112
        let client = Client {
155
            inner: Arc::new(InnerClient {
156
112
                config,
157
112
                http: http.build()?,
158
112
                tokens: Default::default(),
159
112
                semaphore: Semaphore::new(self.concurrency),
160
            }),
161
        };
162
112
        if let Some(botpassword) = self.botpassword {
163
2
            client.login(&botpassword).await?;
164
110
        }
165
110
        Ok(client)
166
112
    }
167

            
168
    /// Set a custom User-agent. Ideally follow the [Wikimedia User-agent policy](https://meta.wikimedia.org/wiki/User-Agent_policy).
169
90
    pub fn set_user_agent(mut self, user_agent: &str) -> Self {
170
90
        self.user_agent = Some(user_agent.to_string());
171
90
        self
172
90
    }
173

            
174
    /// Set an [OAuth2 token](https://www.mediawiki.org/wiki/OAuth/For_Developers#OAuth_2)
175
    /// for authentication
176
2
    pub fn set_oauth2_token(mut self, oauth2_token: &str) -> Self {
177
2
        self.oauth2_token = Some(oauth2_token.to_string());
178
2
        self
179
2
    }
180

            
181
    /// Set the format error messages from the API should be in
182
90
    pub fn set_errorformat(mut self, errorformat: ErrorFormat) -> Self {
183
90
        self.errorformat = errorformat;
184
90
        self
185
90
    }
186

            
187
    /// Set how many requests should be processed in parallel. On Wikimedia
188
    /// wikis, you shouldn't exceed the default of 1 without getting permission
189
    /// from a sysadmin.
190
    pub fn set_concurrency(mut self, concurrency: usize) -> Self {
191
        self.concurrency = concurrency;
192
        self
193
    }
194

            
195
    /// Pause when the servers are lagged for how many seconds?
196
    /// Typically bots should set this to 5, while interactive
197
    /// usage should be much higher.
198
    ///
199
    /// See [mediawiki.org](https://www.mediawiki.org/wiki/Special:MyLanguage/Manual:Maxlag_parameter)
200
    /// for more details.
201
92
    pub fn set_maxlag(mut self, maxlag: u32) -> Self {
202
92
        self.maxlag = Some(maxlag);
203
92
        self
204
92
    }
205

            
206
2
    pub fn set_retry_limit(mut self, limit: u32) -> Self {
207
2
        self.retry_limit = Some(limit);
208
2
        self
209
2
    }
210

            
211
2
    pub fn set_botpassword(mut self, username: &str, password: &str) -> Self {
212
2
        self.botpassword = Some(BotPassword {
213
2
            username: username.to_string(),
214
2
            password: password.to_string(),
215
2
        });
216
2
        self
217
2
    }
218

            
219
2
    pub fn set_http_client<P>(mut self, provider: P) -> Self
220
2
    where
221
2
        P: Into<HttpClientProvider>,
222
2
    {
223
2
        self.http_client = Some(provider.into());
224
2
        self
225
2
    }
226

            
227
4
    pub fn set_assert(mut self, assert: Assert) -> Self {
228
4
        self.assert = Some(assert);
229
4
        self
230
4
    }
231
}
232

            
233
/// Internal configuration options for a Client
234
#[derive(Clone, Debug)]
235
struct ClientConfig {
236
    api_url: String,
237
    assert: Assert,
238
    oauth2_token: Option<String>,
239
    errorformat: ErrorFormat,
240
    maxlag: Option<u32>,
241
    retry_limit: u32,
242
}
243

            
244
/// API Client
245
#[derive(Clone, Debug)]
246
pub struct Client {
247
    pub(crate) inner: Arc<InnerClient>,
248
}
249

            
250
#[derive(Debug)]
251
pub(crate) struct InnerClient {
252
    config: ClientConfig,
253
    http: HttpClient,
254
    tokens: RwLock<TokenStore>,
255
    semaphore: Semaphore,
256
}
257

            
258
impl InnerClient {
259
226
    fn fix_params(&self, params: &mut Params) {
260
226
        params.insert("format", "json");
261
226
        params.insert("formatversion", "2");
262
226
        params.insert("errorformat", self.config.errorformat);
263
226
        if let Some(maxlag) = self.config.maxlag {
264
204
            params.insert("maxlag", maxlag);
265
222
        }
266
        // Set assert if this is not a login or login token request
267
226
        if !(params.get("action") == Some(&"login".to_string())
268
224
            || (params.get("meta") == Some(&"tokens".to_string())
269
6
                && params.get("type") == Some(&"login".to_string())))
270
        {
271
222
            if let Some(value) = self.config.assert.value() {
272
4
                params.insert("assert", value);
273
218
            }
274
4
        }
275
226
    }
276

            
277
    /// Get headers that should be applied to every request
278
226
    fn headers(&self) -> Result<header::HeaderMap> {
279
226
        let mut headers = header::HeaderMap::new();
280
226
        if let Some(token) = &self.config.oauth2_token {
281
            let mut value =
282
                header::HeaderValue::from_str(&format!("Bearer {token}"))?;
283
            value.set_sensitive(true);
284
            headers.insert(header::AUTHORIZATION, value);
285
226
        }
286

            
287
226
        Ok(headers)
288
226
    }
289

            
290
    /// Do an HTTP request.
291
226
    pub(crate) async fn do_request(
292
226
        &self,
293
226
        req_params: RequestParams,
294
339
    ) -> Result<Value> {
295
226
        let req = match req_params {
296
224
            RequestParams::Get(mut params) => {
297
224
                self.fix_params(&mut params);
298
224
                self.http.get(&self.config.api_url).query(params.as_map())
299
            }
300
2
            RequestParams::Post(mut params) => {
301
2
                self.fix_params(&mut params);
302
2
                self.http.post(&self.config.api_url).form(params.as_map())
303
            }
304
            #[cfg(feature = "upload")]
305
            RequestParams::Multipart(mut params) => {
306
                self.fix_params(&mut params.params);
307
                self.http
308
                    .post(&self.config.api_url)
309
                    .multipart(params.into_form().await?)
310
            }
311
        };
312
226
        let req = req.headers(self.headers()?).build()?;
313
226
        let _lock = self.semaphore.acquire().await?;
314
226
        debug!(?req);
315
226
        let result = self.http.execute(req).await;
316
226
        debug!(?result);
317
226
        drop(_lock);
318
226
        let resp = result?;
319
        // Silly, we have to get the headers first, because error_for_status()
320
        // takes back ownership. But most of the time we don't even need it
321
226
        let retry_after = extract_retry_after(resp.headers());
322
226
        if resp.status() == StatusCode::TOO_MANY_REQUESTS {
323
            return Err(Error::TooManyRequests {
324
                retry_after: Some(retry_after),
325
            });
326
226
        }
327
226
        let value: Value = resp.error_for_status()?.json().await?;
328
224
        handle_response(value, retry_after)
329
226
    }
330

            
331
224
    pub(crate) async fn request<P: Into<Params>, T: DeserializeOwned>(
332
224
        &self,
333
224
        method: Method,
334
224
        params: P,
335
224
    ) -> Result<T> {
336
224
        let mut retry_counter = 0;
337
224
        let params = params.into();
338

            
339
        loop {
340
226
            let params = params.clone();
341
226
            let resp = self
342
226
                .do_request(match method {
343
224
                    Method::Get => RequestParams::Get(params),
344
2
                    Method::Post => RequestParams::Post(params),
345
                })
346
226
                .await;
347
226
            match resp {
348
216
                Ok(value) => {
349
216
                    return Ok(serde_json::from_value(value)?);
350
                }
351
10
                Err(err) => {
352
10
                    if let Some(retry_after) = err.retry_after() {
353
4
                        if retry_counter >= self.config.retry_limit {
354
2
                            return Err(err);
355
2
                        }
356
2
                        // We should retry, see if there's a retry-after header
357
2
                        if retry_after != 0 {
358
                            // XXX: Should we be holding the concurrency lock here?
359
                            // Currently all the retry errors are wiki-level issues
360
                            // like read-only mode or maxlag, but in the future they
361
                            // could be just ratelimits
362
2
                            crate::time::sleep(retry_after).await;
363
                        }
364
                        // Loop again!
365
2
                        retry_counter += 1;
366
                    } else {
367
6
                        return Err(err);
368
                    }
369
                }
370
            }
371
        }
372
224
    }
373
}
374

            
375
impl Client {
376
    /// Get a `Builder` instance to further customize the API `Client`.
377
    /// The API URL should be the absolute path to [api.php](https://www.mediawiki.org/wiki/API:Main_page).
378
104
    pub fn builder(api_url: &str) -> Builder {
379
104
        Builder::new(api_url)
380
104
    }
381

            
382
    /// Get an API `Client` instance. The API URL should be the absolute
383
    /// path to [api.php](https://www.mediawiki.org/wiki/API:Main_page).
384
12
    pub async fn new(api_url: &str) -> Result<Self> {
385
8
        Builder::new(api_url).build().await
386
8
    }
387

            
388
3
    async fn login(&self, botpassword: &BotPassword) -> Result<()> {
389
        // Don't use a cached token, we need a fresh one
390
2
        let token = self.inner.tokens.write().await.load("login", self).await?;
391
2
        let resp = self
392
2
            .post(&[
393
2
                ("action", "login"),
394
2
                ("lgname", &botpassword.username),
395
2
                ("lgpassword", &botpassword.password),
396
2
                ("lgtoken", &token),
397
2
            ])
398
2
            .await?;
399
2
        let login_resp: LoginResponse = serde_json::from_value(resp)?;
400
        // Convert "result": "Failed" into API errors
401
2
        if login_resp.login.result == "Failed" {
402
2
            Err(match login_resp.login.reason {
403
2
                Some(reason) => Error::from(reason),
404
                None => Error::Unknown("Login failed".to_string()),
405
            })
406
        } else {
407
            Ok(())
408
        }
409
2
    }
410

            
411
    /// Same as [`Client::get()`], but return a [`serde_json::Value`]
412
222
    pub async fn get_value<P: Into<Params>>(&self, params: P) -> Result<Value> {
413
222
        let params = params.into();
414
222
        self.inner.request(Method::Get, params).await
415
222
    }
416

            
417
    /// Make an arbitrary API request using HTTP GET.
418
200
    pub async fn get<P: Into<Params>, T: DeserializeOwned>(
419
200
        &self,
420
200
        params: P,
421
200
    ) -> Result<T> {
422
200
        match self.get_value(params).await {
423
200
            Ok(value) => Ok(serde_json::from_value(value)?),
424
            Err(err) => Err(err),
425
        }
426
200
    }
427

            
428
    /// Get the specified token, fetching it if necessary
429
    pub(crate) async fn token(&self, token_type: &str) -> Result<String> {
430
        let get = self.inner.tokens.read().await.get(token_type);
431
        match get {
432
            Some(token) => Ok(token),
433
            None => {
434
                self.inner.tokens.write().await.load(token_type, self).await
435
            }
436
        }
437
    }
438

            
439
    /// Make an API POST request with a [CSRF token](https://www.mediawiki.org/wiki/API:Tokens).
440
    /// The correct token will automatically be fetched, and in case of a
441
    /// bad token error (if it expired), a new one will automatically be
442
    /// fetched and the request retried.
443
    pub async fn post_with_token<P: Into<Params>, T: DeserializeOwned>(
444
        &self,
445
        token_type: &str,
446
        params: P,
447
    ) -> Result<T> {
448
        let mut params = params.into();
449
        // Note: This is in a separate line to avoid holding the read lock
450
        // while also trying to get the write lock in the None clause.
451
        params.insert("token", self.token(token_type).await?);
452
        match self.post(params.clone()).await {
453
            Err(Error::BadToken) => {
454
                // badtoken error, let's try one more time
455
                let token = self
456
                    .inner
457
                    .tokens
458
                    .write()
459
                    .await
460
                    .load(token_type, self)
461
                    .await?;
462
                params.insert("token", token);
463
                self.post(params).await
464
            }
465
            // Pass through any Ok() or other Err()
466
            result => result,
467
        }
468
    }
469

            
470
    /// Make an API POST request
471
2
    pub async fn post<P: Into<Params>, T: DeserializeOwned>(
472
2
        &self,
473
2
        params: P,
474
2
    ) -> Result<T> {
475
2
        match self.inner.request(Method::Post, params).await {
476
2
            Ok(value) => Ok(serde_json::from_value(value)?),
477
            Err(err) => Err(err),
478
        }
479
2
    }
480

            
481
    /// Same as [`Client::post()`], but return a [`serde_json::Value`]
482
    pub async fn post_value<P: Into<Params>>(
483
        &self,
484
        params: P,
485
    ) -> Result<Value> {
486
        self.post(params).await
487
    }
488

            
489
    /// Upload a file under with the given filename
490
    /// from a path.
491
    ///
492
    /// * The `chunk_size` should be in bytes, 5MB (`5_000_000`)
493
    ///   is a reasonable default if you're unsure.
494
    /// * Warnings will be returned as an error unless `ignore_warnings`
495
    ///   is true.
496
    /// * Any extra parameters can be passed in the standard format.
497
    #[cfg(feature = "upload")]
498
    #[cfg_attr(docsrs, doc(cfg(feature = "upload")))]
499
    pub async fn upload<P: Into<Params>>(
500
        &self,
501
        filename: &str,
502
        path: PathBuf,
503
        chunk_size: usize,
504
        ignore_warnings: bool,
505
        params: P,
506
    ) -> Result<String> {
507
        let mut base_params =
508
            Params::from(&[("action", "upload"), ("filename", filename)]);
509
        if ignore_warnings {
510
            base_params.insert("ignorewarnings", 1);
511
        }
512
        let req = upload::UploadRequest {
513
            filename: filename.to_string(),
514
            file: path,
515
            chunk_size,
516
            base_params,
517
            upload_params: params.into(),
518
        };
519
        upload::upload(self, req).await
520
    }
521

            
522
    /// Get access to the underlying [`reqwest::Client`] to make arbitrary
523
    /// GET/POST requests, sharing the connection pool and cookie storage.
524
    /// For example, if you wanted to download images from the wiki.
525
90
    pub fn http_client(&self) -> &HttpClient {
526
90
        &self.inner.http
527
90
    }
528
}
529

            
530
224
fn handle_response(mut value: Value, retry_after: u64) -> Result<Value> {
531
224
    if let Some(warnings) = value.get("warnings") {
532
6
        let warnings: Vec<ApiError> = serde_json::from_value(warnings.clone())?;
533
12
        for warning in warnings {
534
6
            warn!("API warning: {}", warning);
535
        }
536
218
    }
537
224
    let errors = value["errors"].take();
538
224
    if !errors.is_null() {
539
8
        let errors: Vec<ApiError> = serde_json::from_value(errors)?;
540
        // Log all received API errors
541
16
        for error in &errors {
542
8
            error!("API error: {}", error);
543
        }
544

            
545
        // We can only return one error, so return the first.
546
8
        let err = match errors.into_iter().next() {
547
8
            Some(err) => Error::from(err),
548
            // Empty errors array? Shouldn't happen, but return an unknown error
549
            None => Error::Unknown("No error specified".to_string()),
550
        };
551
8
        Err(err.with_retry_after(retry_after))
552
    } else {
553
216
        Ok(value)
554
    }
555
224
}
556

            
557
232
fn extract_retry_after(headers: &header::HeaderMap) -> u64 {
558
232
    if let Some(header) = headers.get("retry-after") {
559
10
        header.to_str().unwrap_or("").parse().unwrap_or(0)
560
    } else {
561
222
        0
562
    }
563
232
}
564

            
565
#[cfg(test)]
566
mod tests {
567
    use super::*;
568

            
569
8
    fn assert_send_sync<T: Send + Sync>() {}
570

            
571
    /// Assert all these types are Send + Sync
572
    #[test]
573
2
    fn test_send_sync() {
574
2
        assert_send_sync::<Builder>();
575
2
        assert_send_sync::<Client>();
576
2
        assert_send_sync::<ClientConfig>();
577
2
        assert_send_sync::<InnerClient>();
578
2
    }
579

            
580
    #[tokio::test]
581
3
    async fn test_basic_get() {
582
3
        let client = Client::new("https://www.mediawiki.org/w/api.php")
583
3
            .await
584
3
            .unwrap();
585
3
        let resp = client
586
3
            .get_value(&[("action", "query"), ("meta", "siteinfo")])
587
3
            .await
588
3
            .unwrap();
589
3
        assert_eq!(
590
3
            resp["query"]["general"]["sitename"].as_str().unwrap(),
591
3
            "MediaWiki"
592
3
        );
593
2
    }
594

            
595
    #[tokio::test]
596
3
    async fn test_basic_errors() {
597
3
        let client = Client::new("https://www.mediawiki.org/w/api.php")
598
3
            .await
599
3
            .unwrap();
600
3
        let error = client
601
3
            .get_value(&[("action", "nonexistent")])
602
3
            .await
603
3
            .unwrap_err();
604
3
        assert_eq!(
605
3
            &error.to_string(),
606
3
            "API error: (code: badvalue): Unrecognized value for parameter \"action\": nonexistent."
607
3
        );
608
2
    }
609

            
610
    #[tokio::test]
611
3
    async fn test_builder() {
612
3
        let client = Client::builder("https://www.mediawiki.org/w/api.php")
613
3
            .set_oauth2_token("foobarbaz")
614
3
            .build()
615
3
            .await
616
3
            .unwrap();
617
3
        assert_eq!(
618
3
            client.inner.config.oauth2_token,
619
3
            Some("foobarbaz".to_string())
620
3
        );
621
2
    }
622

            
623
    #[tokio::test]
624
3
    async fn test_login() {
625
3
        let username = std::env::var("MWAPI_USERNAME");
626
3
        let token = std::env::var("MWAPI_TOKEN");
627
3
        if username.is_err() || token.is_err() {
628
2
            // Skip
629
3
            return;
630
2
        }
631
2
        let client = Client::builder("https://test.wikipedia.org/w/api.php")
632
1
            .set_oauth2_token(&token.unwrap())
633
1
            .build()
634
1
            .await
635
2
            .unwrap();
636
2
        let resp = client
637
1
            .get_value(&[("action", "query"), ("meta", "userinfo")])
638
1
            .await
639
2
            .unwrap();
640
1
        dbg!(&resp);
641
1
        // Check the botpassword username ("Foo@something") starts with the real wiki username ("Foo")
642
1
        // TODO: can we re-use mwbot's normalization here?
643
1
        let normalized = username.unwrap().replace('_', " ");
644
1
        assert!(&normalized
645
1
            .starts_with(resp["query"]["userinfo"]["name"].as_str().unwrap()));
646
2
    }
647

            
648
    #[tokio::test]
649
3
    async fn test_good_assert() {
650
3
        let client = Client::builder("https://test.wikipedia.org/w/api.php")
651
3
            .set_assert(Assert::Anonymous)
652
3
            .build()
653
3
            .await
654
3
            .unwrap();
655
3
        // No error
656
3
        client.get_value(&[("action", "query")]).await.unwrap();
657
2
    }
658

            
659
    #[tokio::test]
660
3
    async fn test_bad_assert() {
661
3
        let client = Client::builder("https://test.wikipedia.org/w/api.php")
662
3
            .set_assert(Assert::User)
663
3
            .build()
664
3
            .await
665
3
            .unwrap();
666
3
        let error = client.get_value(&[("action", "query")]).await.unwrap_err();
667
3
        assert!(matches!(error, Error::NotLoggedIn));
668
2
    }
669

            
670
    #[tokio::test]
671
3
    async fn test_bad_login() {
672
3
        let error = Client::builder("https://test.wikipedia.org/w/api.php")
673
3
            .set_botpassword("ThisAccountDoesNotExistPlease", "password")
674
3
            .build()
675
3
            .await
676
3
            .unwrap_err();
677
3
        if let Error::ApiError(api_err) = error {
678
3
            assert_eq!(&api_err.code, "wrongpassword");
679
2
        } else {
680
2
            panic!("wrong error type");
681
2
        }
682
2
    }
683

            
684
    #[tokio::test]
685
3
    async fn test_maxlag() {
686
3
        let client = Client::builder("https://test.wikipedia.org/w/api.php")
687
3
            .set_maxlag(0)
688
3
            .set_retry_limit(1)
689
3
            .build()
690
3
            .await
691
3
            .unwrap();
692
3
        let error = client.get_value(&[("action", "query")]).await.unwrap_err();
693
3
        if let Error::Maxlag { info, .. } = error {
694
3
            assert!(info.starts_with("Waiting for"));
695
2
        } else {
696
2
            dbg!(&error);
697
1
            panic!("Error did not match MaxlagError");
698
2
        }
699
2
    }
700

            
701
    #[tokio::test]
702
3
    async fn test_warning() {
703
3
        let client = Client::builder("https://test.wikipedia.org/w/api.php")
704
3
            .build()
705
3
            .await
706
3
            .unwrap();
707
3
        // We can't really assert that we logged something, so just check it
708
3
        // doesn't obviously blow up
709
3
        client
710
3
            .get_value(&[("action", "query"), ("list", "unknown")])
711
3
            .await
712
3
            .unwrap();
713
2
    }
714

            
715
    #[test]
716
2
    fn test_extract_retry_after() {
717
2
        let mut map = header::HeaderMap::new();
718
2
        map.insert("retry-after", "0".parse().unwrap());
719
2
        assert_eq!(extract_retry_after(&map), 0);
720
2
        map.insert("retry-after", "abc".parse().unwrap());
721
2
        assert_eq!(extract_retry_after(&map), 0);
722
2
        map.insert("retry-after", "4".parse().unwrap());
723
2
        assert_eq!(extract_retry_after(&map), 4);
724
2
    }
725

            
726
    #[tokio::test]
727
3
    async fn test_http_client_provider() {
728
2
        fn provider() -> HttpClientBuilder {
729
2
            HttpClientBuilder::new().default_headers({
730
2
                let mut headers = reqwest::header::HeaderMap::new();
731
2
                headers.insert(
732
2
                    "Origin",
733
2
                    header::HeaderValue::from_static("meta.wikipedia.org"),
734
2
                );
735
2
                headers
736
2
            })
737
2
        }
738
3
        let client = Client::builder("https://www.mediawiki.org/w/api.php")
739
3
            .set_http_client(provider)
740
3
            .build()
741
3
            .await
742
3
            .unwrap();
743
3
        let resp = client
744
3
            .get_value(&[
745
3
                ("action", "query"),
746
3
                ("meta", "siteinfo"),
747
3
                ("origin", "meta.wikipedia.org"),
748
3
            ])
749
3
            .await
750
3
            .unwrap();
751
3
        assert_eq!(
752
3
            resp["query"]["general"]["sitename"].as_str().unwrap(),
753
3
            "MediaWiki"
754
3
        );
755
3
        assert!(Client::new("https://www.mediawiki.org/w/api.php")
756
3
            .await
757
3
            .unwrap()
758
3
            .get_value(&[
759
3
                ("action", "query"),
760
3
                ("meta", "siteinfo"),
761
3
                ("origin", "meta.wikipedia.org"),
762
3
            ])
763
3
            .await
764
3
            .is_err());
765
2
    }
766
}