JWT-validation/src/main.rs
2023-12-30 22:56:54 +01:00

222 lines
7.5 KiB
Rust

use jsonwebtoken::{decode, decode_header, errors::Result, Algorithm, DecodingKey, Validation, TokenData};
use reqwest;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// Define a struct for the claims you expect in your token
#[derive(Debug, Deserialize)]
struct MyClaims {
sub: String,
exp: usize,
aud: String,
iss: String,
preferred_username: Option<String>,
}
#[derive(Deserialize)]
struct AuthorizationWellKnown {
issuer: String,
jwks_uri: String,
authorization_endpoint: String,
token_endpoint: String,
userinfo_endpoint: String,
end_session_endpoint: String,
introspection_endpoint: String,
revocation_endpoint: String,
device_authorization_endpoint: String,
}
#[derive(Deserialize)]
struct JwksContent {
kid: String,
x5c: Vec<String>,
}
#[derive(Deserialize)]
struct Jwks {
keys: Vec<JwksContent>,
}
struct JwtInfo {
jwks_uri: String,
audience: Vec<String>,
issuer: Vec<String>,
public_keys: HashMap<String, String>,
}
fn validate_jwt(token: &str, jwt_info: &mut JwtInfo) -> Result<MyClaims> {
// Decode the header to give info about the crypto
let jwt_header = decode_header(token)?;
// Create a new validation
let mut validation = Validation::new(jwt_header.alg);
// Set the expected audience and issuer
validation.set_audience(&jwt_info.audience);
validation.set_issuer(&jwt_info.issuer);
// Extract the JWT kid
let kid: String;
match jwt_header.kid {
Some(fetched_kid) => kid = fetched_kid,
None => {
eprintln!("Unable to extract KID from jwt header");
return Err(jsonwebtoken::errors::ErrorKind::InvalidToken.into());
}
}
// Fetch the corresponding public key
let public_key_pem: &String;
match jwt_info.public_keys.get(&kid) {
Some(key) => public_key_pem = key,
None => {
// If the key doesn't exist look up the keys again
match fetch_jwt_certificates(jwt_info) {
Some(key_map) => jwt_info.public_keys = key_map,
None => {
eprintln!("Failed to fetch jwt pem certificates");
}
}
// Try to get the keys once more
match jwt_info.public_keys.get(&kid) {
Some(key) => public_key_pem = key,
None => {
eprintln!("Failed to fetch find matching certificates for given KID. {}", kid);
return Err(jsonwebtoken::errors::ErrorKind::InvalidToken.into());
}
}
}
}
// Decode the JWT token
let token_data: TokenData<MyClaims>;
match jwt_header.alg {
Algorithm::RS256 => {
token_data = decode::<MyClaims>(
token,
&DecodingKey::from_rsa_pem(public_key_pem.as_bytes())?,
&validation,
)?;
},
Algorithm::ES256 => {
token_data = decode::<MyClaims>(
token,
&DecodingKey::from_ec_pem(public_key_pem.as_bytes())?,
&validation,
)?;
},
_ => {
eprintln!("JWT Public key algoritm not handled");
return Err(jsonwebtoken::errors::ErrorKind::InvalidAlgorithm.into());
}
}
Ok(token_data.claims)
}
fn fetch_jwt_certificates(jwt_info: &JwtInfo) -> Option<HashMap<String, String>> {
// Fetch the JWKS endpoint
let jwks_body: String;
match reqwest::blocking::get(&jwt_info.jwks_uri) {
Ok(response) => {
match response.text() {
Ok(text) => jwks_body = text,
Err(e) => {
eprintln!("Failed to extract text from response body with error:\n{}", e);
return None;
}
}
},
Err(e) => {
eprintln!("Failed to get the jwks_uri with error:\n{}", e);
return None;
}
}
// Parse the data into the struct
let jwks_data: Jwks;
match serde_json::from_str(&jwks_body) {
Ok(jwks) => jwks_data = jwks,
Err(e) => {
eprintln!("Failed to parse fetched jwks body to Jwks struct with error:\n{}", e);
return None;
}
}
// Create the output hashmap
let mut output_map: HashMap<String, String> = HashMap::new();
// Go through each pair of keys and add them to the output jwt info
for key in jwks_data.keys {
// Extract the x5c key data
let x5c = key.x5c.get(0)?;
// Add the PEM info in to the x5c
let pem_data = format!("-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----", x5c);
// Add the resulting key to the hashmap
output_map.insert(key.kid, pem_data);
}
// Check that we got any keys
if output_map.is_empty() {
eprintln!("Failed to fetch any public keys");
return None;
}
Some(output_map)
}
fn fetch_jwt_info(well_known_uri: &str, expected_issuer: Vec<String>) -> Result<JwtInfo> {
// Fetch the info from the well known endpoint
let well_known_body = reqwest::blocking::get(well_known_uri)
.unwrap()
.text()
.unwrap();
// Parse the data into the well known struct
let well_known_data: AuthorizationWellKnown = serde_json::from_str(&well_known_body).unwrap();
// Validate the issuer
if !expected_issuer.contains(&well_known_data.issuer) {
eprintln!(
"Expected issuer does not contain fetched issuer.\n{}{:?}",
well_known_data.issuer, expected_issuer
);
// TODO: Return Err properly
//Err("Invalid issuer");
}
// Create a JwtInfo variable
let mut jwt_info: JwtInfo = JwtInfo {
jwks_uri: well_known_data.jwks_uri,
audience: Vec::new(),
issuer: expected_issuer,
public_keys: HashMap::new(),
};
// Fetch the valid public keys
match fetch_jwt_certificates(&jwt_info) {
Some(map) => jwt_info.public_keys = map,
None => {
// TODO: Return err properly
}
}
Ok(jwt_info)
}
fn main() {
let token = "eyJhbGciOiJFUzI1NiIsImtpZCI6IjVkM2JkMDcxOGQ4ZWM3NWQ3ZDg1MjlmNDQwMzRiYTc1IiwidHlwIjoiSldUIn0.eyJpc3MiOiJodHRwczovL3Nzby5naXRnYWxzLmNvbS9hcHBsaWNhdGlvbi9vL3NlYnRlc3QvIiwic3ViIjoiZjJiNzIwOGY2MTcwYWI0NWNlZGM1OGUzMTM0NGNjNGY3MGQzZWRjMjhkYWZkMmJlNDZkNzIxMzM1ZDQxZDk2NCIsImF1ZCI6IkNMYUxyOHNpa0VpTjdOQ3JQTWhqaGJ0TFpnblpKNkpaVnpQZFZONVAiLCJleHAiOjE3MDM5NjUxNjIsImlhdCI6MTcwMzk0NzE2MiwiYXV0aF90aW1lIjoxNzAzODU3NzMwLCJhY3IiOiJnb2F1dGhlbnRpay5pby9wcm92aWRlcnMvb2F1dGgyL2RlZmF1bHQiLCJlbWFpbCI6IiIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJuYW1lIjoiSW5zb21uaWEiLCJnaXZlbl9uYW1lIjoiSW5zb21uaWEiLCJwcmVmZXJyZWRfdXNlcm5hbWUiOiJpbnNvbW5pYS10ZXN0Iiwibmlja25hbWUiOiJpbnNvbW5pYS10ZXN0IiwiZ3JvdXBzIjpbXSwiYXpwIjoiQ0xhTHI4c2lrRWlON05DclBNaGpoYnRMWmduWko2SlpWelBkVk41UCIsInVpZCI6IkdPQTNRdTBIOW5TUUx4WHJZVXJ4RHUzTTVkVDhmNTZIQ0l5QlhZdnYifQ.AuVjuJXApMq1vPS48gK5htGSv8KzcCQZlerc82adiCNVb789w2lBoiLjbKotHvAPQOLTQ3qWv2yHNPgBE3dhVA";
let well_known_uri = "https://sso.gitgals.com/application/o/sebtest/.well-known/openid-configuration";
let mut jwt_info = fetch_jwt_info(well_known_uri, vec!("https://sso.gitgals.com/application/o/sebtest/".into())).unwrap();
jwt_info.audience = vec!("CLaLr8sikEiN7NCrPMhjhbtLZgnZJ6JZVzPdVN5P".to_string());
let result: MyClaims;
match validate_jwt(token, &jwt_info) {
Ok(claims) => result = claims,
Err(err) => panic!("Error validating token: {:?}", err),
}
println!("Token is valid! Claims: {:#?}", result);
}