use askama::Template; use axum::response::{Html, Redirect}; use axum::{Form, extract::Query, extract::State}; use axum_sessions::extractors::{ReadableSession, WritableSession}; use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope}; use reqwest::Client; use serde::Deserialize; #[derive(Template)] #[template(path = "index.html")] pub struct IndexTemplate { pub weights: Vec, } #[derive(Deserialize)] pub struct WeightForm { date: String, weight: f64, } pub async fn index(State(state): State) -> Html { let weights = super::models::get_all_weights(&state.pool) .await .unwrap_or_default(); let template = IndexTemplate { weights }; Html(template.render().unwrap()) } pub async fn login(State(state): State, mut session: WritableSession) -> Redirect { // Generate PKCE challenge let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); // Generate the authorization URL let (auth_url, csrf_token) = state .oidc_client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("openid".to_string())) .add_scope(Scope::new("profile".to_string())) .set_pkce_challenge(pkce_challenge) .url(); // Store the CSRF token and PKCE verifier in the session session .insert("csrf_token", csrf_token.secret().clone()) .unwrap(); session .insert("pkce_verifier", pkce_verifier.secret().clone()) .unwrap(); Redirect::to(auth_url.as_str()) } #[cfg(test)] pub async fn test_login(mut session: WritableSession) -> Redirect { session.insert("user_id", "test-user").unwrap(); Redirect::to("/") } #[derive(Deserialize)] pub struct AuthCallbackQuery { code: String, state: String, } pub async fn callback( State(state): State, Query(query): Query, mut session: WritableSession, ) -> Redirect { // Verify CSRF token if let Some(stored_csrf) = session.get::("csrf_token") { if stored_csrf != query.state { return Redirect::to("/?error=invalid_state"); } } else { return Redirect::to("/?error=no_state"); } // Get PKCE verifier let pkce_verifier = if let Some(verifier) = session.get::("pkce_verifier") { PkceCodeVerifier::new(verifier) } else { return Redirect::to("/?error=no_verifier"); }; // Exchange code for token let token_result = state .oidc_client .exchange_code(AuthorizationCode::new(query.code)) .set_pkce_verifier(pkce_verifier) .request_async(async_http_client) .await; match token_result { Ok(token) => { // For OIDC, the ID token contains user info if let Some(id_token) = token.id_token() { // Decode ID token (simplified, in practice you'd verify signature) // For now, assume it's valid and extract sub as user_id let claims = id_token.payload().clone(); if let Some(sub) = claims.subject() { session.insert("user_id", sub.to_string()).unwrap(); Redirect::to("/") } else { Redirect::to("/?error=no_subject") } } else { Redirect::to("/?error=no_id_token") } } Err(_) => Redirect::to("/?error=token_exchange_failed"), } } // Async HTTP client for oauth2 async fn async_http_client( request: oauth2::HttpRequest, ) -> Result { let client = Client::new(); let method_str = request.method.as_str(); let method = reqwest::Method::from_bytes(method_str.as_bytes()).unwrap(); let mut req_builder = client.request(method, request.url); for (name, value) in request.headers { let Some(header_name_str) = name.and_then(|f| Some(f.as_str().clone())) else { continue; }; let header_name = reqwest::header::HeaderName::from_bytes(header_name_str.as_bytes()).unwrap(); let header_value = reqwest::header::HeaderValue::from_bytes(value.as_bytes()).unwrap(); req_builder = req_builder.header(header_name, header_value); } let response = req_builder.body(request.body).send().await?; let status_code = response.status().as_u16(); let headers = response.headers().clone(); let body = response.bytes().await?; // Convert headers let mut oauth_headers = oauth2::http::HeaderMap::new(); for (k, v) in headers.iter() { let name = oauth2::http::HeaderName::from_bytes(k.as_str().as_bytes()).unwrap(); let value = oauth2::http::HeaderValue::from_bytes(v.as_bytes()).unwrap(); oauth_headers.insert(name, value); } Ok(oauth2::HttpResponse { status_code: oauth2::http::StatusCode::from_u16(status_code).unwrap(), headers: oauth_headers, body: body.to_vec(), }) } pub async fn input_post( State(state): State, session: ReadableSession, Form(form): Form, ) -> Result, Redirect> { // Check if user is authenticated if let Some(user_id) = session.get::("user_id") { super::models::insert_weight(&state.pool, &user_id, &form.date, form.weight) .await .unwrap(); let weights = super::models::get_all_weights(&state.pool) .await .unwrap_or_default(); let mut html = String::new(); for weight in weights { html.push_str(&format!( "

{}: {} kg by {}

\n", weight.date, weight.weight, weight.user_id )); } Ok(Html(html)) } else { // Redirect to login if not authenticated Err(Redirect::to("/auth/login")) } }