Compare commits

...

3 Commits

Author SHA1 Message Date
Sebastian H. Gabrielli
b5cbe304e5 Make authroization use shared state, giving ~200 speedup 2023-12-31 16:53:11 +01:00
Sebastian H. Gabrielli
a68730ac52 Authorization works 2023-12-31 15:28:35 +01:00
Sebastian H. Gabrielli
40714e4215 Add jwt_validation.rs from JWT validation test repo 2023-12-31 14:12:37 +01:00
6 changed files with 531 additions and 3 deletions

193
Cargo.lock generated
View File

@ -763,8 +763,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"js-sys",
"libc", "libc",
"wasi", "wasi",
"wasm-bindgen",
] ]
[[package]] [[package]]
@ -932,6 +934,19 @@ dependencies = [
"want", "want",
] ]
[[package]]
name = "hyper-tls"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
dependencies = [
"bytes",
"hyper",
"native-tls",
"tokio",
"tokio-native-tls",
]
[[package]] [[package]]
name = "iana-time-zone" name = "iana-time-zone"
version = "0.1.58" version = "0.1.58"
@ -993,6 +1008,12 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "ipnet"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
[[package]] [[package]]
name = "is-terminal" name = "is-terminal"
version = "0.4.9" version = "0.4.9"
@ -1028,6 +1049,21 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "jsonwebtoken"
version = "9.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c7ea04a7c5c055c175f189b6dc6ba036fd62306b58c66c9f6389036c503a3f4"
dependencies = [
"base64",
"js-sys",
"pem",
"ring",
"serde",
"serde_json",
"simple_asn1",
]
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.4.0" version = "1.4.0"
@ -1442,6 +1478,16 @@ dependencies = [
"syn 2.0.42", "syn 2.0.42",
] ]
[[package]]
name = "pem"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b8fcc794035347fb64beda2d3b462595dd2753e3f268d89c5aae77e8cf2c310"
dependencies = [
"base64",
"serde",
]
[[package]] [[package]]
name = "pem-rfc7468" name = "pem-rfc7468"
version = "0.7.0" version = "0.7.0"
@ -1710,6 +1756,58 @@ dependencies = [
"bytecheck", "bytecheck",
] ]
[[package]]
name = "reqwest"
version = "0.11.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41"
dependencies = [
"base64",
"bytes",
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"hyper",
"hyper-tls",
"ipnet",
"js-sys",
"log",
"mime",
"native-tls",
"once_cell",
"percent-encoding",
"pin-project-lite",
"serde",
"serde_json",
"serde_urlencoded",
"system-configuration",
"tokio",
"tokio-native-tls",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"winreg",
]
[[package]]
name = "ring"
version = "0.17.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74"
dependencies = [
"cc",
"getrandom",
"libc",
"spin 0.9.8",
"untrusted",
"windows-sys 0.48.0",
]
[[package]] [[package]]
name = "rkyv" name = "rkyv"
version = "0.7.43" version = "0.7.43"
@ -1782,6 +1880,8 @@ name = "rocket-test"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"futures", "futures",
"jsonwebtoken",
"reqwest",
"rocket", "rocket",
"rocket_cors", "rocket_cors",
"sea-orm", "sea-orm",
@ -2094,6 +2194,18 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd"
dependencies = [
"form_urlencoded",
"itoa",
"ryu",
"serde",
]
[[package]] [[package]]
name = "sha1" name = "sha1"
version = "0.10.6" version = "0.10.6"
@ -2150,6 +2262,18 @@ version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a"
[[package]]
name = "simple_asn1"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085"
dependencies = [
"num-bigint",
"num-traits",
"thiserror",
"time",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.9" version = "0.4.9"
@ -2509,6 +2633,27 @@ dependencies = [
"syn 2.0.42", "syn 2.0.42",
] ]
[[package]]
name = "system-configuration"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "tap" name = "tap"
version = "1.0.1" version = "1.0.1"
@ -2631,6 +2776,16 @@ dependencies = [
"syn 2.0.42", "syn 2.0.42",
] ]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.14" version = "0.1.14"
@ -2858,6 +3013,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]] [[package]]
name = "url" name = "url"
version = "2.5.0" version = "2.5.0"
@ -2942,6 +3103,18 @@ dependencies = [
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.39"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12"
dependencies = [
"cfg-if",
"js-sys",
"wasm-bindgen",
"web-sys",
]
[[package]] [[package]]
name = "wasm-bindgen-macro" name = "wasm-bindgen-macro"
version = "0.2.89" version = "0.2.89"
@ -2971,6 +3144,16 @@ version = "0.2.89"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f"
[[package]]
name = "web-sys"
version = "0.3.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]] [[package]]
name = "whoami" name = "whoami"
version = "1.4.1" version = "1.4.1"
@ -3158,6 +3341,16 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "winreg"
version = "0.50.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1"
dependencies = [
"cfg-if",
"windows-sys 0.48.0",
]
[[package]] [[package]]
name = "wyz" name = "wyz"
version = "0.5.1" version = "0.5.1"

View File

@ -11,3 +11,5 @@ serde = { version = "1.0", features = ["derive"] }
sea-orm = { version = "^0.12.0", features = [ "sqlx-sqlite", "runtime-tokio-native-tls", "macros", "mock" ] } sea-orm = { version = "^0.12.0", features = [ "sqlx-sqlite", "runtime-tokio-native-tls", "macros", "mock" ] }
futures = "0.3.28" futures = "0.3.28"
rocket_cors = "0.6.0" rocket_cors = "0.6.0"
jsonwebtoken = "9.2.0"
reqwest = "0.11.23"

73
src/authorization.rs Normal file
View File

@ -0,0 +1,73 @@
// Import the Rocket requirements
use rocket::http::Status;
use rocket::State;
use rocket::request::{Outcome, Request, FromRequest};
// Import async types to share jwt_info
use crate::SharedJwtInfo;
use crate::jwt_validation::*;
#[derive(Debug)]
pub struct BoardMember {
pub username: String
}
#[derive(Debug)]
pub enum AuthenticationError {
InvalidJWT,
FailedToGrabJWTInfo,
MissingAuthenticationHeader,
InvalidAuthenticationHeader,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for BoardMember {
type Error = AuthenticationError;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
// Extract the autorization header
let autorization_header: &str;
match req.headers().get_one("Authorization") {
Some(data) => {
autorization_header = data;
},
// Missing header, return unauthroized
None => return Outcome::Error((Status::Unauthorized, AuthenticationError::MissingAuthenticationHeader))
}
// Extract the JWT token from the authroization header
let jwt_token: &str;
match autorization_header.split("Bearer ").collect::<Vec<&str>>().get(1) {
Some(token) => jwt_token = token,
// Header is not structured correctly, return unauthroized
None => return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidAuthenticationHeader))
}
// Grab the JWT info struct from rocket
let state = req.guard::<&State<SharedJwtInfo>>().await;
// Handle the state, and upon success store it in the jwt_info
let jwt_info;
match state {
Outcome::Success(fetched_state) => jwt_info = fetched_state.inner().0.as_ref(),
Outcome::Error(err) => return Outcome::Error((err.0, AuthenticationError::FailedToGrabJWTInfo)),
Outcome::Forward(status) => return Outcome::Error((status, AuthenticationError::FailedToGrabJWTInfo)),
}
// Grab the variable lock
let mut jwt_info_lock = jwt_info.lock().await;
// Validate the token and store the result
let valid_token: MyClaims;
match validate_jwt(jwt_token, &mut jwt_info_lock).await {
Ok(data) => valid_token = data,
Err(e) => {
println!("{:?}", e);
return Outcome::Error((Status::Unauthorized, AuthenticationError::InvalidJWT))
}
}
let username = valid_token.preferred_username.unwrap().clone();
Outcome::Success(BoardMember{username})
}
}

