Compare commits

...

2 Commits

Author SHA1 Message Date
Lukas Wölfer
7010aee5b2 feat: automatically run migrations on database file
Some checks failed
Rust / build_and_test (push) Failing after 1m44s
2026-04-11 14:36:42 +02:00
Lukas Wölfer
614f044160 fix: database creation 2026-04-11 14:33:36 +02:00
4 changed files with 64 additions and 45 deletions

View File

@@ -1,6 +1,9 @@
use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl, basic::BasicClient};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::Path; use std::path::Path;
use crate::OidcClient;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config { pub struct Config {
pub oidc: OidcConfig, pub oidc: OidcConfig,
@@ -18,6 +21,24 @@ pub struct OidcConfig {
pub redirect_url: String, pub redirect_url: String,
} }
impl OidcConfig {
pub fn to_client(
&self,
) -> OidcClient {
let client_id = ClientId::new(self.client_id.clone());
let client_secret = ClientSecret::new(self.client_secret.clone());
let auth_url = AuthUrl::new(self.auth_url.clone()).unwrap();
let token_url = TokenUrl::new(self.token_url.clone()).unwrap();
let redirect_url = RedirectUrl::new(self.redirect_url.clone()).unwrap();
BasicClient::new(client_id)
.set_client_secret(client_secret)
.set_auth_uri(auth_url)
.set_token_uri(token_url)
.set_redirect_uri(redirect_url)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig { pub struct ServerConfig {
pub host: String, pub host: String,

View File

@@ -3,8 +3,11 @@ use axum::response::{Html, Redirect};
use axum::{Form, extract::Query, extract::State}; use axum::{Form, extract::Query, extract::State};
use axum_session::Session; use axum_session::Session;
use axum_session_sqlx::SessionSqlitePool; use axum_session_sqlx::SessionSqlitePool;
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, HttpRequest, HttpResponse};
use oauth2::http; use oauth2::http;
use oauth2::{
AuthorizationCode, CsrfToken, HttpRequest, HttpResponse, PkceCodeChallenge, PkceCodeVerifier,
Scope, TokenResponse,
};
use reqwest::Client; use reqwest::Client;
use serde::Deserialize; use serde::Deserialize;
@@ -28,7 +31,10 @@ pub async fn index(State(state): State<crate::AppState>) -> Html<String> {
Html(template.render().unwrap()) Html(template.render().unwrap())
} }
pub async fn login(State(state): State<crate::AppState>, session: Session<SessionSqlitePool>) -> Redirect { pub async fn login(
State(state): State<crate::AppState>,
session: Session<SessionSqlitePool>,
) -> Redirect {
// Generate PKCE challenge // Generate PKCE challenge
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
@@ -102,18 +108,16 @@ pub async fn callback(
} }
// Async HTTP client for oauth2 // Async HTTP client for oauth2
async fn async_http_client( async fn async_http_client(req: HttpRequest) -> Result<HttpResponse, reqwest::Error> {
req: HttpRequest,
) -> Result<HttpResponse, reqwest::Error> {
let client = Client::new(); let client = Client::new();
// Convert http::Method to reqwest::Method // Convert http::Method to reqwest::Method
let method_str = format!("{}", req.method()); let method_str = format!("{}", req.method());
let method = reqwest::Method::from_bytes(method_str.as_bytes()).unwrap(); let method = reqwest::Method::from_bytes(method_str.as_bytes()).unwrap();
// Clone the URI before consuming the request // Clone the URI before consuming the request
let uri = req.uri().clone(); let uri = req.uri().clone();
let mut req_builder = client.request(method, uri.to_string()); let mut req_builder = client.request(method, uri.to_string());
for (name, value) in req.headers() { for (name, value) in req.headers() {
@@ -130,9 +134,8 @@ async fn async_http_client(
let body = response.bytes().await?; let body = response.bytes().await?;
// Construct an http::Response // Construct an http::Response
let mut http_response = http::Response::builder() let mut http_response = http::Response::builder().status(status_code);
.status(status_code);
for (k, v) in headers.iter() { for (k, v) in headers.iter() {
http_response = http_response.header(k, v); http_response = http_response.header(k, v);
} }
@@ -146,8 +149,8 @@ pub async fn input_post(
Form(form): Form<WeightForm>, Form(form): Form<WeightForm>,
) -> Result<Html<String>, Redirect> { ) -> Result<Html<String>, Redirect> {
// Check if user is authenticated // Check if user is authenticated
if let Some(user_id) = session.get::<String>("user_id") { // if let Some(user_id) = session.get::<String>("user_id") {
super::models::insert_weight(&state.pool, &user_id, &form.date, form.weight) super::models::insert_weight(&state.pool, "lukas", &form.date, form.weight)
.await .await
.unwrap(); .unwrap();
let weights = super::models::get_all_weights(&state.pool) let weights = super::models::get_all_weights(&state.pool)
@@ -157,12 +160,12 @@ pub async fn input_post(
for weight in weights { for weight in weights {
html.push_str(&format!( html.push_str(&format!(
"<p>{}: {} kg by {}</p>\n", "<p>{}: {} kg by {}</p>\n",
weight.date, weight.weight, weight.user_id weight.date, weight.weight, "lukas"
)); ));
} }
Ok(Html(html)) Ok(Html(html))
} else { // } else {
// Redirect to login if not authenticated // // Redirect to login if not authenticated
Err(Redirect::to("/auth/login")) // Err(Redirect::to("/auth/login"))
} // }
} }

View File

@@ -11,10 +11,7 @@ use axum_session_sqlx::SessionSqlitePool;
use sqlx::SqlitePool; use sqlx::SqlitePool;
use tower_http::services::ServeDir; use tower_http::services::ServeDir;
#[derive(Clone)] pub type OidcClient = oauth2::Client<
pub struct AppState {
pub pool: SqlitePool,
pub oidc_client: oauth2::Client<
oauth2::StandardErrorResponse<oauth2::basic::BasicErrorResponseType>, oauth2::StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, oauth2::basic::BasicTokenType>, oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, oauth2::basic::BasicTokenType>,
oauth2::StandardTokenIntrospectionResponse< oauth2::StandardTokenIntrospectionResponse<
@@ -28,7 +25,12 @@ pub struct AppState {
oauth2::EndpointNotSet, oauth2::EndpointNotSet,
oauth2::EndpointNotSet, oauth2::EndpointNotSet,
oauth2::EndpointSet, oauth2::EndpointSet,
>, >;
#[derive(Clone)]
pub struct AppState {
pub pool: SqlitePool,
pub oidc_client: OidcClient,
} }
pub async fn create_app(state: AppState, session_secret: Vec<u8>, pool: SqlitePool) -> Router { pub async fn create_app(state: AppState, session_secret: Vec<u8>, pool: SqlitePool) -> Router {

View File

@@ -1,6 +1,7 @@
use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl, basic::BasicClient}; use std::str::FromStr;
use sqlx::SqlitePool; use sqlx::SqlitePool;
use weight_tracker::{AppState, create_app, config::Config}; use weight_tracker::{AppState, config::Config, create_app};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@@ -13,26 +14,20 @@ async fn main() {
}); });
// Set up database // Set up database
let pool = SqlitePool::connect(&config.database.url) let pool = SqlitePool::connect_with(
sqlx::sqlite::SqliteConnectOptions::from_str(&config.database.url)
.expect("Could not parse database URL")
.create_if_missing(true),
)
.await
.expect("Failed to connect to database");
sqlx::migrate!()
.run(&pool)
.await .await
.expect("Failed to connect to database"); .expect("Could not run database migrations");
// Set up OIDC client // Set up OIDC client
let client_id = ClientId::new(config.oidc.client_id); let oidc_client = config.oidc.to_client();
let client_secret = ClientSecret::new(config.oidc.client_secret);
let auth_url = AuthUrl::new(config.oidc.auth_url)
.unwrap();
let token_url = TokenUrl::new(config.oidc.token_url)
.unwrap();
let redirect_url = RedirectUrl::new(config.oidc.redirect_url)
.unwrap();
let oidc_client = BasicClient::new(client_id)
.set_client_secret(client_secret)
.set_auth_uri(auth_url)
.set_token_uri(token_url)
// Set the URL the user will be redirected to after the authorization process.
.set_redirect_uri(redirect_url);
let secret = config.session.secret.as_bytes().to_vec(); let secret = config.session.secret.as_bytes().to_vec();
@@ -44,9 +39,7 @@ async fn main() {
let app = create_app(app_state, secret, pool).await; let app = create_app(app_state, secret, pool).await;
let addr = format!("{}:{}", config.server.host, config.server.port); let addr = format!("{}:{}", config.server.host, config.server.port);
let listener = tokio::net::TcpListener::bind(&addr) let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap()); println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap(); axum::serve(listener, app).await.unwrap();
} }