use async_redis_session::RedisSessionStore; use async_session::{Session, SessionStore}; use axum::{ async_trait, extract::{Extension, FromRequest, Query, RequestParts, TypedHeader}, prelude::*, response::IntoResponse, routing::BoxRoute, AddExtensionLayer, }; use http::{header::SET_COOKIE, StatusCode}; use hyper::Body; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; use std::env; static COOKIE_NAME: &str = "SESSION"; pub fn oauth_client() -> BasicClient { // Environment variables (* = required): // *"CLIENT_ID" "123456789123456789"; // *"CLIENT_SECRET" "rAn60Mch4ra-CTErsSf-r04utHcLienT"; // "REDIRECT_URL" "http://127.0.0.1:3000/auth/authorized"; // "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code"; // "TOKEN_URL" "https://discord.com/api/oauth2/token"; let client_id = env::var("CLIENT_ID").expect("Missing CLIENT_ID!"); let client_secret = env::var("CLIENT_SECRET").expect("Missing CLIENT_SECRET!"); let redirect_url = env::var("REDIRECT_URL") .unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string()); let auth_url = env::var("AUTH_URL").unwrap_or_else(|_| { "https://discord.com/api/oauth2/authorize?response_type=code".to_string() }); let token_url = env::var("TOKEN_URL") .unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string()); let client = BasicClient::new( ClientId::new(client_id), Some(ClientSecret::new(client_secret)), AuthUrl::new(auth_url).unwrap(), Some(TokenUrl::new(token_url).unwrap()), ) .set_redirect_uri(RedirectUrl::new(redirect_url).unwrap()); tracing::debug!("client: {:?}", client); client } // The user data we'll get back from Discord. // https://discord.com/developers/docs/resources/user#user-object-user-structure #[derive(Debug, Serialize, Deserialize)] struct DiscordUser { id: String, avatar: Option, username: String, discriminator: String, } // Session is optional async fn index(user: Option) -> impl IntoResponse { match user { Some(u) => format!( "Hey {}! You're logged in!\nYou may now access `/protected`.\nLog out with `/logout`.", u.username ), None => "You're not logged in.\nVisit `/auth/discord` to do so.".to_string(), } } async fn discord_auth(Extension(client): Extension) -> impl IntoResponse { let (auth_url, _csrf_token) = client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) .url(); // Redirect to Discord's oauth service Redirect(auth_url.into()) } // Valid user session required. If there is none, redirect to the auth page async fn protected(user: DiscordUser) -> impl IntoResponse { serde_json::to_string(&user).expect("could not serialize user") } async fn avatar_url(user: DiscordUser) -> impl IntoResponse { let cdn_url = env::var("CDN_URL").unwrap_or_else(|_| "https://cdn.discordapp.com".to_string()); match user.avatar { Some(id) => format!("{}/avatars/{}/{}.webp?size=256", cdn_url, user.id, id), None => format!("{}/embed/avatars/0.png?size=256", cdn_url), } } async fn logout( Extension(store): Extension, TypedHeader(cookies): TypedHeader, ) -> impl IntoResponse { let cookie = cookies.get(COOKIE_NAME).unwrap(); let session = match store.load_session(cookie.to_string()).await.unwrap() { Some(s) => s, // No session active, just redirect None => return Redirect("/".to_string()), }; store.destroy_session(session).await.unwrap(); Redirect("/".to_string()) } #[derive(Debug, Deserialize)] struct AuthRequest { code: String, state: String, } async fn login_authorized( Query(query): Query, Extension(store): Extension, Extension(oauth_client): Extension, ) -> impl IntoResponse { // Get an auth token let token = oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) .request_async(async_http_client) .await .unwrap(); // Fetch user data from discord let client = reqwest::Client::new(); let user_data: DiscordUser = client // https://discord.com/developers/docs/resources/user#get-current-user .get("https://discordapp.com/api/users/@me") .bearer_auth(token.access_token().secret()) .send() .await .unwrap() .json::() .await .unwrap(); // Create a new session filled with user data let mut session = Session::new(); session.insert("user", &user_data).unwrap(); // Store session and get corresponding cookie let cookie = store.store_session(session).await.unwrap().unwrap(); // Build the cookie let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie); // Set cookie let r = http::Response::builder() .header("Location", "/") .header(SET_COOKIE, cookie) .status(302); r.body(Body::empty()).unwrap() } // Utility to save some lines of code struct Redirect(String); impl IntoResponse for Redirect { fn into_response(self) -> http::Response { let builder = http::Response::builder() .header("Location", self.0) .status(StatusCode::FOUND); builder.body(Body::empty()).unwrap() } } struct AuthRedirect; impl IntoResponse for AuthRedirect { fn into_response(self) -> http::Response { Redirect("/auth/discord".to_string()).into_response() } } #[async_trait] impl FromRequest for DiscordUser where B: Send, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; async fn from_request(req: &mut RequestParts) -> Result { let extract::Extension(store) = extract::Extension::::from_request(req) .await .expect("`RedisSessionStore` extension is missing"); let cookies: TypedHeader = TypedHeader::::from_request(req) .await .expect("could not get cookies"); let session_cookie = cookies.0.get(COOKIE_NAME).ok_or(AuthRedirect)?; let session = store .load_session(session_cookie.to_string()) .await .unwrap() .ok_or(AuthRedirect)?; let user = session.get::("user").ok_or(AuthRedirect)?; Ok(user) } } pub fn get_routes() -> BoxRoute { route("/", get(index)) .route("/discord", get(discord_auth)) .route("/authorized", get(login_authorized)) .route("/protected", get(protected)) .route("/avatar", get(avatar_url)) .route("/logout", get(logout)) .layer(AddExtensionLayer::new(oauth_client())) .boxed() }