From b5cbe304e537c4fdf12ba4039b3cfb612433b6c1 Mon Sep 17 00:00:00 2001 From: "Sebastian H. Gabrielli" Date: Sun, 31 Dec 2023 16:53:11 +0100 Subject: [PATCH] Make authroization use shared state, giving ~200 speedup --- src/authorization.rs | 35 +++++++++++++++++++---------------- src/jwt_validation.rs | 4 ++++ src/main.rs | 29 ++++++++++++++++++++++++++++- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/src/authorization.rs b/src/authorization.rs index a60f530..8724367 100644 --- a/src/authorization.rs +++ b/src/authorization.rs @@ -1,12 +1,11 @@ // Import the Rocket requirements use rocket::http::Status; -use rocket::request::{self, Outcome, Request, FromRequest}; +use rocket::State; +use rocket::request::{Outcome, Request, FromRequest}; -// Import the jwt validation functions -mod jwt_validation { - include!("jwt_validation.rs"); -} -use crate::authorization::jwt_validation::*; +// Import async types to share jwt_info +use crate::SharedJwtInfo; +use crate::jwt_validation::*; #[derive(Debug)] pub struct BoardMember { @@ -16,6 +15,7 @@ pub struct BoardMember { #[derive(Debug)] pub enum AuthenticationError { InvalidJWT, + FailedToGrabJWTInfo, MissingAuthenticationHeader, InvalidAuthenticationHeader, } @@ -43,20 +43,23 @@ impl<'r> FromRequest<'r> for BoardMember { None => return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidAuthenticationHeader)) } - // This is temporary, this should be saved and not called on each validation - let mut jwt_info: JwtInfo; - match fetch_jwt_info("https://sso.gitgals.com/application/o/sebtest/.well-known/openid-configuration", vec!("https://sso.gitgals.com/application/o/sebtest/".into())).await { - Ok(data) => jwt_info = data, - Err(e) => { - println!("{:?}", e); - return Outcome::Error((Status::InternalServerError, AuthenticationError::InvalidJWT)) - }, + // Grab the JWT info struct from rocket + let state = req.guard::<&State>().await; + + // Handle the state, and upon success store it in the jwt_info + let jwt_info; + match state { + Outcome::Success(fetched_state) => jwt_info = fetched_state.inner().0.as_ref(), + Outcome::Error(err) => return Outcome::Error((err.0, AuthenticationError::FailedToGrabJWTInfo)), + Outcome::Forward(status) => return Outcome::Error((status, AuthenticationError::FailedToGrabJWTInfo)), } - jwt_info.audience = vec!("CLaLr8sikEiN7NCrPMhjhbtLZgnZJ6JZVzPdVN5P".into()); + + // Grab the variable lock + let mut jwt_info_lock = jwt_info.lock().await; // Validate the token and store the result let valid_token: MyClaims; - match validate_jwt(jwt_token, &mut jwt_info).await { + match validate_jwt(jwt_token, &mut jwt_info_lock).await { Ok(data) => valid_token = data, Err(e) => { println!("{:?}", e); diff --git a/src/jwt_validation.rs b/src/jwt_validation.rs index 4770dca..d5ae52a 100644 --- a/src/jwt_validation.rs +++ b/src/jwt_validation.rs @@ -3,6 +3,10 @@ use reqwest; use rocket::serde::{Deserialize, json::serde_json}; use std::collections::HashMap; +// JWT Shared state async stuff +use rocket::tokio::sync::Mutex; +use std::sync::Arc; + // Define a struct for the claims you expect in your token #[derive(Debug, Deserialize)] pub struct MyClaims { diff --git a/src/main.rs b/src/main.rs index 10506a9..757d294 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,10 +15,18 @@ use webserver_member::*; // Handle CORS use rocket_cors::{AllowedOrigins, CorsOptions}; -// Handle authentication +// Handle autorization +// Async types to share jwt_info +use rocket::tokio::sync::Mutex; +use std::sync::Arc; +// Actual authorization functions +mod jwt_validation; +use jwt_validation::{JwtInfo, fetch_jwt_info}; mod authorization; use authorization::BoardMember; +struct SharedJwtInfo(Arc>); + // Serve the very exiting main page #[get("/")] fn index(board_member: BoardMember) -> String { @@ -27,16 +35,35 @@ fn index(board_member: BoardMember) -> String { #[launch] async fn rocket() -> _ { + // Create a database connection let db = match set_up_db().await { Ok(db) => db, Err(err) => panic!("{}", err) }; + + // Fetch the JWT info from the authentication server + let jwt_info: Arc> = Arc::new(Mutex::new(fetch_jwt_info( + "https://sso.gitgals.com/application/o/sebtest/.well-known/openid-configuration", + vec!["https://sso.gitgals.com/application/o/sebtest/".into()] + ).await.expect("Failed to fetch authorization info"))); + + // Set the expected audience field in a scoped block to make sure the lock is released as soon as we are done + { + // Grab the mutex lock + let mut jwt_info_lock = jwt_info.lock().await; + // Set the expected audience field + jwt_info_lock.audience = vec!("CLaLr8sikEiN7NCrPMhjhbtLZgnZJ6JZVzPdVN5P".into()); + } + + // Configure CORS let cors = CorsOptions::default().allowed_origins(AllowedOrigins::all()) .to_cors().unwrap(); + // Start the webserver rocket::build() .manage(db) + .manage(SharedJwtInfo(jwt_info)) .attach(cors) .mount("/", routes![ index,