1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/*
Copyright (C) 2021 Kunal Mehta <legoktm@debian.org>

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

use crate::responses::TokenResponse;
use crate::{Client, Error, Result};
use std::collections::HashMap;

type Token = String;

#[derive(Debug, Default)]
pub(crate) struct TokenStore {
    map: HashMap<String, Token>,
}

impl TokenStore {
    /// Get a token that's already loaded. It's up to the caller
    /// to lazy-load the token as a fallback and gracefully
    /// handle expired tokens
    pub(crate) fn get(&self, name: &str) -> Option<Token> {
        self.map.get(name).map(|token| token.to_string())
    }

    pub(crate) async fn load(
        &mut self,
        name: &str,
        api: &Client,
    ) -> Result<Token> {
        let resp: TokenResponse = api
            .get(&[("action", "query"), ("meta", "tokens"), ("type", name)])
            .await?;
        match resp.query.tokens.get(&format!("{name}token")) {
            Some(token) => {
                self.map.insert(name.to_string(), token.to_string());
                Ok(token.to_string())
            }
            None => Err(Error::TokenFailure(name.to_string())),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    #[tokio::test]
    async fn test_tokenstore() {
        let mut store = TokenStore::default();
        assert_eq!(store.get("csrf"), None);
        let client = Client::new("https://test.wikipedia.org/w/api.php")
            .await
            .unwrap();
        let fetched = store.load("csrf", &client).await.unwrap();
        assert_eq!(&fetched, "+\\");
        assert_eq!(store.get("csrf"), Some("+\\".to_string()));
        let err = store.load("invalid", &client).await.unwrap_err();
        assert!(matches!(err, Error::TokenFailure(_)));
    }
}