diff --git a/src/handlers.rs b/src/handlers.rs index 324af64..49b6401 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -3,7 +3,8 @@ use axum::response::{Html, Redirect}; use axum::{Form, extract::Query, extract::State}; use axum_session::Session; use axum_session_sqlx::SessionSqlitePool; -use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope}; +use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, HttpRequest, HttpResponse}; +use oauth2::http; use reqwest::Client; use serde::Deserialize; @@ -85,25 +86,16 @@ pub async fn callback( .oidc_client .exchange_code(AuthorizationCode::new(query.code)) .set_pkce_verifier(pkce_verifier) - .request_async(async_http_client) + .request_async(&async_http_client) .await; match token_result { Ok(token) => { - // For OIDC, the ID token contains user info - if let Some(id_token) = token.id_token() { - // Decode ID token (simplified, in practice you'd verify signature) - // For now, assume it's valid and extract sub as user_id - let claims = id_token.payload().clone(); - if let Some(sub) = claims.subject() { - session.set("user_id", sub.to_string()); - Redirect::to("/") - } else { - Redirect::to("/?error=no_subject") - } - } else { - Redirect::to("/?error=no_id_token") - } + // For OIDC, extract user info from token + // Using the access token as a simple user identifier + let user_id = token.access_token().secret().clone(); + session.set("user_id", user_id); + Redirect::to("/") } Err(_) => Redirect::to("/?error=token_exchange_failed"), } @@ -111,41 +103,41 @@ pub async fn callback( // Async HTTP client for oauth2 async fn async_http_client( - request: oauth2::HttpRequest, -) -> Result { + req: HttpRequest, +) -> Result { let client = Client::new(); - let method_str = request.method.as_str(); + + // Convert http::Method to reqwest::Method + let method_str = format!("{}", req.method()); let method = reqwest::Method::from_bytes(method_str.as_bytes()).unwrap(); - let mut req_builder = client.request(method, request.url); + + // Clone the URI before consuming the request + let uri = req.uri().clone(); + + let mut req_builder = client.request(method, uri.to_string()); - for (name, value) in request.headers { - let Some(header_name_str) = name.and_then(|f| Some(f.as_str().clone())) else { - continue; - }; + for (name, value) in req.headers() { let header_name = - reqwest::header::HeaderName::from_bytes(header_name_str.as_bytes()).unwrap(); + reqwest::header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap(); let header_value = reqwest::header::HeaderValue::from_bytes(value.as_bytes()).unwrap(); req_builder = req_builder.header(header_name, header_value); } - let response = req_builder.body(request.body).send().await?; - let status_code = response.status().as_u16(); + let body = req.into_body(); + let response = req_builder.body(body).send().await?; + let status_code = response.status(); let headers = response.headers().clone(); let body = response.bytes().await?; - // Convert headers - let mut oauth_headers = oauth2::http::HeaderMap::new(); + // Construct an http::Response + let mut http_response = http::Response::builder() + .status(status_code); + for (k, v) in headers.iter() { - let name = oauth2::http::HeaderName::from_bytes(k.as_str().as_bytes()).unwrap(); - let value = oauth2::http::HeaderValue::from_bytes(v.as_bytes()).unwrap(); - oauth_headers.insert(name, value); + http_response = http_response.header(k, v); } - Ok(oauth2::HttpResponse { - status_code: oauth2::http::StatusCode::from_u16(status_code).unwrap(), - headers: oauth_headers, - body: body.to_vec(), - }) + Ok(http_response.body(body.to_vec()).unwrap()) } pub async fn input_post( diff --git a/src/main.rs b/src/main.rs index f124a9d..644b91b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,21 +28,7 @@ async fn main() { .unwrap(); let redirect_url = RedirectUrl::new("http://localhost:3000/auth/callback".to_string()).unwrap(); - let oidc_client: oauth2::Client< - oauth2::StandardErrorResponse, - oauth2::StandardTokenResponse, - oauth2::StandardTokenIntrospectionResponse< - oauth2::EmptyExtraTokenFields, - oauth2::basic::BasicTokenType, - >, - oauth2::StandardRevocableToken, - oauth2::StandardErrorResponse, - oauth2::EndpointSet, - oauth2::EndpointNotSet, - oauth2::EndpointNotSet, - oauth2::EndpointNotSet, - oauth2::EndpointSet, - > = BasicClient::new(client_id) + let oidc_client = BasicClient::new(client_id) .set_client_secret(client_secret) .set_auth_uri(auth_url) .set_token_uri(token_url) @@ -50,13 +36,16 @@ async fn main() { .set_redirect_uri(redirect_url); let secret = env::var("SESSION_SECRET") - .unwrap_or_else(|_| "your_secret_key".to_string()) + .unwrap_or_else(|_| "your_secret_key_that_is_long_enough_so_the_library_does_not_complain".to_string()) .as_bytes() .to_vec(); - let app_state = AppState { pool, oidc_client }; + let app_state = AppState { + pool: pool.clone(), + oidc_client, + }; - let app = create_app(app_state, secret, pool.clone()).await; + let app = create_app(app_state, secret, pool).await; let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await