Make authroization use shared state, giving ~200 speedup

This commit is contained in:
Sebastian H. Gabrielli 2023-12-31 16:53:11 +01:00
parent a68730ac52
commit b5cbe304e5
3 changed files with 51 additions and 17 deletions

View File

@ -1,12 +1,11 @@
// Import the Rocket requirements // Import the Rocket requirements
use rocket::http::Status; use rocket::http::Status;
use rocket::request::{self, Outcome, Request, FromRequest}; use rocket::State;
use rocket::request::{Outcome, Request, FromRequest};
// Import the jwt validation functions // Import async types to share jwt_info
mod jwt_validation { use crate::SharedJwtInfo;
include!("jwt_validation.rs"); use crate::jwt_validation::*;
}
use crate::authorization::jwt_validation::*;
#[derive(Debug)] #[derive(Debug)]
pub struct BoardMember { pub struct BoardMember {
@ -16,6 +15,7 @@ pub struct BoardMember {
#[derive(Debug)] #[derive(Debug)]
pub enum AuthenticationError { pub enum AuthenticationError {
InvalidJWT, InvalidJWT,
FailedToGrabJWTInfo,
MissingAuthenticationHeader, MissingAuthenticationHeader,
InvalidAuthenticationHeader, InvalidAuthenticationHeader,
} }
@ -43,20 +43,23 @@ impl<'r> FromRequest<'r> for BoardMember {
None => return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidAuthenticationHeader)) None => return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidAuthenticationHeader))
} }
// This is temporary, this should be saved and not called on each validation // Grab the JWT info struct from rocket
let mut jwt_info: JwtInfo; let state = req.guard::<&State<SharedJwtInfo>>().await;
match fetch_jwt_info("https://sso.gitgals.com/application/o/sebtest/.well-known/openid-configuration", vec!("https://sso.gitgals.com/application/o/sebtest/".into())).await {
Ok(data) => jwt_info = data, // Handle the state, and upon success store it in the jwt_info
Err(e) => { let jwt_info;
println!("{:?}", e); match state {
return Outcome::Error((Status::InternalServerError, AuthenticationError::InvalidJWT)) Outcome::Success(fetched_state) => jwt_info = fetched_state.inner().0.as_ref(),
}, Outcome::Error(err) => return Outcome::Error((err.0, AuthenticationError::FailedToGrabJWTInfo)),
Outcome::Forward(status) => return Outcome::Error((status, AuthenticationError::FailedToGrabJWTInfo)),
} }
jwt_info.audience = vec!("CLaLr8sikEiN7NCrPMhjhbtLZgnZJ6JZVzPdVN5P".into());
// Grab the variable lock
let mut jwt_info_lock = jwt_info.lock().await;
// Validate the token and store the result // Validate the token and store the result
let valid_token: MyClaims; let valid_token: MyClaims;
match validate_jwt(jwt_token, &mut jwt_info).await { match validate_jwt(jwt_token, &mut jwt_info_lock).await {
Ok(data) => valid_token = data, Ok(data) => valid_token = data,
Err(e) => { Err(e) => {
println!("{:?}", e); println!("{:?}", e);

View File

@ -3,6 +3,10 @@ use reqwest;
use rocket::serde::{Deserialize, json::serde_json}; use rocket::serde::{Deserialize, json::serde_json};
use std::collections::HashMap; use std::collections::HashMap;
// JWT Shared state async stuff
use rocket::tokio::sync::Mutex;
use std::sync::Arc;
// Define a struct for the claims you expect in your token // Define a struct for the claims you expect in your token
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct MyClaims { pub struct MyClaims {

View File

@ -15,10 +15,18 @@ use webserver_member::*;
// Handle CORS // Handle CORS
use rocket_cors::{AllowedOrigins, CorsOptions}; use rocket_cors::{AllowedOrigins, CorsOptions};
// Handle authentication // Handle autorization
// Async types to share jwt_info
use rocket::tokio::sync::Mutex;
use std::sync::Arc;
// Actual authorization functions
mod jwt_validation;
use jwt_validation::{JwtInfo, fetch_jwt_info};
mod authorization; mod authorization;
use authorization::BoardMember; use authorization::BoardMember;
struct SharedJwtInfo(Arc<Mutex<JwtInfo>>);
// Serve the very exiting main page // Serve the very exiting main page
#[get("/")] #[get("/")]
fn index(board_member: BoardMember) -> String { fn index(board_member: BoardMember) -> String {
@ -27,16 +35,35 @@ fn index(board_member: BoardMember) -> String {
#[launch] #[launch]
async fn rocket() -> _ { async fn rocket() -> _ {
// Create a database connection
let db = match set_up_db().await { let db = match set_up_db().await {
Ok(db) => db, Ok(db) => db,
Err(err) => panic!("{}", err) Err(err) => panic!("{}", err)
}; };
// Fetch the JWT info from the authentication server
let jwt_info: Arc<Mutex<JwtInfo>> = Arc::new(Mutex::new(fetch_jwt_info(
"https://sso.gitgals.com/application/o/sebtest/.well-known/openid-configuration",
vec!["https://sso.gitgals.com/application/o/sebtest/".into()]
).await.expect("Failed to fetch authorization info")));
// Set the expected audience field in a scoped block to make sure the lock is released as soon as we are done
{
// Grab the mutex lock
let mut jwt_info_lock = jwt_info.lock().await;
// Set the expected audience field
jwt_info_lock.audience = vec!("CLaLr8sikEiN7NCrPMhjhbtLZgnZJ6JZVzPdVN5P".into());
}
// Configure CORS
let cors = CorsOptions::default().allowed_origins(AllowedOrigins::all()) let cors = CorsOptions::default().allowed_origins(AllowedOrigins::all())
.to_cors().unwrap(); .to_cors().unwrap();
// Start the webserver
rocket::build() rocket::build()
.manage(db) .manage(db)
.manage(SharedJwtInfo(jwt_info))
.attach(cors) .attach(cors)
.mount("/", routes![ .mount("/", routes![
index, index,