use axum::{ extract::{Query, State}, response::Redirect, }; use base64::Engine; use openid::Options; use rand::Rng; use serde::Deserialize; use tower_sessions::Session; use crate::state::AppState; use irc_now_common::auth::UserClaims; fn extract_custom_claims(raw_jwt: &str) -> (Option, Option) { let payload_b64 = match raw_jwt.split('.').nth(1) { Some(p) => p, None => return (None, None), }; let bytes = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64) { Ok(b) => b, Err(_) => return (None, None), }; let val: serde_json::Value = match serde_json::from_slice(&bytes) { Ok(v) => v, Err(_) => return (None, None), }; let plan = val.get("plan").and_then(|v| v.as_str()).map(String::from); let content_expires = val .get("content_expires") .and_then(|v| match v { serde_json::Value::Bool(b) => Some(*b), serde_json::Value::String(s) => s.parse().ok(), _ => None, }); (plan, content_expires) } fn random_string(len: usize) -> String { let mut rng = rand::thread_rng(); (0..len) .map(|_| { let idx = rng.gen_range(0..36); if idx < 10 { (b'0' + idx) as char } else { (b'a' + idx - 10) as char } }) .collect() } pub async fn login(State(state): State, session: Session) -> Redirect { let csrf_state = random_string(32); let nonce = random_string(32); session .insert("oidc_state", &csrf_state) .await .expect("session insert failed"); session .insert("oidc_nonce", &nonce) .await .expect("session insert failed"); let options = Options { scope: Some("openid email profile".to_string()), state: Some(csrf_state), nonce: Some(nonce), ..Options::default() }; let auth_url = state.oidc_client.auth_url(&options); Redirect::temporary(auth_url.as_str()) } #[derive(Deserialize)] pub struct CallbackParams { code: String, state: String, } pub async fn callback( State(state): State, session: Session, Query(params): Query, ) -> Result { let stored_state: Option = session .get("oidc_state") .await .unwrap_or(None); let stored_nonce: Option = session .get("oidc_nonce") .await .unwrap_or(None); session.remove::("oidc_state").await.ok(); session.remove::("oidc_nonce").await.ok(); let Some(stored_state) = stored_state else { tracing::warn!("no oidc_state in session"); return Err(Redirect::temporary("/auth/login")); }; if params.state != stored_state { tracing::warn!("CSRF state mismatch"); return Err(Redirect::temporary("/auth/login")); } let nonce_ref = stored_nonce.as_deref(); let token = state .oidc_client .authenticate(¶ms.code, nonce_ref, None) .await .map_err(|e| { tracing::error!("token exchange failed: {e}"); Redirect::temporary("/auth/login") })?; let id_token = token.id_token.ok_or_else(|| { tracing::error!("no id_token in response"); Redirect::temporary("/auth/login") })?; let payload = id_token.payload().map_err(|e| { tracing::error!("failed to decode id_token payload: {e}"); Redirect::temporary("/auth/login") })?; let sub = &payload.userinfo.sub; let email = payload.userinfo.email.as_deref(); let (plan, content_expires) = token .bearer .id_token .as_deref() .map(extract_custom_claims) .unwrap_or((None, None)); let claims = UserClaims { sub: sub.clone(), email: email.map(String::from), display_name: None, plan, stripe_customer_id: None, content_expires, }; session .insert("user", &claims) .await .map_err(|e| { tracing::error!("failed to store user claims in session: {e}"); Redirect::temporary("/auth/login") })?; Ok(Redirect::temporary("/my")) } pub async fn logout(State(state): State, session: Session) -> Redirect { session.flush().await.ok(); let origin = url::Url::parse(&state.oidc.redirect_url) .map(|u| format!("{}://{}", u.scheme(), u.host_str().unwrap_or("irc.bot"))) .unwrap_or_else(|_| "https://irc.bot".to_string()); let logout_url = format!( "{}/protocol/openid-connect/logout?post_logout_redirect_uri={}&client_id={}", state.oidc.issuer_url, urlencoding::encode(&origin), urlencoding::encode(&state.oidc.client_id), ); Redirect::temporary(&logout_url) }