This commit is contained in:
@@ -3,7 +3,8 @@ use axum::response::{Html, Redirect};
|
||||
use axum::{Form, extract::Query, extract::State};
|
||||
use axum_session::Session;
|
||||
use axum_session_sqlx::SessionSqlitePool;
|
||||
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope};
|
||||
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, HttpRequest, HttpResponse};
|
||||
use oauth2::http;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
|
||||
@@ -85,25 +86,16 @@ pub async fn callback(
|
||||
.oidc_client
|
||||
.exchange_code(AuthorizationCode::new(query.code))
|
||||
.set_pkce_verifier(pkce_verifier)
|
||||
.request_async(async_http_client)
|
||||
.request_async(&async_http_client)
|
||||
.await;
|
||||
|
||||
match token_result {
|
||||
Ok(token) => {
|
||||
// For OIDC, the ID token contains user info
|
||||
if let Some(id_token) = token.id_token() {
|
||||
// Decode ID token (simplified, in practice you'd verify signature)
|
||||
// For now, assume it's valid and extract sub as user_id
|
||||
let claims = id_token.payload().clone();
|
||||
if let Some(sub) = claims.subject() {
|
||||
session.set("user_id", sub.to_string());
|
||||
Redirect::to("/")
|
||||
} else {
|
||||
Redirect::to("/?error=no_subject")
|
||||
}
|
||||
} else {
|
||||
Redirect::to("/?error=no_id_token")
|
||||
}
|
||||
// For OIDC, extract user info from token
|
||||
// Using the access token as a simple user identifier
|
||||
let user_id = token.access_token().secret().clone();
|
||||
session.set("user_id", user_id);
|
||||
Redirect::to("/")
|
||||
}
|
||||
Err(_) => Redirect::to("/?error=token_exchange_failed"),
|
||||
}
|
||||
@@ -111,41 +103,41 @@ pub async fn callback(
|
||||
|
||||
// Async HTTP client for oauth2
|
||||
async fn async_http_client(
|
||||
request: oauth2::HttpRequest,
|
||||
) -> Result<oauth2::HttpResponse, reqwest::Error> {
|
||||
req: HttpRequest,
|
||||
) -> Result<HttpResponse, reqwest::Error> {
|
||||
let client = Client::new();
|
||||
let method_str = request.method.as_str();
|
||||
|
||||
// Convert http::Method to reqwest::Method
|
||||
let method_str = format!("{}", req.method());
|
||||
let method = reqwest::Method::from_bytes(method_str.as_bytes()).unwrap();
|
||||
let mut req_builder = client.request(method, request.url);
|
||||
|
||||
// Clone the URI before consuming the request
|
||||
let uri = req.uri().clone();
|
||||
|
||||
let mut req_builder = client.request(method, uri.to_string());
|
||||
|
||||
for (name, value) in request.headers {
|
||||
let Some(header_name_str) = name.and_then(|f| Some(f.as_str().clone())) else {
|
||||
continue;
|
||||
};
|
||||
for (name, value) in req.headers() {
|
||||
let header_name =
|
||||
reqwest::header::HeaderName::from_bytes(header_name_str.as_bytes()).unwrap();
|
||||
reqwest::header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap();
|
||||
let header_value = reqwest::header::HeaderValue::from_bytes(value.as_bytes()).unwrap();
|
||||
req_builder = req_builder.header(header_name, header_value);
|
||||
}
|
||||
|
||||
let response = req_builder.body(request.body).send().await?;
|
||||
let status_code = response.status().as_u16();
|
||||
let body = req.into_body();
|
||||
let response = req_builder.body(body).send().await?;
|
||||
let status_code = response.status();
|
||||
let headers = response.headers().clone();
|
||||
let body = response.bytes().await?;
|
||||
|
||||
// Convert headers
|
||||
let mut oauth_headers = oauth2::http::HeaderMap::new();
|
||||
// Construct an http::Response
|
||||
let mut http_response = http::Response::builder()
|
||||
.status(status_code);
|
||||
|
||||
for (k, v) in headers.iter() {
|
||||
let name = oauth2::http::HeaderName::from_bytes(k.as_str().as_bytes()).unwrap();
|
||||
let value = oauth2::http::HeaderValue::from_bytes(v.as_bytes()).unwrap();
|
||||
oauth_headers.insert(name, value);
|
||||
http_response = http_response.header(k, v);
|
||||
}
|
||||
|
||||
Ok(oauth2::HttpResponse {
|
||||
status_code: oauth2::http::StatusCode::from_u16(status_code).unwrap(),
|
||||
headers: oauth_headers,
|
||||
body: body.to_vec(),
|
||||
})
|
||||
Ok(http_response.body(body.to_vec()).unwrap())
|
||||
}
|
||||
|
||||
pub async fn input_post(
|
||||
|
||||
Reference in New Issue
Block a user