Authorization fetches certs once if no matching cert is found

This commit is contained in:
Sebastian H. Gabrielli 2024-01-01 14:07:47 +01:00
parent b5cbe304e5
commit 687ab16d14
2 changed files with 104 additions and 48 deletions

View File

@ -12,10 +12,19 @@ pub struct BoardMember {
pub username: String pub username: String
} }
impl From<MyClaims> for BoardMember {
fn from(claims: MyClaims) -> BoardMember {
BoardMember {
username: claims.preferred_username.unwrap()
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub enum AuthenticationError { pub enum AuthenticationError {
InvalidJWT, InvalidJWT,
FailedToGrabJWTInfo, FailedToGrabJWTInfo,
FailedToUpdateCerts,
MissingAuthenticationHeader, MissingAuthenticationHeader,
InvalidAuthenticationHeader, InvalidAuthenticationHeader,
} }
@ -54,20 +63,63 @@ impl<'r> FromRequest<'r> for BoardMember {
Outcome::Forward(status) => return Outcome::Error((status, AuthenticationError::FailedToGrabJWTInfo)), Outcome::Forward(status) => return Outcome::Error((status, AuthenticationError::FailedToGrabJWTInfo)),
} }
// Create a clone of the data
let mut jwt_info_clone: JwtInfo;
{
// Grab the variable lock // Grab the variable lock
let mut jwt_info_lock = jwt_info.lock().await; let jwt_info_lock = jwt_info.lock().await;
// Clone the data
jwt_info_clone = jwt_info_lock.clone();
}
// Validate the token and store the result // Validate the token, if it is okay return the data
let valid_token: MyClaims; let _missing_cert: bool;
match validate_jwt(jwt_token, &mut jwt_info_lock).await { match validate_jwt(jwt_token, &jwt_info_clone).await {
Ok(data) => valid_token = data, Ok(returned_claims) => {
// Return the data
return Outcome::Success(returned_claims.into())
},
Err(e) => { Err(e) => {
println!("{:?}", e); match e {
MyCustomErrorType::MissingCert => _missing_cert = true, // We handle the missing cert error further down
_ => {
// Validation failed, reject the JWT
eprintln!("Validating JWT failed with error: {:?}", e);
return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidJWT)) return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidJWT))
} }
} }
}
}
let username = valid_token.preferred_username.unwrap().clone(); // Since we are missing the cert we fetch the list of certs again
Outcome::Success(BoardMember{username}) match fetch_jwt_certificates(&jwt_info_clone).await {
Some(data) => {
if jwt_info_clone.public_keys != data {
// We have a new set of keys, update the JwtInfo
let mut jwt_info_lock = jwt_info.lock().await; // Grab the lock on jwt_info
jwt_info_lock.public_keys = data.clone(); // Update jwt_info with the new keys
jwt_info_clone.public_keys = data; // Update the local copy to mirror the recently updated jwt_info
} else {
// Keys are the same, validation failed
return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidJWT))
}
},
// We failed to fetch the new certs, fail the validation
None => return Outcome::Error((Status::Unauthorized, AuthenticationError::FailedToUpdateCerts))
}
// Validate the token again, now with the updated certs. If it is okay return the data
match validate_jwt(jwt_token, &jwt_info_clone).await {
Ok(returned_claims) => {
// Return the data
return Outcome::Success(returned_claims.into())
},
Err(e) => {
// Validation failed, reject the JWT
eprintln!("Validating JWT failed with error: {:?}", e);
return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidJWT))
}
}
} }
} }

View File

