diff --git a/src/auth.rs b/src/auth.rs index 85b6359e..7eabbc1e 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -2,9 +2,10 @@ // use chrono::{Duration, Utc}; use num_traits::FromPrimitive; -use once_cell::sync::Lazy; +use once_cell::sync::{Lazy, OnceCell}; use jsonwebtoken::{self, errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header}; +use openssl::rsa::Rsa; use serde::de::DeserializeOwned; use serde::ser::Serialize; @@ -26,23 +27,45 @@ static JWT_SEND_ISSUER: Lazy = Lazy::new(|| format!("{}|send", CONFIG.do static JWT_ORG_API_KEY_ISSUER: Lazy = Lazy::new(|| format!("{}|api.organization", CONFIG.domain_origin())); static JWT_FILE_DOWNLOAD_ISSUER: Lazy = Lazy::new(|| format!("{}|file_download", CONFIG.domain_origin())); -static PRIVATE_RSA_KEY: Lazy = Lazy::new(|| { - let key = - std::fs::read(CONFIG.private_rsa_key()).unwrap_or_else(|e| panic!("Error loading private RSA Key. \n{e}")); - EncodingKey::from_rsa_pem(&key).unwrap_or_else(|e| panic!("Error decoding private RSA Key.\n{e}")) -}); -static PUBLIC_RSA_KEY: Lazy = Lazy::new(|| { - let key = std::fs::read(CONFIG.public_rsa_key()).unwrap_or_else(|e| panic!("Error loading public RSA Key. \n{e}")); - DecodingKey::from_rsa_pem(&key).unwrap_or_else(|e| panic!("Error decoding public RSA Key.\n{e}")) -}); +static PRIVATE_RSA_KEY: OnceCell = OnceCell::new(); +static PUBLIC_RSA_KEY: OnceCell = OnceCell::new(); -pub fn load_keys() { - Lazy::force(&PRIVATE_RSA_KEY); - Lazy::force(&PUBLIC_RSA_KEY); +pub fn initialize_keys() -> Result<(), crate::error::Error> { + let mut priv_key_buffer = Vec::with_capacity(2048); + + let priv_key = { + let mut priv_key_file = File::options().create(true).read(true).write(true).open(CONFIG.private_rsa_key())?; + + #[allow(clippy::verbose_file_reads)] + let bytes_read = priv_key_file.read_to_end(&mut priv_key_buffer)?; + + if bytes_read > 0 { + Rsa::private_key_from_pem(&priv_key_buffer[..bytes_read])? + } else { + // Only create the key if the file doesn't exist or is empty + let rsa_key = openssl::rsa::Rsa::generate(2048)?; + priv_key_buffer = rsa_key.private_key_to_pem()?; + priv_key_file.write_all(&priv_key_buffer)?; + info!("Private key created correctly."); + rsa_key + } + }; + + let pub_key_buffer = priv_key.public_key_to_pem()?; + + let enc = EncodingKey::from_rsa_pem(&priv_key_buffer)?; + let dec: DecodingKey = DecodingKey::from_rsa_pem(&pub_key_buffer)?; + if PRIVATE_RSA_KEY.set(enc).is_err() { + err!("PRIVATE_RSA_KEY must only be initialized once") + } + if PUBLIC_RSA_KEY.set(dec).is_err() { + err!("PUBLIC_RSA_KEY must only be initialized once") + } + Ok(()) } pub fn encode_jwt(claims: &T) -> String { - match jsonwebtoken::encode(&JWT_HEADER, claims, &PRIVATE_RSA_KEY) { + match jsonwebtoken::encode(&JWT_HEADER, claims, PRIVATE_RSA_KEY.wait()) { Ok(token) => token, Err(e) => panic!("Error encoding jwt {e}"), } @@ -56,7 +79,7 @@ fn decode_jwt(token: &str, issuer: String) -> Result Ok(d.claims), Err(err) => match *err.kind() { ErrorKind::InvalidToken => err!("Token is invalid"), @@ -799,7 +822,11 @@ impl<'r> FromRequest<'r> for OwnerHeaders { // // Client IP address detection // -use std::net::IpAddr; +use std::{ + fs::File, + io::{Read, Write}, + net::IpAddr, +}; pub struct ClientIp { pub ip: IpAddr, diff --git a/src/config.rs b/src/config.rs index 2f0e9264..e174c66b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1164,7 +1164,7 @@ impl Config { } pub fn delete_user_config(&self) -> Result<(), Error> { - crate::util::delete_file(&CONFIG_FILE)?; + std::fs::remove_file(&*CONFIG_FILE)?; // Empty user config let usr = ConfigBuilder::default(); @@ -1189,9 +1189,6 @@ impl Config { pub fn private_rsa_key(&self) -> String { format!("{}.pem", CONFIG.rsa_key_filename()) } - pub fn public_rsa_key(&self) -> String { - format!("{}.pub.pem", CONFIG.rsa_key_filename()) - } pub fn mail_enabled(&self) -> bool { let inner = &self.inner.read().unwrap().config; inner._enable_smtp && (inner.smtp_host.is_some() || inner.use_sendmail) diff --git a/src/db/models/attachment.rs b/src/db/models/attachment.rs index 8f05e6b4..f8eca72f 100644 --- a/src/db/models/attachment.rs +++ b/src/db/models/attachment.rs @@ -103,7 +103,7 @@ impl Attachment { let file_path = &self.get_file_path(); - match crate::util::delete_file(file_path) { + match std::fs::remove_file(file_path) { // Ignore "file not found" errors. This can happen when the // upstream caller has already cleaned up the file as part of // its own error handling. diff --git a/src/main.rs b/src/main.rs index 05f43c5a..53b72606 100644 --- a/src/main.rs +++ b/src/main.rs @@ -71,7 +71,7 @@ async fn main() -> Result<(), Error> { let extra_debug = matches!(level, LF::Trace | LF::Debug); check_data_folder().await; - check_rsa_keys().unwrap_or_else(|_| { + auth::initialize_keys().unwrap_or_else(|_| { error!("Error creating keys, exiting..."); exit(1); }); @@ -444,31 +444,6 @@ async fn container_data_folder_is_persistent(data_folder: &str) -> bool { true } -fn check_rsa_keys() -> Result<(), crate::error::Error> { - // If the RSA keys don't exist, try to create them - let priv_path = CONFIG.private_rsa_key(); - let pub_path = CONFIG.public_rsa_key(); - - if !util::file_exists(&priv_path) { - let rsa_key = openssl::rsa::Rsa::generate(2048)?; - - let priv_key = rsa_key.private_key_to_pem()?; - crate::util::write_file(&priv_path, &priv_key)?; - info!("Private key created correctly."); - } - - if !util::file_exists(&pub_path) { - let rsa_key = openssl::rsa::Rsa::private_key_from_pem(&std::fs::read(&priv_path)?)?; - - let pub_key = rsa_key.public_key_to_pem()?; - crate::util::write_file(&pub_path, &pub_key)?; - info!("Public key created correctly."); - } - - auth::load_keys(); - Ok(()) -} - fn check_web_vault() { if !CONFIG.web_vault_enabled() { return; diff --git a/src/util.rs b/src/util.rs index 0bf37959..2f04fe34 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,11 +1,7 @@ // // Web Headers and caching // -use std::{ - collections::HashMap, - io::{Cursor, ErrorKind}, - ops::Deref, -}; +use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path}; use num_traits::ToPrimitive; use rocket::{ @@ -334,40 +330,6 @@ impl Fairing for BetterLogging { } } -// -// File handling -// -use std::{ - fs::{self, File}, - io::Result as IOResult, - path::Path, -}; - -pub fn file_exists(path: &str) -> bool { - Path::new(path).exists() -} - -pub fn write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error> { - use std::io::Write; - let mut f = match File::create(path) { - Ok(file) => file, - Err(e) => { - if e.kind() == ErrorKind::PermissionDenied { - error!("Can't create '{}': Permission denied", path); - } - return Err(From::from(e)); - } - }; - - f.write_all(content)?; - f.flush()?; - Ok(()) -} - -pub fn delete_file(path: &str) -> IOResult<()> { - fs::remove_file(path) -} - pub fn get_display_size(size: i64) -> String { const UNITS: [&str; 6] = ["bytes", "KB", "MB", "GB", "TB", "PB"]; @@ -444,7 +406,7 @@ pub fn get_env_str_value(key: &str) -> Option { match (value_from_env, value_file) { (Ok(_), Ok(_)) => panic!("You should not define both {key} and {key_file}!"), (Ok(v_env), Err(_)) => Some(v_env), - (Err(_), Ok(v_file)) => match fs::read_to_string(v_file) { + (Err(_), Ok(v_file)) => match std::fs::read_to_string(v_file) { Ok(content) => Some(content.trim().to_string()), Err(e) => panic!("Failed to load {key}: {e:?}"), },