use jsonwebtoken::{decode, decode_header, errors::Result, Algorithm, DecodingKey, Validation, TokenData}; use reqwest; use serde::{Deserialize, Serialize}; use std::collections::HashMap; // Define a struct for the claims you expect in your token #[derive(Debug, Deserialize)] struct MyClaims { sub: String, exp: usize, aud: String, iss: String, preferred_username: Option, } #[derive(Deserialize)] struct AuthorizationWellKnown { issuer: String, jwks_uri: String, authorization_endpoint: String, token_endpoint: String, userinfo_endpoint: String, end_session_endpoint: String, introspection_endpoint: String, revocation_endpoint: String, device_authorization_endpoint: String, } #[derive(Deserialize)] struct JwksContent { kid: String, x5c: Vec, } #[derive(Deserialize)] struct Jwks { keys: Vec, } struct JwtInfo { jwks_uri: String, audience: Vec, issuer: Vec, public_keys: HashMap, } fn validate_jwt(token: &str, jwt_info: &mut JwtInfo) -> Result { // Decode the header to give info about the crypto let jwt_header = decode_header(token)?; // Create a new validation let mut validation = Validation::new(jwt_header.alg); // Set the expected audience and issuer validation.set_audience(&jwt_info.audience); validation.set_issuer(&jwt_info.issuer); // Extract the JWT kid let kid: String; match jwt_header.kid { Some(fetched_kid) => kid = fetched_kid, None => { eprintln!("Unable to extract KID from jwt header"); return Err(jsonwebtoken::errors::ErrorKind::InvalidToken.into()); } } // Fetch the corresponding public key let public_key_pem: &String; match jwt_info.public_keys.get(&kid) { Some(key) => public_key_pem = key, None => { // If the key doesn't exist look up the keys again match fetch_jwt_certificates(jwt_info) { Some(key_map) => jwt_info.public_keys = key_map, None => { eprintln!("Failed to fetch jwt pem certificates"); } } // Try to get the keys once more match jwt_info.public_keys.get(&kid) { Some(key) => public_key_pem = key, None => { eprintln!("Failed to fetch find matching certificates for given KID. {}", kid); return Err(jsonwebtoken::errors::ErrorKind::InvalidToken.into()); } } } } // Decode the JWT token let token_data: TokenData; match jwt_header.alg { Algorithm::RS256 => { token_data = decode::( token, &DecodingKey::from_rsa_pem(public_key_pem.as_bytes())?, &validation, )?; }, Algorithm::ES256 => { token_data = decode::( token, &DecodingKey::from_ec_pem(public_key_pem.as_bytes())?, &validation, )?; }, _ => { eprintln!("JWT Public key algoritm not handled"); return Err(jsonwebtoken::errors::ErrorKind::InvalidAlgorithm.into()); } } Ok(token_data.claims) } fn fetch_jwt_certificates(jwt_info: &JwtInfo) -> Option> { // Fetch the JWKS endpoint let jwks_body: String; match reqwest::blocking::get(&jwt_info.jwks_uri) { Ok(response) => { match response.text() { Ok(text) => jwks_body = text, Err(e) => { eprintln!("Failed to extract text from response body with error:\n{}", e); return None; } } }, Err(e) => { eprintln!("Failed to get the jwks_uri with error:\n{}", e); return None; } } // Parse the data into the struct let jwks_data: Jwks; match serde_json::from_str(&jwks_body) { Ok(jwks) => jwks_data = jwks, Err(e) => { eprintln!("Failed to parse fetched jwks body to Jwks struct with error:\n{}", e); return None; } } // Create the output hashmap let mut output_map: HashMap = HashMap::new(); // Go through each pair of keys and add them to the output jwt info for key in jwks_data.keys { // Extract the x5c key data let x5c = key.x5c.get(0)?; // Add the PEM info in to the x5c let pem_data = format!("-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----", x5c); // Add the resulting key to the hashmap output_map.insert(key.kid, pem_data); } // Check that we got any keys if output_map.is_empty() { eprintln!("Failed to fetch any public keys"); return None; } Some(output_map) } fn fetch_jwt_info(well_known_uri: &str, expected_issuer: Vec) -> Result { // Fetch the info from the well known endpoint let well_known_body = reqwest::blocking::get(well_known_uri) .unwrap() .text() .unwrap(); // Parse the data into the well known struct let well_known_data: AuthorizationWellKnown = serde_json::from_str(&well_known_body).unwrap(); // Validate the issuer if !expected_issuer.contains(&well_known_data.issuer) { eprintln!( "Expected issuer does not contain fetched issuer.\n{} ∉ {:?}", well_known_data.issuer, expected_issuer ); // TODO: Return Err properly //Err("Invalid issuer"); } // Create a JwtInfo variable let mut jwt_info: JwtInfo = JwtInfo { jwks_uri: well_known_data.jwks_uri, audience: Vec::new(), issuer: expected_issuer, public_keys: HashMap::new(), }; // Fetch the valid public keys match fetch_jwt_certificates(&jwt_info) { Some(map) => jwt_info.public_keys = map, None => { // TODO: Return err properly } } Ok(jwt_info) } fn main() { let token = "eyJhbGciOiJFUzI1NiIsImtpZCI6IjVkM2JkMDcxOGQ4ZWM3NWQ3ZDg1MjlmNDQwMzRiYTc1IiwidHlwIjoiSldUIn0.eyJpc3MiOiJodHRwczovL3Nzby5naXRnYWxzLmNvbS9hcHBsaWNhdGlvbi9vL3NlYnRlc3QvIiwic3ViIjoiZjJiNzIwOGY2MTcwYWI0NWNlZGM1OGUzMTM0NGNjNGY3MGQzZWRjMjhkYWZkMmJlNDZkNzIxMzM1ZDQxZDk2NCIsImF1ZCI6IkNMYUxyOHNpa0VpTjdOQ3JQTWhqaGJ0TFpnblpKNkpaVnpQZFZONVAiLCJleHAiOjE3MDM5NjUxNjIsImlhdCI6MTcwMzk0NzE2MiwiYXV0aF90aW1lIjoxNzAzODU3NzMwLCJhY3IiOiJnb2F1dGhlbnRpay5pby9wcm92aWRlcnMvb2F1dGgyL2RlZmF1bHQiLCJlbWFpbCI6IiIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJuYW1lIjoiSW5zb21uaWEiLCJnaXZlbl9uYW1lIjoiSW5zb21uaWEiLCJwcmVmZXJyZWRfdXNlcm5hbWUiOiJpbnNvbW5pYS10ZXN0Iiwibmlja25hbWUiOiJpbnNvbW5pYS10ZXN0IiwiZ3JvdXBzIjpbXSwiYXpwIjoiQ0xhTHI4c2lrRWlON05DclBNaGpoYnRMWmduWko2SlpWelBkVk41UCIsInVpZCI6IkdPQTNRdTBIOW5TUUx4WHJZVXJ4RHUzTTVkVDhmNTZIQ0l5QlhZdnYifQ.AuVjuJXApMq1vPS48gK5htGSv8KzcCQZlerc82adiCNVb789w2lBoiLjbKotHvAPQOLTQ3qWv2yHNPgBE3dhVA"; let well_known_uri = "https://sso.gitgals.com/application/o/sebtest/.well-known/openid-configuration"; let mut jwt_info = fetch_jwt_info(well_known_uri, vec!("https://sso.gitgals.com/application/o/sebtest/".into())).unwrap(); jwt_info.audience = vec!("CLaLr8sikEiN7NCrPMhjhbtLZgnZJ6JZVzPdVN5P".to_string()); let result: MyClaims; match validate_jwt(token, &jwt_info) { Ok(claims) => result = claims, Err(err) => panic!("Error validating token: {:?}", err), } println!("Token is valid! Claims: {:#?}", result); }