use askama::Template; use axum::response::{Html, Redirect}; use axum::{Form, extract::Query, extract::State}; use axum_session::Session; use axum_session_sqlx::SessionSqlitePool; use oauth2::http; use oauth2::{ AuthorizationCode, CsrfToken, HttpRequest, HttpResponse, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, }; use reqwest::Client; use serde::Deserialize; #[derive(Template)] #[template(path = "index.html")] pub struct IndexTemplate { pub weights: Vec, } #[derive(Deserialize)] pub struct WeightForm { date: String, time: String, weight: f64, } pub async fn index(State(state): State) -> Html { let weights = super::models::get_all_weights(&state.pool) .await .unwrap_or_default(); let template = IndexTemplate { weights }; Html(template.render().unwrap()) } pub async fn login( State(state): State, session: Session, ) -> Redirect { // Generate PKCE challenge let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); // Generate the authorization URL let (auth_url, csrf_token) = state .oidc_client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("openid".to_string())) .add_scope(Scope::new("profile".to_string())) .set_pkce_challenge(pkce_challenge) .url(); // Store the CSRF token and PKCE verifier in the session session.set("csrf_token", csrf_token.secret().clone()); session.set("pkce_verifier", pkce_verifier.secret().clone()); Redirect::to(auth_url.as_str()) } #[cfg(test)] pub async fn test_login(session: Session) -> Redirect { session.set("user_id", "test-user"); Redirect::to("/") } #[derive(Deserialize)] pub struct AuthCallbackQuery { code: String, state: String, } pub async fn callback( State(state): State, Query(query): Query, session: Session, ) -> Redirect { // Verify CSRF token if let Some(stored_csrf) = session.get::("csrf_token") { if stored_csrf != query.state { return Redirect::to("/?error=invalid_state"); } } else { return Redirect::to("/?error=no_state"); } // Get PKCE verifier let pkce_verifier = if let Some(verifier) = session.get::("pkce_verifier") { PkceCodeVerifier::new(verifier) } else { return Redirect::to("/?error=no_verifier"); }; // Exchange code for token let token_result = state .oidc_client .exchange_code(AuthorizationCode::new(query.code)) .set_pkce_verifier(pkce_verifier) .request_async(&async_http_client) .await; match token_result { Ok(token) => { // For OIDC, extract user info from token // Using the access token as a simple user identifier let user_id = token.access_token().secret().clone(); session.set("user_id", user_id); Redirect::to("/") } Err(_) => Redirect::to("/?error=token_exchange_failed"), } } // Async HTTP client for oauth2 async fn async_http_client(req: HttpRequest) -> Result { let client = Client::new(); // Convert http::Method to reqwest::Method let method_str = format!("{}", req.method()); let method = reqwest::Method::from_bytes(method_str.as_bytes()).unwrap(); // Clone the URI before consuming the request let uri = req.uri().clone(); let mut req_builder = client.request(method, uri.to_string()); for (name, value) in req.headers() { let header_name = reqwest::header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap(); let header_value = reqwest::header::HeaderValue::from_bytes(value.as_bytes()).unwrap(); req_builder = req_builder.header(header_name, header_value); } let body = req.into_body(); let response = req_builder.body(body).send().await?; let status_code = response.status(); let headers = response.headers().clone(); let body = response.bytes().await?; // Construct an http::Response let mut http_response = http::Response::builder().status(status_code); for (k, v) in headers.iter() { http_response = http_response.header(k, v); } Ok(http_response.body(body.to_vec()).unwrap()) } pub async fn input_post( State(state): State, session: Session, Form(form): Form, ) -> Result, Redirect> { // Check if user is authenticated if let Some(user_id) = session.get::("user_id") { let datetime = format!("{}T{}", form.date, form.time); super::models::insert_weight(&state.pool, &user_id, &datetime, form.weight) .await .unwrap(); let weights = super::models::get_all_weights(&state.pool) .await .unwrap_or_default(); let mut html = String::new(); for weight in weights { html.push_str(&format!( "

{}: {} kg by {}

\n", weight.date, weight.weight, &user_id )); } Ok(Html(html)) } else { // Redirect to login if not authenticated Err(Redirect::to("/auth/login")) } }