diff --git a/src/authorization.rs b/src/authorization.rs index 8724367..6db410d 100644 --- a/src/authorization.rs +++ b/src/authorization.rs @@ -12,10 +12,19 @@ pub struct BoardMember { pub username: String } +impl From for BoardMember { + fn from(claims: MyClaims) -> BoardMember { + BoardMember { + username: claims.preferred_username.unwrap() + } + } +} + #[derive(Debug)] pub enum AuthenticationError { InvalidJWT, FailedToGrabJWTInfo, + FailedToUpdateCerts, MissingAuthenticationHeader, InvalidAuthenticationHeader, } @@ -54,20 +63,63 @@ impl<'r> FromRequest<'r> for BoardMember { Outcome::Forward(status) => return Outcome::Error((status, AuthenticationError::FailedToGrabJWTInfo)), } - // Grab the variable lock - let mut jwt_info_lock = jwt_info.lock().await; + // Create a clone of the data + let mut jwt_info_clone: JwtInfo; + { + // Grab the variable lock + let jwt_info_lock = jwt_info.lock().await; + // Clone the data + jwt_info_clone = jwt_info_lock.clone(); + } - // Validate the token and store the result - let valid_token: MyClaims; - match validate_jwt(jwt_token, &mut jwt_info_lock).await { - Ok(data) => valid_token = data, + // Validate the token, if it is okay return the data + let _missing_cert: bool; + match validate_jwt(jwt_token, &jwt_info_clone).await { + Ok(returned_claims) => { + // Return the data + return Outcome::Success(returned_claims.into()) + }, Err(e) => { - println!("{:?}", e); - return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidJWT)) + match e { + MyCustomErrorType::MissingCert => _missing_cert = true, // We handle the missing cert error further down + _ => { + // Validation failed, reject the JWT + eprintln!("Validating JWT failed with error: {:?}", e); + return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidJWT)) + } + } } } - let username = valid_token.preferred_username.unwrap().clone(); - Outcome::Success(BoardMember{username}) + // Since we are missing the cert we fetch the list of certs again + match fetch_jwt_certificates(&jwt_info_clone).await { + Some(data) => { + if jwt_info_clone.public_keys != data { + // We have a new set of keys, update the JwtInfo + + let mut jwt_info_lock = jwt_info.lock().await; // Grab the lock on jwt_info + jwt_info_lock.public_keys = data.clone(); // Update jwt_info with the new keys + jwt_info_clone.public_keys = data; // Update the local copy to mirror the recently updated jwt_info + } else { + // Keys are the same, validation failed + return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidJWT)) + } + }, + // We failed to fetch the new certs, fail the validation + None => return Outcome::Error((Status::Unauthorized, AuthenticationError::FailedToUpdateCerts)) + } + + // Validate the token again, now with the updated certs. If it is okay return the data + match validate_jwt(jwt_token, &jwt_info_clone).await { + Ok(returned_claims) => { + // Return the data + return Outcome::Success(returned_claims.into()) + }, + Err(e) => { + // Validation failed, reject the JWT + eprintln!("Validating JWT failed with error: {:?}", e); + return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidJWT)) + } + } } } \ No newline at end of file diff --git a/src/jwt_validation.rs b/src/jwt_validation.rs index d5ae52a..385587b 100644 --- a/src/jwt_validation.rs +++ b/src/jwt_validation.rs @@ -3,10 +3,6 @@ use reqwest; use rocket::serde::{Deserialize, json::serde_json}; use std::collections::HashMap; -// JWT Shared state async stuff -use rocket::tokio::sync::Mutex; -use std::sync::Arc; - // Define a struct for the claims you expect in your token #[derive(Debug, Deserialize)] pub struct MyClaims { @@ -34,6 +30,7 @@ struct Jwks { keys: Vec, } +#[derive(Clone)] pub struct JwtInfo { pub jwks_uri: String, pub audience: Vec, @@ -44,13 +41,23 @@ pub struct JwtInfo { #[derive(Debug)] pub enum MyCustomErrorType { NetworkError, - JwtError, JsonParseError, + MissingCert, + CertTypeNotImplemented, + FailedToDecodeCert, + JwtError, + FailedToDecodeJwtToken, + FailedToDecodeJwtHeader, + FailedToExtractKid, } -pub async fn validate_jwt(token: &str, jwt_info: &mut JwtInfo) -> Result { +pub async fn validate_jwt(token: &str, jwt_info: &JwtInfo) -> Result { // Decode the header to give info about the crypto - let jwt_header = decode_header(token)?; + let jwt_header: jsonwebtoken::Header; + match decode_header(token) { + Ok(data) => jwt_header = data, + Err(_) => return Err(MyCustomErrorType::FailedToDecodeJwtHeader), + } // Create a new validation let mut validation = Validation::new(jwt_header.alg); @@ -64,7 +71,7 @@ pub async fn validate_jwt(token: &str, jwt_info: &mut JwtInfo) -> Result kid = fetched_kid, None => { eprintln!("Unable to extract KID from jwt header"); - return Err(jsonwebtoken::errors::ErrorKind::InvalidToken.into()); + return Err(MyCustomErrorType::FailedToExtractKid); } } @@ -72,53 +79,50 @@ pub async fn validate_jwt(token: &str, jwt_info: &mut JwtInfo) -> Result public_key_pem = key, - None => { - // If the key doesn't exist look up the keys again - match fetch_jwt_certificates(jwt_info).await { - Some(key_map) => jwt_info.public_keys = key_map, - None => { - eprintln!("Failed to fetch jwt pem certificates"); - } - } - - // Try to get the keys once more - match jwt_info.public_keys.get(&kid) { - Some(key) => public_key_pem = key, - None => { - eprintln!("Failed to fetch find matching certificates for given KID. {}", kid); - return Err(jsonwebtoken::errors::ErrorKind::InvalidToken.into()); - } - } - } + None => return Err(MyCustomErrorType::MissingCert) } // Decode the JWT token let token_data: TokenData; match jwt_header.alg { Algorithm::RS256 => { - token_data = decode::( - token, - &DecodingKey::from_rsa_pem(public_key_pem.as_bytes())?, - &validation, - )?; + // Extract the key + let key: DecodingKey; + match DecodingKey::from_rsa_pem(public_key_pem.as_bytes()) { + Ok(data) => key = data, + Err(_) => return Err(MyCustomErrorType::FailedToDecodeCert) + } + + // Decode the token + match decode::(token, &key, &validation,) { + Ok(data) => token_data = data, + Err(_) => return Err(MyCustomErrorType::FailedToDecodeJwtToken) + } }, Algorithm::ES256 => { - token_data = decode::( - token, - &DecodingKey::from_ec_pem(public_key_pem.as_bytes())?, - &validation, - )?; + // Extract the key + let key: DecodingKey; + match DecodingKey::from_ec_pem(public_key_pem.as_bytes()) { + Ok(data) => key = data, + Err(_) => return Err(MyCustomErrorType::FailedToDecodeCert) + } + + // Decode the token + match decode::(token, &key, &validation,) { + Ok(data) => token_data = data, + Err(_) => return Err(MyCustomErrorType::FailedToDecodeJwtToken) + } }, _ => { eprintln!("JWT Public key algoritm not handled"); - return Err(jsonwebtoken::errors::ErrorKind::InvalidAlgorithm.into()); + return Err(MyCustomErrorType::CertTypeNotImplemented); } } Ok(token_data.claims) } -async fn fetch_jwt_certificates(jwt_info: &JwtInfo) -> Option> { +pub async fn fetch_jwt_certificates(jwt_info: &JwtInfo) -> Option> { // Fetch the JWKS endpoint let jwks_body: String; match reqwest::get(&jwt_info.jwks_uri).await {