229
src/jwt_validation.rs Normal file
View File

@ -0,0 +1,229 @@
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation, TokenData};
use reqwest;
use rocket::serde::{Deserialize, json::serde_json};
use std::collections::HashMap;
// JWT Shared state async stuff
use rocket::tokio::sync::Mutex;
use std::sync::Arc;
// Define a struct for the claims you expect in your token
#[derive(Debug, Deserialize)]
pub struct MyClaims {
pub sub: String,
pub exp: usize,
pub aud: String,
pub iss: String,
pub preferred_username: Option<String>,
}
#[derive(Deserialize)]
struct AuthorizationWellKnown {
issuer: String,
jwks_uri: String,
}
#[derive(Deserialize)]
struct JwksContent {
kid: String,
x5c: Vec<String>,
}
#[derive(Deserialize)]
struct Jwks {
keys: Vec<JwksContent>,
}
pub struct JwtInfo {
pub jwks_uri: String,
pub audience: Vec<String>,
pub issuer: Vec<String>,
pub public_keys: HashMap<String, String>,
}
#[derive(Debug)]
pub enum MyCustomErrorType {
NetworkError,
JwtError,
JsonParseError,
}
pub async fn validate_jwt(token: &str, jwt_info: &mut JwtInfo) -> Result<MyClaims, jsonwebtoken::errors::Error> {
// Decode the header to give info about the crypto
let jwt_header = decode_header(token)?;
// Create a new validation
let mut validation = Validation::new(jwt_header.alg);
// Set the expected audience and issuer
validation.set_audience(&jwt_info.audience);
validation.set_issuer(&jwt_info.issuer);
// Extract the JWT kid
let kid: String;
match jwt_header.kid {
Some(fetched_kid) => kid = fetched_kid,
None => {
eprintln!("Unable to extract KID from jwt header");
return Err(jsonwebtoken::errors::ErrorKind::InvalidToken.into());
}
}
// Fetch the corresponding public key
let public_key_pem: &String;
match jwt_info.public_keys.get(&kid) {
Some(key) => public_key_pem = key,
None => {
// If the key doesn't exist look up the keys again
match fetch_jwt_certificates(jwt_info).await {
Some(key_map) => jwt_info.public_keys = key_map,
None => {
eprintln!("Failed to fetch jwt pem certificates");
}
}
// Try to get the keys once more
match jwt_info.public_keys.get(&kid) {
Some(key) => public_key_pem = key,
None => {
eprintln!("Failed to fetch find matching certificates for given KID. {}", kid);
return Err(jsonwebtoken::errors::ErrorKind::InvalidToken.into());
}
}
}
}
// Decode the JWT token
let token_data: TokenData<MyClaims>;
match jwt_header.alg {
Algorithm::RS256 => {
token_data = decode::<MyClaims>(
token,
&DecodingKey::from_rsa_pem(public_key_pem.as_bytes())?,
&validation,
)?;
},
Algorithm::ES256 => {
token_data = decode::<MyClaims>(
token,
&DecodingKey::from_ec_pem(public_key_pem.as_bytes())?,
&validation,
)?;
},
_ => {
eprintln!("JWT Public key algoritm not handled");
return Err(jsonwebtoken::errors::ErrorKind::InvalidAlgorithm.into());
}
}
Ok(token_data.claims)
}
async fn fetch_jwt_certificates(jwt_info: &JwtInfo) -> Option<HashMap<String, String>> {
// Fetch the JWKS endpoint
let jwks_body: String;
match reqwest::get(&jwt_info.jwks_uri).await {
Ok(response) => {
match response.text().await {
Ok(text) => jwks_body = text,
Err(e) => {
eprintln!("Failed to extract text from response body with error:\n{}", e);
return None;
}
}
},
Err(e) => {
eprintln!("Failed to get the jwks_uri with error:\n{}", e);
return None;
}
}
// Parse the data into the struct
let jwks_data: Jwks;
match serde_json::from_str(&jwks_body) {
Ok(jwks) => jwks_data = jwks,
Err(e) => {
eprintln!("Failed to parse fetched jwks body to Jwks struct with error:\n{}", e);
return None;
}
}
// Create the output hashmap
let mut output_map: HashMap<String, String> = HashMap::new();
// Go through each pair of keys and add them to the output jwt info
for key in jwks_data.keys {
// Extract the x5c key data
let x5c = key.x5c.get(0)?;
// Add the PEM info in to the x5c
let pem_data = format!("-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----", x5c);
// Add the resulting key to the hashmap
output_map.insert(key.kid, pem_data);
}
// Check that we got any keys
if output_map.is_empty() {
eprintln!("Failed to fetch any public keys");
return None;
}
Some(output_map)
}
pub async fn fetch_jwt_info(well_known_uri: &str, expected_issuer: Vec<String>) -> Result<JwtInfo, MyCustomErrorType> {
// Fetch the info from the well known endpoint
let well_known_body;
match reqwest::get(well_known_uri).await {
Ok(response) => {
match response.text().await {
Ok(text) => well_known_body = text,
Err(e) => {
eprintln!("Failed to extract text from response body with error:\n{}", e);
return Err(MyCustomErrorType::NetworkError);
}
}
},
Err(e) => {
eprintln!("Failed to get the well known with error:\n{}", e);
return Err(MyCustomErrorType::NetworkError);
}
}
// Parse the data into the well known struct
let well_known_data: AuthorizationWellKnown;
match serde_json::from_str(&well_known_body) {
Ok(data) => well_known_data = data,
Err(e) => {
eprintln!("Failed to parse well known data into struct with err:\n{}", e);
return Err(MyCustomErrorType::JsonParseError);
}
}
// Validate the issuer
if !expected_issuer.contains(&well_known_data.issuer) {
eprintln!(
"Expected issuer does not contain fetched issuer.\n{} ∉ {:?}",
well_known_data.issuer, expected_issuer
);
return Err(MyCustomErrorType::JwtError);
}
// Create a JwtInfo variable
let mut jwt_info: JwtInfo = JwtInfo {
jwks_uri: well_known_data.jwks_uri,
audience: Vec::new(),
issuer: expected_issuer,
public_keys: HashMap::new(),
};
// Fetch the valid public keys
match fetch_jwt_certificates(&jwt_info).await {
Some(map) => jwt_info.public_keys = map,
None => {
return Err(MyCustomErrorType::JwtError);
}
}
Ok(jwt_info)
}

