diff --git a/src/config.rs b/src/config.rs index 4d8f21e..172226f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,9 @@ +use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl, basic::BasicClient}; use serde::{Deserialize, Serialize}; use std::path::Path; +use crate::OidcClient; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { pub oidc: OidcConfig, @@ -18,6 +21,24 @@ pub struct OidcConfig { pub redirect_url: String, } +impl OidcConfig { + pub fn to_client( + &self, + ) -> OidcClient { + let client_id = ClientId::new(self.client_id.clone()); + let client_secret = ClientSecret::new(self.client_secret.clone()); + let auth_url = AuthUrl::new(self.auth_url.clone()).unwrap(); + let token_url = TokenUrl::new(self.token_url.clone()).unwrap(); + let redirect_url = RedirectUrl::new(self.redirect_url.clone()).unwrap(); + + BasicClient::new(client_id) + .set_client_secret(client_secret) + .set_auth_uri(auth_url) + .set_token_uri(token_url) + .set_redirect_uri(redirect_url) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerConfig { pub host: String, diff --git a/src/handlers.rs b/src/handlers.rs index 49b6401..7004a71 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -3,8 +3,11 @@ use axum::response::{Html, Redirect}; use axum::{Form, extract::Query, extract::State}; use axum_session::Session; use axum_session_sqlx::SessionSqlitePool; -use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, HttpRequest, HttpResponse}; use oauth2::http; +use oauth2::{ + AuthorizationCode, CsrfToken, HttpRequest, HttpResponse, PkceCodeChallenge, PkceCodeVerifier, + Scope, TokenResponse, +}; use reqwest::Client; use serde::Deserialize; @@ -28,7 +31,10 @@ pub async fn index(State(state): State) -> Html { Html(template.render().unwrap()) } -pub async fn login(State(state): State, session: Session) -> Redirect { +pub async fn login( + State(state): State, + session: Session, +) -> Redirect { // Generate PKCE challenge let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); @@ -102,18 +108,16 @@ pub async fn callback( } // Async HTTP client for oauth2 -async fn async_http_client( - req: HttpRequest, -) -> Result { +async fn async_http_client(req: HttpRequest) -> Result { let client = Client::new(); - + // Convert http::Method to reqwest::Method let method_str = format!("{}", req.method()); let method = reqwest::Method::from_bytes(method_str.as_bytes()).unwrap(); - + // Clone the URI before consuming the request let uri = req.uri().clone(); - + let mut req_builder = client.request(method, uri.to_string()); for (name, value) in req.headers() { @@ -130,9 +134,8 @@ async fn async_http_client( let body = response.bytes().await?; // Construct an http::Response - let mut http_response = http::Response::builder() - .status(status_code); - + let mut http_response = http::Response::builder().status(status_code); + for (k, v) in headers.iter() { http_response = http_response.header(k, v); } @@ -146,8 +149,8 @@ pub async fn input_post( Form(form): Form, ) -> Result, Redirect> { // Check if user is authenticated - if let Some(user_id) = session.get::("user_id") { - super::models::insert_weight(&state.pool, &user_id, &form.date, form.weight) + // if let Some(user_id) = session.get::("user_id") { + super::models::insert_weight(&state.pool, "lukas", &form.date, form.weight) .await .unwrap(); let weights = super::models::get_all_weights(&state.pool) @@ -157,12 +160,12 @@ pub async fn input_post( for weight in weights { html.push_str(&format!( "

{}: {} kg by {}

\n", - weight.date, weight.weight, weight.user_id + weight.date, weight.weight, "lukas" )); } Ok(Html(html)) - } else { - // Redirect to login if not authenticated - Err(Redirect::to("/auth/login")) - } + // } else { + // // Redirect to login if not authenticated + // Err(Redirect::to("/auth/login")) + // } } diff --git a/src/lib.rs b/src/lib.rs index 8138106..6086663 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,10 +11,7 @@ use axum_session_sqlx::SessionSqlitePool; use sqlx::SqlitePool; use tower_http::services::ServeDir; -#[derive(Clone)] -pub struct AppState { - pub pool: SqlitePool, - pub oidc_client: oauth2::Client< +pub type OidcClient = oauth2::Client< oauth2::StandardErrorResponse, oauth2::StandardTokenResponse, oauth2::StandardTokenIntrospectionResponse< @@ -28,7 +25,12 @@ pub struct AppState { oauth2::EndpointNotSet, oauth2::EndpointNotSet, oauth2::EndpointSet, - >, + >; + +#[derive(Clone)] +pub struct AppState { + pub pool: SqlitePool, + pub oidc_client: OidcClient, } pub async fn create_app(state: AppState, session_secret: Vec, pool: SqlitePool) -> Router { diff --git a/src/main.rs b/src/main.rs index a256d1f..b11ba61 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ -use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl, basic::BasicClient}; +use std::str::FromStr; + use sqlx::SqlitePool; -use weight_tracker::{AppState, create_app, config::Config}; +use weight_tracker::{AppState, config::Config, create_app}; #[tokio::main] async fn main() { @@ -13,26 +14,16 @@ async fn main() { }); // Set up database - let pool = SqlitePool::connect(&config.database.url) - .await - .expect("Failed to connect to database"); + let pool = SqlitePool::connect_with( + sqlx::sqlite::SqliteConnectOptions::from_str(&config.database.url) + .expect("Could not parse database URL") + .create_if_missing(true), + ) + .await + .expect("Failed to connect to database"); // Set up OIDC client - let client_id = ClientId::new(config.oidc.client_id); - let client_secret = ClientSecret::new(config.oidc.client_secret); - let auth_url = AuthUrl::new(config.oidc.auth_url) - .unwrap(); - let token_url = TokenUrl::new(config.oidc.token_url) - .unwrap(); - let redirect_url = RedirectUrl::new(config.oidc.redirect_url) - .unwrap(); - - let oidc_client = BasicClient::new(client_id) - .set_client_secret(client_secret) - .set_auth_uri(auth_url) - .set_token_uri(token_url) - // Set the URL the user will be redirected to after the authorization process. - .set_redirect_uri(redirect_url); + let oidc_client = config.oidc.to_client(); let secret = config.session.secret.as_bytes().to_vec(); @@ -44,9 +35,7 @@ async fn main() { let app = create_app(app_state, secret, pool).await; let addr = format!("{}:{}", config.server.host, config.server.port); - let listener = tokio::net::TcpListener::bind(&addr) - .await - .unwrap(); + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); }