use crate::params::RequestParams;
use crate::responses::LoginResponse;
use crate::tokens::TokenStore;
#[cfg(feature = "upload")]
use crate::upload;
use crate::{ApiError, Assert, Error, ErrorFormat, Method, Params, Result};
use reqwest::{
header, Client as HttpClient, ClientBuilder as HttpClientBuilder,
StatusCode,
};
use serde::de::DeserializeOwned;
use serde_json::Value;
#[cfg(feature = "upload")]
use std::path::PathBuf;
use std::{fmt::Debug, sync::Arc};
use tokio::sync::{RwLock, Semaphore};
use tracing::{debug, error, warn};
#[derive(Clone, Debug)]
pub struct Builder {
api_url: String,
assert: Option<Assert>,
concurrency: usize,
maxlag: Option<u32>,
retry_limit: Option<u32>,
user_agent: Option<String>,
oauth2_token: Option<String>,
errorformat: ErrorFormat,
botpassword: Option<BotPassword>,
http_client: Option<HttpClientProvider>,
}
#[derive(Clone)]
pub struct HttpClientProvider(Arc<dyn Fn() -> HttpClientBuilder + Send + Sync>);
impl Debug for HttpClientProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("HttpClientProvider")
}
}
impl From<HttpClientProvider> for HttpClientBuilder {
fn from(provider: HttpClientProvider) -> Self {
provider.0()
}
}
impl<F> From<F> for HttpClientProvider
where
F: Fn() -> HttpClientBuilder + Send + Sync + 'static,
{
fn from(value: F) -> Self {
Self(Arc::new(value))
}
}
#[derive(Clone, Debug)]
struct BotPassword {
username: String,
password: String,
}
impl Builder {
pub fn new(api_url: &str) -> Self {
Self {
api_url: api_url.to_string(),
assert: Default::default(),
concurrency: 1,
maxlag: None,
retry_limit: None,
user_agent: None,
oauth2_token: None,
errorformat: Default::default(),
botpassword: None,
http_client: None,
}
}
pub async fn build(self) -> Result<Client> {
let assert = match self.assert {
Some(assert) => assert,
None => {
if self.oauth2_token.is_some() || self.botpassword.is_some() {
Assert::User
} else {
Assert::None
}
}
};
let config = ClientConfig {
api_url: self.api_url,
assert,
oauth2_token: self.oauth2_token,
errorformat: self.errorformat,
maxlag: self.maxlag,
retry_limit: self.retry_limit.unwrap_or(10),
};
let mut http = self
.http_client
.map(Into::<HttpClientBuilder>::into)
.unwrap_or_default();
let ua = self
.user_agent
.unwrap_or(format!("mwapi-rs/{}", crate::VERSION));
#[cfg(target_arch = "wasm32")]
{
let mut headers = header::HeaderMap::new();
headers
.insert("Api-User-Agent", header::HeaderValue::from_str(&ua)?);
http = http.default_headers(headers);
}
#[cfg(not(target_arch = "wasm32"))]
{
http = http.cookie_store(true).user_agent(ua);
}
let client = Client {
inner: Arc::new(InnerClient {
config,
http: http.build()?,
tokens: Default::default(),
semaphore: Semaphore::new(self.concurrency),
}),
};
if let Some(botpassword) = self.botpassword {
client.login(&botpassword).await?;
}
Ok(client)
}
pub fn set_user_agent(mut self, user_agent: &str) -> Self {
self.user_agent = Some(user_agent.to_string());
self
}
pub fn set_oauth2_token(mut self, oauth2_token: &str) -> Self {
self.oauth2_token = Some(oauth2_token.to_string());
self
}
pub fn set_errorformat(mut self, errorformat: ErrorFormat) -> Self {
self.errorformat = errorformat;
self
}
pub fn set_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = concurrency;
self
}
pub fn set_maxlag(mut self, maxlag: u32) -> Self {
self.maxlag = Some(maxlag);
self
}
pub fn set_retry_limit(mut self, limit: u32) -> Self {
self.retry_limit = Some(limit);
self
}
pub fn set_botpassword(mut self, username: &str, password: &str) -> Self {
self.botpassword = Some(BotPassword {
username: username.to_string(),
password: password.to_string(),
});
self
}
pub fn set_http_client<P>(mut self, provider: P) -> Self
where
P: Into<HttpClientProvider>,
{
self.http_client = Some(provider.into());
self
}
pub fn set_assert(mut self, assert: Assert) -> Self {
self.assert = Some(assert);
self
}
}
#[derive(Clone, Debug)]
struct ClientConfig {
api_url: String,
assert: Assert,
oauth2_token: Option<String>,
errorformat: ErrorFormat,
maxlag: Option<u32>,
retry_limit: u32,
}
#[derive(Clone, Debug)]
pub struct Client {
pub(crate) inner: Arc<InnerClient>,
}
#[derive(Debug)]
pub(crate) struct InnerClient {
config: ClientConfig,
http: HttpClient,
tokens: RwLock<TokenStore>,
semaphore: Semaphore,
}
impl InnerClient {
fn fix_params(&self, params: &mut Params) {
params.insert("format", "json");
params.insert("formatversion", "2");
params.insert("errorformat", self.config.errorformat);
if let Some(maxlag) = self.config.maxlag {
params.insert("maxlag", maxlag);
}
if !(params.get("action") == Some(&"login".to_string())
|| (params.get("meta") == Some(&"tokens".to_string())
&& params.get("type") == Some(&"login".to_string())))
{
if let Some(value) = self.config.assert.value() {
params.insert("assert", value);
}
}
}
fn headers(&self) -> Result<header::HeaderMap> {
let mut headers = header::HeaderMap::new();
if let Some(token) = &self.config.oauth2_token {
let mut value =
header::HeaderValue::from_str(&format!("Bearer {token}"))?;
value.set_sensitive(true);
headers.insert(header::AUTHORIZATION, value);
}
Ok(headers)
}
pub(crate) async fn do_request(
&self,
req_params: RequestParams,
) -> Result<Value> {
let req = match req_params {
RequestParams::Get(mut params) => {
self.fix_params(&mut params);
self.http.get(&self.config.api_url).query(params.as_map())
}
RequestParams::Post(mut params) => {
self.fix_params(&mut params);
self.http.post(&self.config.api_url).form(params.as_map())
}
#[cfg(feature = "upload")]
RequestParams::Multipart(mut params) => {
self.fix_params(&mut params.params);
self.http
.post(&self.config.api_url)
.multipart(params.into_form().await?)
}
};
let req = req.headers(self.headers()?).build()?;
let _lock = self.semaphore.acquire().await?;
debug!(?req);
let result = self.http.execute(req).await;
debug!(?result);
drop(_lock);
let resp = result?;
let retry_after = extract_retry_after(resp.headers());
if resp.status() == StatusCode::TOO_MANY_REQUESTS {
return Err(Error::TooManyRequests {
retry_after: Some(retry_after),
});
}
let value: Value = resp.error_for_status()?.json().await?;
handle_response(value, retry_after)
}
pub(crate) async fn request<P: Into<Params>, T: DeserializeOwned>(
&self,
method: Method,
params: P,
) -> Result<T> {
let mut retry_counter = 0;
let params = params.into();
loop {
let params = params.clone();
let resp = self
.do_request(match method {
Method::Get => RequestParams::Get(params),
Method::Post => RequestParams::Post(params),
})
.await;
match resp {
Ok(value) => {
return Ok(serde_json::from_value(value)?);
}
Err(err) => {
if let Some(retry_after) = err.retry_after() {
if retry_counter >= self.config.retry_limit {
return Err(err);
}
if retry_after != 0 {
crate::time::sleep(retry_after).await;
}
retry_counter += 1;
} else {
return Err(err);
}
}
}
}
}
}
impl Client {
pub fn builder(api_url: &str) -> Builder {
Builder::new(api_url)
}
pub async fn new(api_url: &str) -> Result<Self> {
Builder::new(api_url).build().await
}
async fn login(&self, botpassword: &BotPassword) -> Result<()> {
let token = self.inner.tokens.write().await.load("login", self).await?;
let resp = self
.post(&[
("action", "login"),
("lgname", &botpassword.username),
("lgpassword", &botpassword.password),
("lgtoken", &token),
])
.await?;
let login_resp: LoginResponse = serde_json::from_value(resp)?;
if login_resp.login.result == "Failed" {
Err(match login_resp.login.reason {
Some(reason) => Error::from(reason),
None => Error::Unknown("Login failed".to_string()),
})
} else {
Ok(())
}
}
pub async fn get_value<P: Into<Params>>(&self, params: P) -> Result<Value> {
let params = params.into();
self.inner.request(Method::Get, params).await
}
pub async fn get<P: Into<Params>, T: DeserializeOwned>(
&self,
params: P,
) -> Result<T> {
match self.get_value(params).await {
Ok(value) => Ok(serde_json::from_value(value)?),
Err(err) => Err(err),
}
}
pub(crate) async fn token(&self, token_type: &str) -> Result<String> {
let get = self.inner.tokens.read().await.get(token_type);
match get {
Some(token) => Ok(token),
None => {
self.inner.tokens.write().await.load(token_type, self).await
}
}
}
pub async fn post_with_token<P: Into<Params>, T: DeserializeOwned>(
&self,
token_type: &str,
params: P,
) -> Result<T> {
let mut params = params.into();
params.insert("token", self.token(token_type).await?);
match self.post(params.clone()).await {
Err(Error::BadToken) => {
let token = self
.inner
.tokens
.write()
.await
.load(token_type, self)
.await?;
params.insert("token", token);
self.post(params).await
}
result => result,
}
}
pub async fn post<P: Into<Params>, T: DeserializeOwned>(
&self,
params: P,
) -> Result<T> {
match self.inner.request(Method::Post, params).await {
Ok(value) => Ok(serde_json::from_value(value)?),
Err(err) => Err(err),
}
}
pub async fn post_value<P: Into<Params>>(
&self,
params: P,
) -> Result<Value> {
self.post(params).await
}
#[cfg(feature = "upload")]
#[cfg_attr(docsrs, doc(cfg(feature = "upload")))]
pub async fn upload<P: Into<Params>>(
&self,
filename: &str,
path: PathBuf,
chunk_size: usize,
ignore_warnings: bool,
params: P,
) -> Result<String> {
let mut base_params =
Params::from(&[("action", "upload"), ("filename", filename)]);
if ignore_warnings {
base_params.insert("ignorewarnings", 1);
}
let req = upload::UploadRequest {
filename: filename.to_string(),
file: path,
chunk_size,
base_params,
upload_params: params.into(),
};
upload::upload(self, req).await
}
pub fn http_client(&self) -> &HttpClient {
&self.inner.http
}
}
fn handle_response(mut value: Value, retry_after: u64) -> Result<Value> {
if let Some(warnings) = value.get("warnings") {
let warnings: Vec<ApiError> = serde_json::from_value(warnings.clone())?;
for warning in warnings {
warn!("API warning: {}", warning);
}
}
let errors = value["errors"].take();
if !errors.is_null() {
let errors: Vec<ApiError> = serde_json::from_value(errors)?;
for error in &errors {
error!("API error: {}", error);
}
let err = match errors.into_iter().next() {
Some(err) => Error::from(err),
None => Error::Unknown("No error specified".to_string()),
};
Err(err.with_retry_after(retry_after))
} else {
Ok(value)
}
}
fn extract_retry_after(headers: &header::HeaderMap) -> u64 {
if let Some(header) = headers.get("retry-after") {
header.to_str().unwrap_or("").parse().unwrap_or(0)
} else {
0
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn test_send_sync() {
assert_send_sync::<Builder>();
assert_send_sync::<Client>();
assert_send_sync::<ClientConfig>();
assert_send_sync::<InnerClient>();
}
#[tokio::test]
async fn test_basic_get() {
let client = Client::new("https://www.mediawiki.org/w/api.php")
.await
.unwrap();
let resp = client
.get_value(&[("action", "query"), ("meta", "siteinfo")])
.await
.unwrap();
assert_eq!(
resp["query"]["general"]["sitename"].as_str().unwrap(),
"MediaWiki"
);
}
#[tokio::test]
async fn test_basic_errors() {
let client = Client::new("https://www.mediawiki.org/w/api.php")
.await
.unwrap();
let error = client
.get_value(&[("action", "nonexistent")])
.await
.unwrap_err();
assert_eq!(
&error.to_string(),
"API error: (code: badvalue): Unrecognized value for parameter \"action\": nonexistent."
);
}
#[tokio::test]
async fn test_builder() {
let client = Client::builder("https://www.mediawiki.org/w/api.php")
.set_oauth2_token("foobarbaz")
.build()
.await
.unwrap();
assert_eq!(
client.inner.config.oauth2_token,
Some("foobarbaz".to_string())
);
}
#[tokio::test]
async fn test_login() {
let username = std::env::var("MWAPI_USERNAME");
let token = std::env::var("MWAPI_TOKEN");
if username.is_err() || token.is_err() {
return;
}
let client = Client::builder("https://test.wikipedia.org/w/api.php")
.set_oauth2_token(&token.unwrap())
.build()
.await
.unwrap();
let resp = client
.get_value(&[("action", "query"), ("meta", "userinfo")])
.await
.unwrap();
dbg!(&resp);
let normalized = username.unwrap().replace('_', " ");
assert!(&normalized
.starts_with(resp["query"]["userinfo"]["name"].as_str().unwrap()));
}
#[tokio::test]
async fn test_good_assert() {
let client = Client::builder("https://test.wikipedia.org/w/api.php")
.set_assert(Assert::Anonymous)
.build()
.await
.unwrap();
client.get_value(&[("action", "query")]).await.unwrap();
}
#[tokio::test]
async fn test_bad_assert() {
let client = Client::builder("https://test.wikipedia.org/w/api.php")
.set_assert(Assert::User)
.build()
.await
.unwrap();
let error = client.get_value(&[("action", "query")]).await.unwrap_err();
assert!(matches!(error, Error::NotLoggedIn));
}
#[tokio::test]
async fn test_bad_login() {
let error = Client::builder("https://test.wikipedia.org/w/api.php")
.set_botpassword("ThisAccountDoesNotExistPlease", "password")
.build()
.await
.unwrap_err();
if let Error::ApiError(api_err) = error {
assert_eq!(&api_err.code, "wrongpassword");
} else {
panic!("wrong error type");
}
}
#[tokio::test]
async fn test_maxlag() {
let client = Client::builder("https://test.wikipedia.org/w/api.php")
.set_maxlag(0)
.set_retry_limit(1)
.build()
.await
.unwrap();
let error = client.get_value(&[("action", "query")]).await.unwrap_err();
if let Error::Maxlag { info, .. } = error {
assert!(info.starts_with("Waiting for"));
} else {
dbg!(&error);
panic!("Error did not match MaxlagError");
}
}
#[tokio::test]
async fn test_warning() {
let client = Client::builder("https://test.wikipedia.org/w/api.php")
.build()
.await
.unwrap();
client
.get_value(&[("action", "query"), ("list", "unknown")])
.await
.unwrap();
}
#[test]
fn test_extract_retry_after() {
let mut map = header::HeaderMap::new();
map.insert("retry-after", "0".parse().unwrap());
assert_eq!(extract_retry_after(&map), 0);
map.insert("retry-after", "abc".parse().unwrap());
assert_eq!(extract_retry_after(&map), 0);
map.insert("retry-after", "4".parse().unwrap());
assert_eq!(extract_retry_after(&map), 4);
}
#[tokio::test]
async fn test_http_client_provider() {
fn provider() -> HttpClientBuilder {
HttpClientBuilder::new().default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Origin",
header::HeaderValue::from_static("meta.wikipedia.org"),
);
headers
})
}
let client = Client::builder("https://www.mediawiki.org/w/api.php")
.set_http_client(provider)
.build()
.await
.unwrap();
let resp = client
.get_value(&[
("action", "query"),
("meta", "siteinfo"),
("origin", "meta.wikipedia.org"),
])
.await
.unwrap();
assert_eq!(
resp["query"]["general"]["sitename"].as_str().unwrap(),
"MediaWiki"
);
assert!(Client::new("https://www.mediawiki.org/w/api.php")
.await
.unwrap()
.get_value(&[
("action", "query"),
("meta", "siteinfo"),
("origin", "meta.wikipedia.org"),
])
.await
.is_err());
}
}