View File

@ -15,24 +15,55 @@ use webserver_member::*;
// Handle CORS // Handle CORS
use rocket_cors::{AllowedOrigins, CorsOptions}; use rocket_cors::{AllowedOrigins, CorsOptions};
// Handle autorization
// Async types to share jwt_info
use rocket::tokio::sync::Mutex;
use std::sync::Arc;
// Actual authorization functions
mod jwt_validation;
use jwt_validation::{JwtInfo, fetch_jwt_info};
mod authorization;
use authorization::BoardMember;
struct SharedJwtInfo(Arc<Mutex<JwtInfo>>);
// Serve the very exiting main page // Serve the very exiting main page
#[get("/")] #[get("/")]
fn index() -> &'static str { fn index(board_member: BoardMember) -> String {
"Hello, world!\nNothing useful is served here." format!("Hello, world!\nThe autorized user's preffered username is: {:?}", board_member.username)
} }
#[launch] #[launch]
async fn rocket() -> _ { async fn rocket() -> _ {
// Create a database connection
let db = match set_up_db().await { let db = match set_up_db().await {
Ok(db) => db, Ok(db) => db,
Err(err) => panic!("{}", err) Err(err) => panic!("{}", err)
}; };
// Fetch the JWT info from the authentication server
let jwt_info: Arc<Mutex<JwtInfo>> = Arc::new(Mutex::new(fetch_jwt_info(
"https://sso.gitgals.com/application/o/sebtest/.well-known/openid-configuration",
vec!["https://sso.gitgals.com/application/o/sebtest/".into()]
).await.expect("Failed to fetch authorization info")));
// Set the expected audience field in a scoped block to make sure the lock is released as soon as we are done
{
// Grab the mutex lock
let mut jwt_info_lock = jwt_info.lock().await;
// Set the expected audience field
jwt_info_lock.audience = vec!("CLaLr8sikEiN7NCrPMhjhbtLZgnZJ6JZVzPdVN5P".into());
}
// Configure CORS
let cors = CorsOptions::default().allowed_origins(AllowedOrigins::all()) let cors = CorsOptions::default().allowed_origins(AllowedOrigins::all())
.to_cors().unwrap(); .to_cors().unwrap();
// Start the webserver
rocket::build() rocket::build()
.manage(db) .manage(db)
.manage(SharedJwtInfo(jwt_info))
.attach(cors) .attach(cors)
.mount("/", routes![ .mount("/", routes![
index, index,

BIN
test.db

Binary file not shown.