use std::{net::IpAddr, num::NonZeroU32, time::Duration}; use axum::{async_trait, extract::FromRequestParts, http::request::Parts}; use governor::{clock::DefaultClock, state::keyed::DashMapStateStore, Quota, RateLimiter}; use once_cell::sync::Lazy; use tracing::warn; use crate::server::error::{Error, Result}; type Limiter = RateLimiter, DefaultClock>; static LIMITER_LOGIN: Lazy = Lazy::new(|| { let seconds = Duration::from_secs(60); let burst = NonZeroU32::new(10).expect("Non-zero login ratelimit burst"); RateLimiter::keyed( Quota::with_period(seconds) .expect("Non-zero login ratelimit seconds") .allow_burst(burst), ) }); pub fn check_limit_login(ip: &IpAddr) -> Result<()> { match LIMITER_LOGIN.check_key(ip) { Ok(_) => Ok(()), Err(_e) => Err(Error::RateLimit), } } pub struct ClientIp(pub IpAddr); const X_FORWARDED_FOR: &str = "x-forwarded-for"; #[async_trait] impl FromRequestParts for ClientIp where S: Send + Sync, { type Rejection = Error; async fn from_request_parts(req: &mut Parts, _state: &S) -> Result { let addr = req .headers .get(X_FORWARDED_FOR) .and_then(|hv| hv.to_str().ok()) .and_then(|ip| { match ip.find(',') { Some(idx) => &ip[..idx], None => ip, } .parse() .map_err(|_| warn!("'{}' header is malformed: {}", X_FORWARDED_FOR, ip)) .ok() }) .unwrap_or_else(|| "0.0.0.0".parse().unwrap()); Ok(Self(addr)) } }