use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation, TokenData}; use reqwest; use rocket::serde::{Deserialize, json::serde_json}; use std::collections::HashMap; // Define a struct for the claims you expect in your token #[derive(Debug, Deserialize)] pub struct MyClaims { pub sub: String, pub exp: usize, pub aud: String, pub iss: String, pub preferred_username: Option, } #[derive(Deserialize)] struct AuthorizationWellKnown { issuer: String, jwks_uri: String, } #[derive(Deserialize)] struct JwksContent { kid: String, x5c: Vec, } #[derive(Deserialize)] struct Jwks { keys: Vec, } #[derive(Clone)] pub struct JwtInfo { pub jwks_uri: String, pub audience: Vec, pub issuer: Vec, pub public_keys: HashMap, } #[derive(Debug)] pub enum MyCustomErrorType { NetworkError, JsonParseError, MissingCert, CertTypeNotImplemented, FailedToDecodeCert, JwtError, FailedToDecodeJwtToken, FailedToDecodeJwtHeader, FailedToExtractKid, } pub async fn validate_jwt(token: &str, jwt_info: &JwtInfo) -> Result { // Decode the header to give info about the crypto let jwt_header: jsonwebtoken::Header; match decode_header(token) { Ok(data) => jwt_header = data, Err(_) => return Err(MyCustomErrorType::FailedToDecodeJwtHeader), } // Create a new validation let mut validation = Validation::new(jwt_header.alg); // Set the expected audience and issuer validation.set_audience(&jwt_info.audience); validation.set_issuer(&jwt_info.issuer); // Extract the JWT kid let kid: String; match jwt_header.kid { Some(fetched_kid) => kid = fetched_kid, None => { eprintln!("Unable to extract KID from jwt header"); return Err(MyCustomErrorType::FailedToExtractKid); } } // Fetch the corresponding public key let public_key_pem: &String; match jwt_info.public_keys.get(&kid) { Some(key) => public_key_pem = key, None => return Err(MyCustomErrorType::MissingCert) } // Decode the JWT token let token_data: TokenData; match jwt_header.alg { Algorithm::RS256 => { // Extract the key let key: DecodingKey; match DecodingKey::from_rsa_pem(public_key_pem.as_bytes()) { Ok(data) => key = data, Err(_) => return Err(MyCustomErrorType::FailedToDecodeCert) } // Decode the token match decode::(token, &key, &validation,) { Ok(data) => token_data = data, Err(_) => return Err(MyCustomErrorType::FailedToDecodeJwtToken) } }, Algorithm::ES256 => { // Extract the key let key: DecodingKey; match DecodingKey::from_ec_pem(public_key_pem.as_bytes()) { Ok(data) => key = data, Err(_) => return Err(MyCustomErrorType::FailedToDecodeCert) } // Decode the token match decode::(token, &key, &validation,) { Ok(data) => token_data = data, Err(_) => return Err(MyCustomErrorType::FailedToDecodeJwtToken) } }, _ => { eprintln!("JWT Public key algoritm not handled"); return Err(MyCustomErrorType::CertTypeNotImplemented); } } Ok(token_data.claims) } pub async fn fetch_jwt_certificates(jwt_info: &JwtInfo) -> Option> { // Fetch the JWKS endpoint let jwks_body: String; match reqwest::get(&jwt_info.jwks_uri).await { Ok(response) => { match response.text().await { Ok(text) => jwks_body = text, Err(e) => { eprintln!("Failed to extract text from response body with error:\n{}", e); return None; } } }, Err(e) => { eprintln!("Failed to get the jwks_uri with error:\n{}", e); return None; } } // Parse the data into the struct let jwks_data: Jwks; match serde_json::from_str(&jwks_body) { Ok(jwks) => jwks_data = jwks, Err(e) => { eprintln!("Failed to parse fetched jwks body to Jwks struct with error:\n{}", e); return None; } } // Create the output hashmap let mut output_map: HashMap = HashMap::new(); // Go through each pair of keys and add them to the output jwt info for key in jwks_data.keys { // Extract the x5c key data let x5c = key.x5c.get(0)?; // Add the PEM info in to the x5c let pem_data = format!("-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----", x5c); // Add the resulting key to the hashmap output_map.insert(key.kid, pem_data); } // Check that we got any keys if output_map.is_empty() { eprintln!("Failed to fetch any public keys"); return None; } Some(output_map) } pub async fn fetch_jwt_info(well_known_uri: &str, expected_issuer: Vec) -> Result { // Fetch the info from the well known endpoint let well_known_body; match reqwest::get(well_known_uri).await { Ok(response) => { match response.text().await { Ok(text) => well_known_body = text, Err(e) => { eprintln!("Failed to extract text from response body with error:\n{}", e); return Err(MyCustomErrorType::NetworkError); } } }, Err(e) => { eprintln!("Failed to get the well known with error:\n{}", e); return Err(MyCustomErrorType::NetworkError); } } // Parse the data into the well known struct let well_known_data: AuthorizationWellKnown; match serde_json::from_str(&well_known_body) { Ok(data) => well_known_data = data, Err(e) => { eprintln!("Failed to parse well known data into struct with err:\n{}", e); return Err(MyCustomErrorType::JsonParseError); } } // Validate the issuer if !expected_issuer.contains(&well_known_data.issuer) { eprintln!( "Expected issuer does not contain fetched issuer.\n{} ∉ {:?}", well_known_data.issuer, expected_issuer ); return Err(MyCustomErrorType::JwtError); } // Create a JwtInfo variable let mut jwt_info: JwtInfo = JwtInfo { jwks_uri: well_known_data.jwks_uri, audience: Vec::new(), issuer: expected_issuer, public_keys: HashMap::new(), }; // Fetch the valid public keys match fetch_jwt_certificates(&jwt_info).await { Some(map) => jwt_info.public_keys = map, None => { return Err(MyCustomErrorType::JwtError); } } Ok(jwt_info) }