Make authroization use shared state, giving ~200 speedup
This commit is contained in:
parent
a68730ac52
commit
b5cbe304e5
@ -1,12 +1,11 @@
|
|||||||
// Import the Rocket requirements
|
// Import the Rocket requirements
|
||||||
use rocket::http::Status;
|
use rocket::http::Status;
|
||||||
use rocket::request::{self, Outcome, Request, FromRequest};
|
use rocket::State;
|
||||||
|
use rocket::request::{Outcome, Request, FromRequest};
|
||||||
|
|
||||||
// Import the jwt validation functions
|
// Import async types to share jwt_info
|
||||||
mod jwt_validation {
|
use crate::SharedJwtInfo;
|
||||||
include!("jwt_validation.rs");
|
use crate::jwt_validation::*;
|
||||||
}
|
|
||||||
use crate::authorization::jwt_validation::*;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct BoardMember {
|
pub struct BoardMember {
|
||||||
@ -16,6 +15,7 @@ pub struct BoardMember {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum AuthenticationError {
|
pub enum AuthenticationError {
|
||||||
InvalidJWT,
|
InvalidJWT,
|
||||||
|
FailedToGrabJWTInfo,
|
||||||
MissingAuthenticationHeader,
|
MissingAuthenticationHeader,
|
||||||
InvalidAuthenticationHeader,
|
InvalidAuthenticationHeader,
|
||||||
}
|
}
|
||||||
@ -43,20 +43,23 @@ impl<'r> FromRequest<'r> for BoardMember {
|
|||||||
None => return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidAuthenticationHeader))
|
None => return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidAuthenticationHeader))
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is temporary, this should be saved and not called on each validation
|
// Grab the JWT info struct from rocket
|
||||||
let mut jwt_info: JwtInfo;
|
let state = req.guard::<&State<SharedJwtInfo>>().await;
|
||||||
match fetch_jwt_info("https://sso.gitgals.com/application/o/sebtest/.well-known/openid-configuration", vec!("https://sso.gitgals.com/application/o/sebtest/".into())).await {
|
|
||||||
Ok(data) => jwt_info = data,
|
// Handle the state, and upon success store it in the jwt_info
|
||||||
Err(e) => {
|
let jwt_info;
|
||||||
println!("{:?}", e);
|
match state {
|
||||||
return Outcome::Error((Status::InternalServerError, AuthenticationError::InvalidJWT))
|
Outcome::Success(fetched_state) => jwt_info = fetched_state.inner().0.as_ref(),
|
||||||
},
|
Outcome::Error(err) => return Outcome::Error((err.0, AuthenticationError::FailedToGrabJWTInfo)),
|
||||||
|
Outcome::Forward(status) => return Outcome::Error((status, AuthenticationError::FailedToGrabJWTInfo)),
|
||||||
}
|
}
|
||||||
jwt_info.audience = vec!("CLaLr8sikEiN7NCrPMhjhbtLZgnZJ6JZVzPdVN5P".into());
|
|
||||||
|
// Grab the variable lock
|
||||||
|
let mut jwt_info_lock = jwt_info.lock().await;
|
||||||
|
|
||||||
// Validate the token and store the result
|
// Validate the token and store the result
|
||||||
let valid_token: MyClaims;
|
let valid_token: MyClaims;
|
||||||
match validate_jwt(jwt_token, &mut jwt_info).await {
|
match validate_jwt(jwt_token, &mut jwt_info_lock).await {
|
||||||
Ok(data) => valid_token = data,
|
Ok(data) => valid_token = data,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!("{:?}", e);
|
println!("{:?}", e);
|
||||||
|
|||||||
@ -3,6 +3,10 @@ use reqwest;
|
|||||||
use rocket::serde::{Deserialize, json::serde_json};
|
use rocket::serde::{Deserialize, json::serde_json};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
// JWT Shared state async stuff
|
||||||
|
use rocket::tokio::sync::Mutex;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
// Define a struct for the claims you expect in your token
|
// Define a struct for the claims you expect in your token
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct MyClaims {
|
pub struct MyClaims {
|
||||||
|
|||||||
29
src/main.rs
29
src/main.rs
@ -15,10 +15,18 @@ use webserver_member::*;
|
|||||||
// Handle CORS
|
// Handle CORS
|
||||||
use rocket_cors::{AllowedOrigins, CorsOptions};
|
use rocket_cors::{AllowedOrigins, CorsOptions};
|
||||||
|
|
||||||
// Handle authentication
|
// Handle autorization
|
||||||
|
// Async types to share jwt_info
|
||||||
|
use rocket::tokio::sync::Mutex;
|
||||||
|
use std::sync::Arc;
|
||||||
|
// Actual authorization functions
|
||||||
|
mod jwt_validation;
|
||||||
|
use jwt_validation::{JwtInfo, fetch_jwt_info};
|
||||||
mod authorization;
|
mod authorization;
|
||||||
use authorization::BoardMember;
|
use authorization::BoardMember;
|
||||||
|
|
||||||
|
struct SharedJwtInfo(Arc<Mutex<JwtInfo>>);
|
||||||
|
|
||||||
// Serve the very exiting main page
|
// Serve the very exiting main page
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn index(board_member: BoardMember) -> String {
|
fn index(board_member: BoardMember) -> String {
|
||||||
@ -27,16 +35,35 @@ fn index(board_member: BoardMember) -> String {
|
|||||||
|
|
||||||
#[launch]
|
#[launch]
|
||||||
async fn rocket() -> _ {
|
async fn rocket() -> _ {
|
||||||
|
// Create a database connection
|
||||||
let db = match set_up_db().await {
|
let db = match set_up_db().await {
|
||||||
Ok(db) => db,
|
Ok(db) => db,
|
||||||
Err(err) => panic!("{}", err)
|
Err(err) => panic!("{}", err)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Fetch the JWT info from the authentication server
|
||||||
|
let jwt_info: Arc<Mutex<JwtInfo>> = Arc::new(Mutex::new(fetch_jwt_info(
|
||||||
|
"https://sso.gitgals.com/application/o/sebtest/.well-known/openid-configuration",
|
||||||
|
vec!["https://sso.gitgals.com/application/o/sebtest/".into()]
|
||||||
|
).await.expect("Failed to fetch authorization info")));
|
||||||
|
|
||||||
|
// Set the expected audience field in a scoped block to make sure the lock is released as soon as we are done
|
||||||
|
{
|
||||||
|
// Grab the mutex lock
|
||||||
|
let mut jwt_info_lock = jwt_info.lock().await;
|
||||||
|
// Set the expected audience field
|
||||||
|
jwt_info_lock.audience = vec!("CLaLr8sikEiN7NCrPMhjhbtLZgnZJ6JZVzPdVN5P".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure CORS
|
||||||
let cors = CorsOptions::default().allowed_origins(AllowedOrigins::all())
|
let cors = CorsOptions::default().allowed_origins(AllowedOrigins::all())
|
||||||
.to_cors().unwrap();
|
.to_cors().unwrap();
|
||||||
|
|
||||||
|
// Start the webserver
|
||||||
rocket::build()
|
rocket::build()
|
||||||
.manage(db)
|
.manage(db)
|
||||||
|
.manage(SharedJwtInfo(jwt_info))
|
||||||
.attach(cors)
|
.attach(cors)
|
||||||
.mount("/", routes![
|
.mount("/", routes![
|
||||||
index,
|
index,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user