@ -3,10 +3,6 @@ use reqwest;
use rocket::serde::{Deserialize, json::serde_json}; use rocket::serde::{Deserialize, json::serde_json};
use std::collections::HashMap; use std::collections::HashMap;
// JWT Shared state async stuff
use rocket::tokio::sync::Mutex;
use std::sync::Arc;
// Define a struct for the claims you expect in your token // Define a struct for the claims you expect in your token
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct MyClaims { pub struct MyClaims {
@ -34,6 +30,7 @@ struct Jwks {
keys: Vec<JwksContent>, keys: Vec<JwksContent>,
} }
#[derive(Clone)]
pub struct JwtInfo { pub struct JwtInfo {
pub jwks_uri: String, pub jwks_uri: String,
pub audience: Vec<String>, pub audience: Vec<String>,
@ -44,13 +41,23 @@ pub struct JwtInfo {
#[derive(Debug)] #[derive(Debug)]
pub enum MyCustomErrorType { pub enum MyCustomErrorType {
NetworkError, NetworkError,
JwtError,
JsonParseError, JsonParseError,
MissingCert,
CertTypeNotImplemented,
FailedToDecodeCert,
JwtError,
FailedToDecodeJwtToken,
FailedToDecodeJwtHeader,
FailedToExtractKid,
} }
pub async fn validate_jwt(token: &str, jwt_info: &mut JwtInfo) -> Result<MyClaims, jsonwebtoken::errors::Error> { pub async fn validate_jwt(token: &str, jwt_info: &JwtInfo) -> Result<MyClaims, MyCustomErrorType> {
// Decode the header to give info about the crypto // Decode the header to give info about the crypto
let jwt_header = decode_header(token)?; let jwt_header: jsonwebtoken::Header;
match decode_header(token) {
Ok(data) => jwt_header = data,
Err(_) => return Err(MyCustomErrorType::FailedToDecodeJwtHeader),
}
// Create a new validation // Create a new validation
let mut validation = Validation::new(jwt_header.alg); let mut validation = Validation::new(jwt_header.alg);
@ -64,7 +71,7 @@ pub async fn validate_jwt(token: &str, jwt_info: &mut JwtInfo) -> Result<MyClaim
Some(fetched_kid) => kid = fetched_kid, Some(fetched_kid) => kid = fetched_kid,
None => { None => {
eprintln!("Unable to extract KID from jwt header"); eprintln!("Unable to extract KID from jwt header");
return Err(jsonwebtoken::errors::ErrorKind::InvalidToken.into()); return Err(MyCustomErrorType::FailedToExtractKid);
} }
} }
@ -72,53 +79,50 @@ pub async fn validate_jwt(token: &str, jwt_info: &mut JwtInfo) -> Result<MyClaim
let public_key_pem: &String; let public_key_pem: &String;
match jwt_info.public_keys.get(&kid) { match jwt_info.public_keys.get(&kid) {
Some(key) => public_key_pem = key, Some(key) => public_key_pem = key,
None => { None => return Err(MyCustomErrorType::MissingCert)
// If the key doesn't exist look up the keys again
match fetch_jwt_certificates(jwt_info).await {
Some(key_map) => jwt_info.public_keys = key_map,
None => {
eprintln!("Failed to fetch jwt pem certificates");
}
}
// Try to get the keys once more
match jwt_info.public_keys.get(&kid) {
Some(key) => public_key_pem = key,
None => {
eprintln!("Failed to fetch find matching certificates for given KID. {}", kid);
return Err(jsonwebtoken::errors::ErrorKind::InvalidToken.into());
}
}
}
} }
// Decode the JWT token // Decode the JWT token
let token_data: TokenData<MyClaims>; let token_data: TokenData<MyClaims>;
match jwt_header.alg { match jwt_header.alg {
Algorithm::RS256 => { Algorithm::RS256 => {
token_data = decode::<MyClaims>( // Extract the key
token, let key: DecodingKey;
&DecodingKey::from_rsa_pem(public_key_pem.as_bytes())?, match DecodingKey::from_rsa_pem(public_key_pem.as_bytes()) {
&validation, Ok(data) => key = data,
)?; Err(_) => return Err(MyCustomErrorType::FailedToDecodeCert)
}
// Decode the token
match decode::<MyClaims>(token, &key, &validation,) {
Ok(data) => token_data = data,
Err(_) => return Err(MyCustomErrorType::FailedToDecodeJwtToken)
}
}, },
Algorithm::ES256 => { Algorithm::ES256 => {
token_data = decode::<MyClaims>( // Extract the key
token, let key: DecodingKey;
&DecodingKey::from_ec_pem(public_key_pem.as_bytes())?, match DecodingKey::from_ec_pem(public_key_pem.as_bytes()) {
&validation, Ok(data) => key = data,
)?; Err(_) => return Err(MyCustomErrorType::FailedToDecodeCert)
}
// Decode the token
match decode::<MyClaims>(token, &key, &validation,) {
Ok(data) => token_data = data,
Err(_) => return Err(MyCustomErrorType::FailedToDecodeJwtToken)
}
}, },
_ => { _ => {
eprintln!("JWT Public key algoritm not handled"); eprintln!("JWT Public key algoritm not handled");
return Err(jsonwebtoken::errors::ErrorKind::InvalidAlgorithm.into()); return Err(MyCustomErrorType::CertTypeNotImplemented);
} }
} }
Ok(token_data.claims) Ok(token_data.claims)
} }
async fn fetch_jwt_certificates(jwt_info: &JwtInfo) -> Option<HashMap<String, String>> { pub async fn fetch_jwt_certificates(jwt_info: &JwtInfo) -> Option<HashMap<String, String>> {
// Fetch the JWKS endpoint // Fetch the JWKS endpoint
let jwks_body: String; let jwks_body: String;
match reqwest::get(&jwt_info.jwks_uri).await { match reqwest::get(&jwt_info.jwks_uri).await {