// // Web Headers and caching // use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path}; use num_traits::ToPrimitive; use once_cell::sync::Lazy; use rocket::{ fairing::{Fairing, Info, Kind}, http::{ContentType, Header, HeaderMap, Method, Status}, request::FromParam, response::{self, Responder}, Data, Orbit, Request, Response, Rocket, }; use tokio::{ runtime::Handle, time::{sleep, Duration}, }; use crate::CONFIG; pub struct AppHeaders(); #[rocket::async_trait] impl Fairing for AppHeaders { fn info(&self) -> Info { Info { name: "Application Headers", kind: Kind::Response, } } async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { let req_uri_path = req.uri().path(); let req_headers = req.headers(); // Check if this connection is an Upgrade/WebSocket connection and return early // We do not want add any extra headers, this could cause issues with reverse proxies or CloudFlare if req_uri_path.ends_with("notifications/hub") || req_uri_path.ends_with("notifications/anonymous-hub") { match (req_headers.get_one("connection"), req_headers.get_one("upgrade")) { (Some(c), Some(u)) if c.to_lowercase().contains("upgrade") && u.to_lowercase().contains("websocket") => { // Remove headers which could cause websocket connection issues res.remove_header("X-Frame-Options"); res.remove_header("X-Content-Type-Options"); res.remove_header("Permissions-Policy"); return; } (_, _) => (), } } res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()"); res.set_raw_header("Referrer-Policy", "same-origin"); res.set_raw_header("X-Content-Type-Options", "nosniff"); // Obsolete in modern browsers, unsafe (XS-Leak), and largely replaced by CSP res.set_raw_header("X-XSS-Protection", "0"); // Do not send the Content-Security-Policy (CSP) Header and X-Frame-Options for the *-connector.html files. // This can cause issues when some MFA requests needs to open a popup or page within the clients like WebAuthn, or Duo. // This is the same behavior as upstream Bitwarden. if !req_uri_path.ends_with("connector.html") { // # Frame Ancestors: // Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb // Edge Add-ons: https://microsoftedge.microsoft.com/addons/detail/bitwarden-free-password/jbkfoedolllekgbhcbcoahefnbanhhlh?hl=en-US // Firefox Browser Add-ons: https://addons.mozilla.org/en-US/firefox/addon/bitwarden-password-manager/ // # img/child/frame src: // Have I Been Pwned to allow those calls to work. // # Connect src: // Leaked Passwords check: api.pwnedpasswords.com // 2FA/MFA Site check: api.2fa.directory // # Mail Relay: https://bitwarden.com/blog/add-privacy-and-security-using-email-aliases-with-bitwarden/ // app.simplelogin.io, app.addy.io, api.fastmail.com, quack.duckduckgo.com let csp = format!( "default-src 'self'; \ base-uri 'self'; \ form-action 'self'; \ object-src 'self' blob:; \ script-src 'self' 'wasm-unsafe-eval'; \ style-src 'self' 'unsafe-inline'; \ child-src 'self' https://*.duosecurity.com https://*.duofederal.com; \ frame-src 'self' https://*.duosecurity.com https://*.duofederal.com; \ frame-ancestors 'self' \ chrome-extension://nngceckbapebfimnlniiiahkandclblb \ chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh \ moz-extension://* \ {allowed_iframe_ancestors}; \ img-src 'self' data: \ https://haveibeenpwned.com \ {icon_service_csp}; \ connect-src 'self' \ https://api.pwnedpasswords.com \ https://api.2fa.directory \ https://app.simplelogin.io/api/ \ https://app.addy.io/api/ \ https://api.fastmail.com/ \ https://api.forwardemail.net \ ;\ ", icon_service_csp = CONFIG._icon_service_csp(), allowed_iframe_ancestors = CONFIG.allowed_iframe_ancestors() ); res.set_raw_header("Content-Security-Policy", csp); res.set_raw_header("X-Frame-Options", "SAMEORIGIN"); } else { // It looks like this header get's set somewhere else also, make sure this is not sent for these files, it will cause MFA issues. res.remove_header("X-Frame-Options"); } // Disable cache unless otherwise specified if !res.headers().contains("cache-control") { res.set_raw_header("Cache-Control", "no-cache, no-store, max-age=0"); } } } pub struct Cors(); impl Cors { fn get_header(headers: &HeaderMap<'_>, name: &str) -> String { match headers.get_one(name) { Some(h) => h.to_string(), _ => String::new(), } } // Check a request's `Origin` header against the list of allowed origins. // If a match exists, return it. Otherwise, return None. fn get_allowed_origin(headers: &HeaderMap<'_>) -> Option { let origin = Cors::get_header(headers, "Origin"); let domain_origin = CONFIG.domain_origin(); let safari_extension_origin = "file://"; if origin == domain_origin || origin == safari_extension_origin { Some(origin) } else { None } } } #[rocket::async_trait] impl Fairing for Cors { fn info(&self) -> Info { Info { name: "Cors", kind: Kind::Response, } } async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) { let req_headers = request.headers(); if let Some(origin) = Cors::get_allowed_origin(req_headers) { response.set_header(Header::new("Access-Control-Allow-Origin", origin)); } // Preflight request if request.method() == Method::Options { let req_allow_headers = Cors::get_header(req_headers, "Access-Control-Request-Headers"); let req_allow_method = Cors::get_header(req_headers, "Access-Control-Request-Method"); response.set_header(Header::new("Access-Control-Allow-Methods", req_allow_method)); response.set_header(Header::new("Access-Control-Allow-Headers", req_allow_headers)); response.set_header(Header::new("Access-Control-Allow-Credentials", "true")); response.set_status(Status::Ok); response.set_header(ContentType::Plain); response.set_sized_body(Some(0), Cursor::new("")); } } } pub struct Cached { response: R, is_immutable: bool, ttl: u64, } impl Cached { pub fn long(response: R, is_immutable: bool) -> Cached { Self { response, is_immutable, ttl: 604800, // 7 days } } pub fn short(response: R, is_immutable: bool) -> Cached { Self { response, is_immutable, ttl: 600, // 10 minutes } } pub fn ttl(response: R, ttl: u64, is_immutable: bool) -> Cached { Self { response, is_immutable, ttl, } } } impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached { fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> { let mut res = self.response.respond_to(request)?; let cache_control_header = if self.is_immutable { format!("public, immutable, max-age={}", self.ttl) } else { format!("public, max-age={}", self.ttl) }; res.set_raw_header("Cache-Control", cache_control_header); let time_now = chrono::Local::now(); let expiry_time = time_now + chrono::TimeDelta::try_seconds(self.ttl.try_into().unwrap()).unwrap(); res.set_raw_header("Expires", format_datetime_http(&expiry_time)); Ok(res) } } pub struct SafeString(String); impl std::fmt::Display for SafeString { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } impl Deref for SafeString { type Target = String; fn deref(&self) -> &Self::Target { &self.0 } } impl AsRef for SafeString { #[inline] fn as_ref(&self) -> &Path { Path::new(&self.0) } } impl<'r> FromParam<'r> for SafeString { type Error = (); #[inline(always)] fn from_param(param: &'r str) -> Result { if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { Ok(SafeString(param.to_string())) } else { Err(()) } } } // Log all the routes from the main paths list, and the attachments endpoint // Effectively ignores, any static file route, and the alive endpoint const LOGGED_ROUTES: [&str; 7] = ["/api", "/admin", "/identity", "/icons", "/attachments", "/events", "/notifications"]; // Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts pub struct BetterLogging(pub bool); #[rocket::async_trait] impl Fairing for BetterLogging { fn info(&self) -> Info { Info { name: "Better Logging", kind: Kind::Liftoff | Kind::Request | Kind::Response, } } async fn on_liftoff(&self, rocket: &Rocket) { if self.0 { info!(target: "routes", "Routes loaded:"); let mut routes: Vec<_> = rocket.routes().collect(); routes.sort_by_key(|r| r.uri.path()); for route in routes { if route.rank < 0 { info!(target: "routes", "{:<6} {}", route.method, route.uri); } else { info!(target: "routes", "{:<6} {} [{}]", route.method, route.uri, route.rank); } } } let config = rocket.config(); let scheme = if config.tls_enabled() { "https" } else { "http" }; let addr = format!("{}://{}:{}", &scheme, &config.address, &config.port); info!(target: "start", "Rocket has launched from {}", addr); } async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) { let method = request.method(); if !self.0 && method == Method::Options { return; } let uri = request.uri(); let uri_path = uri.path(); let uri_path_str = uri_path.url_decode_lossy(); let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str); if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) { match uri.query() { Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]), None => info!(target: "request", "{} {}", method, uri_path_str), }; } } async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) { if !self.0 && request.method() == Method::Options { return; } let uri_path = request.uri().path(); let uri_path_str = uri_path.url_decode_lossy(); let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str); if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) { let status = response.status(); if let Some(ref route) = request.route() { info!(target: "response", "{} => {}", route, status) } else { info!(target: "response", "{}", status) } } } } pub fn get_display_size(size: i64) -> String { const UNITS: [&str; 6] = ["bytes", "KB", "MB", "GB", "TB", "PB"]; // If we're somehow too big for a f64, just return the size in bytes let Some(mut size) = size.to_f64() else { return format!("{size} bytes"); }; let mut unit_counter = 0; loop { if size > 1024. { size /= 1024.; unit_counter += 1; } else { break; } } format!("{:.2} {}", size, UNITS[unit_counter]) } pub fn get_uuid() -> String { uuid::Uuid::new_v4().to_string() } // // String util methods // use std::str::FromStr; #[inline] pub fn upcase_first(s: &str) -> String { let mut c = s.chars(); match c.next() { None => String::new(), Some(f) => f.to_uppercase().collect::() + c.as_str(), } } #[inline] pub fn lcase_first(s: &str) -> String { let mut c = s.chars(); match c.next() { None => String::new(), Some(f) => f.to_lowercase().collect::() + c.as_str(), } } pub fn try_parse_string(string: Option) -> Option where S: AsRef, T: FromStr, { if let Some(Ok(value)) = string.map(|s| s.as_ref().parse::()) { Some(value) } else { None } } // // Env methods // use std::env; pub fn get_env_str_value(key: &str) -> Option { let key_file = format!("{key}_FILE"); let value_from_env = env::var(key); let value_file = env::var(&key_file); match (value_from_env, value_file) { (Ok(_), Ok(_)) => panic!("You should not define both {key} and {key_file}!"), (Ok(v_env), Err(_)) => Some(v_env), (Err(_), Ok(v_file)) => match std::fs::read_to_string(v_file) { Ok(content) => Some(content.trim().to_string()), Err(e) => panic!("Failed to load {key}: {e:?}"), }, _ => None, } } pub fn get_env(key: &str) -> Option where V: FromStr, { try_parse_string(get_env_str_value(key)) } pub fn get_env_bool(key: &str) -> Option { const TRUE_VALUES: &[&str] = &["true", "t", "yes", "y", "1"]; const FALSE_VALUES: &[&str] = &["false", "f", "no", "n", "0"]; match get_env_str_value(key) { Some(val) if TRUE_VALUES.contains(&val.to_lowercase().as_ref()) => Some(true), Some(val) if FALSE_VALUES.contains(&val.to_lowercase().as_ref()) => Some(false), _ => None, } } // // Date util methods // use chrono::{DateTime, Local, NaiveDateTime, TimeZone}; // Format used by Bitwarden API const DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.6fZ"; /// Formats a UTC-offset `NaiveDateTime` in the format used by Bitwarden API /// responses with "date" fields (`CreationDate`, `RevisionDate`, etc.). pub fn format_date(dt: &NaiveDateTime) -> String { dt.format(DATETIME_FORMAT).to_string() } /// Formats a `DateTime` using the specified format string. /// /// For a `DateTime`, the `%Z` specifier normally formats as the /// time zone's UTC offset (e.g., `+00:00`). In this function, if the /// `TZ` environment variable is set, then `%Z` instead formats as the /// abbreviation for that time zone (e.g., `UTC`). pub fn format_datetime_local(dt: &DateTime, fmt: &str) -> String { // Try parsing the `TZ` environment variable to enable formatting `%Z` as // a time zone abbreviation. if let Ok(tz) = env::var("TZ") { if let Ok(tz) = tz.parse::() { return dt.with_timezone(&tz).format(fmt).to_string(); } } // Otherwise, fall back to formatting `%Z` as a UTC offset. dt.format(fmt).to_string() } /// Formats a UTC-offset `NaiveDateTime` as a datetime in the local time zone. /// /// This function basically converts the `NaiveDateTime` to a `DateTime`, /// and then calls [format_datetime_local](crate::util::format_datetime_local). pub fn format_naive_datetime_local(dt: &NaiveDateTime, fmt: &str) -> String { format_datetime_local(&Local.from_utc_datetime(dt), fmt) } /// Formats a `DateTime` as required for HTTP /// /// https://httpwg.org/specs/rfc7231.html#http.date pub fn format_datetime_http(dt: &DateTime) -> String { let expiry_time = DateTime::::from_naive_utc_and_offset(dt.naive_utc(), chrono::Utc); // HACK: HTTP expects the date to always be GMT (UTC) rather than giving an // offset (which would always be 0 in UTC anyway) expiry_time.to_rfc2822().replace("+0000", "GMT") } pub fn parse_date(date: &str) -> NaiveDateTime { NaiveDateTime::parse_from_str(date, DATETIME_FORMAT).unwrap() } // // Deployment environment methods // /// Returns true if the program is running in Docker, Podman or Kubernetes. pub fn is_running_in_container() -> bool { Path::new("/.dockerenv").exists() || Path::new("/run/.containerenv").exists() || Path::new("/run/secrets/kubernetes.io").exists() || Path::new("/var/run/secrets/kubernetes.io").exists() } /// Simple check to determine on which container base image vaultwarden is running. /// We build images based upon Debian or Alpine, so these we check here. pub fn container_base_image() -> &'static str { if Path::new("/etc/debian_version").exists() { "Debian" } else if Path::new("/etc/alpine-release").exists() { "Alpine" } else { "Unknown" } } // // Deserialization methods // use std::fmt; use serde::de::{self, DeserializeOwned, Deserializer, MapAccess, SeqAccess, Visitor}; use serde_json::Value; pub type JsonMap = serde_json::Map; #[derive(Serialize, Deserialize)] pub struct UpCase { #[serde(deserialize_with = "upcase_deserialize")] #[serde(flatten)] pub data: T, } // https://github.com/serde-rs/serde/issues/586 pub fn upcase_deserialize<'de, T, D>(deserializer: D) -> Result where T: DeserializeOwned, D: Deserializer<'de>, { let d = deserializer.deserialize_any(UpCaseVisitor)?; T::deserialize(d).map_err(de::Error::custom) } struct UpCaseVisitor; impl<'de> Visitor<'de> for UpCaseVisitor { type Value = Value; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter.write_str("an object or an array") } fn visit_map(self, mut map: A) -> Result where A: MapAccess<'de>, { let mut result_map = JsonMap::new(); while let Some((key, value)) = map.next_entry()? { result_map.insert(upcase_first(key), upcase_value(value)); } Ok(Value::Object(result_map)) } fn visit_seq(self, mut seq: A) -> Result where A: SeqAccess<'de>, { let mut result_seq = Vec::::new(); while let Some(value) = seq.next_element()? { result_seq.push(upcase_value(value)); } Ok(Value::Array(result_seq)) } } fn upcase_value(value: Value) -> Value { if let Value::Object(map) = value { let mut new_value = Value::Object(serde_json::Map::new()); for (key, val) in map.into_iter() { let processed_key = _process_key(&key); new_value[processed_key] = upcase_value(val); } new_value } else if let Value::Array(array) = value { // Initialize array with null values let mut new_value = Value::Array(vec![Value::Null; array.len()]); for (index, val) in array.into_iter().enumerate() { new_value[index] = upcase_value(val); } new_value } else { value } } // Inner function to handle a special case for the 'ssn' key. // This key is part of the Identity Cipher (Social Security Number) fn _process_key(key: &str) -> String { match key.to_lowercase().as_ref() { "ssn" => "SSN".into(), _ => self::upcase_first(key), } } #[derive(Deserialize, Debug, Clone)] #[serde(untagged)] pub enum NumberOrString { Number(i64), String(String), } impl NumberOrString { pub fn into_string(self) -> String { match self { NumberOrString::Number(n) => n.to_string(), NumberOrString::String(s) => s, } } #[allow(clippy::wrong_self_convention)] pub fn into_i32(&self) -> Result { use std::num::ParseIntError as PIE; match self { NumberOrString::Number(n) => match n.to_i32() { Some(n) => Ok(n), None => err!("Number does not fit in i32"), }, NumberOrString::String(s) => { s.parse().map_err(|e: PIE| crate::Error::new("Can't convert to number", e.to_string())) } } } #[allow(clippy::wrong_self_convention)] pub fn into_i64(&self) -> Result { use std::num::ParseIntError as PIE; match self { NumberOrString::Number(n) => Ok(*n), NumberOrString::String(s) => { s.parse().map_err(|e: PIE| crate::Error::new("Can't convert to number", e.to_string())) } } } } // // Retry methods // pub fn retry(mut func: F, max_tries: u32) -> Result where F: FnMut() -> Result, { let mut tries = 0; loop { match func() { ok @ Ok(_) => return ok, err @ Err(_) => { tries += 1; if tries >= max_tries { return err; } Handle::current().block_on(sleep(Duration::from_millis(500))); } } } } pub async fn retry_db(mut func: F, max_tries: u32) -> Result where F: FnMut() -> Result, E: std::error::Error, { let mut tries = 0; loop { match func() { ok @ Ok(_) => return ok, Err(e) => { tries += 1; if tries >= max_tries && max_tries > 0 { return Err(e); } warn!("Can't connect to database, retrying: {:?}", e); sleep(Duration::from_millis(1_000)).await; } } } } use reqwest::{header, Client, ClientBuilder}; pub fn get_reqwest_client() -> &'static Client { static INSTANCE: Lazy = Lazy::new(|| get_reqwest_client_builder().build().expect("Failed to build client")); &INSTANCE } pub fn get_reqwest_client_builder() -> ClientBuilder { let mut headers = header::HeaderMap::new(); headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden")); Client::builder().default_headers(headers).timeout(Duration::from_secs(10)) } pub fn convert_json_key_lcase_first(src_json: Value) -> Value { match src_json { Value::Array(elm) => { let mut new_array: Vec = Vec::with_capacity(elm.len()); for obj in elm { new_array.push(convert_json_key_lcase_first(obj)); } Value::Array(new_array) } Value::Object(obj) => { let mut json_map = JsonMap::new(); for (key, value) in obj.iter() { match (key, value) { (key, Value::Object(elm)) => { let inner_value = convert_json_key_lcase_first(Value::Object(elm.clone())); json_map.insert(lcase_first(key), inner_value); } (key, Value::Array(elm)) => { let mut inner_array: Vec = Vec::with_capacity(elm.len()); for inner_obj in elm { inner_array.push(convert_json_key_lcase_first(inner_obj.clone())); } json_map.insert(lcase_first(key), Value::Array(inner_array)); } (key, value) => { json_map.insert(lcase_first(key), value.clone()); } } } Value::Object(json_map) } value => value, } } /// Parses the experimental client feature flags string into a HashMap. pub fn parse_experimental_client_feature_flags(experimental_client_feature_flags: &str) -> HashMap { let feature_states = experimental_client_feature_flags.to_lowercase().split(',').map(|f| (f.trim().to_owned(), true)).collect(); feature_states } mod dns_resolver { use std::{ fmt, net::{IpAddr, SocketAddr}, sync::Arc, }; use hickory_resolver::{system_conf::read_system_conf, TokioAsyncResolver}; use once_cell::sync::Lazy; use reqwest::dns::{Name, Resolve, Resolving}; use crate::{util::is_global, CONFIG}; #[derive(Debug, Clone)] pub enum CustomResolverError { Blacklist { domain: String, }, NonGlobalIp { domain: String, ip: IpAddr, }, } impl CustomResolverError { pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> { let mut source = e.source(); while let Some(err) = source { source = err.source(); if let Some(err) = err.downcast_ref::() { return Some(err); } } None } } impl fmt::Display for CustomResolverError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Blacklist { domain, } => write!(f, "Blacklisted domain: {domain} matched ICON_BLACKLIST_REGEX"), Self::NonGlobalIp { domain, ip, } => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"), } } } impl std::error::Error for CustomResolverError {} #[derive(Debug, Clone)] pub enum CustomDnsResolver { Default(), Hickory(Arc), } type BoxError = Box; impl CustomDnsResolver { pub fn instance() -> Arc { static INSTANCE: Lazy> = Lazy::new(CustomDnsResolver::new); Arc::clone(&*INSTANCE) } fn new() -> Arc { match read_system_conf() { Ok((config, opts)) => { let resolver = TokioAsyncResolver::tokio(config.clone(), opts.clone()); Arc::new(Self::Hickory(Arc::new(resolver))) } Err(e) => { warn!("Error creating Hickory resolver, falling back to default: {e:?}"); Arc::new(Self::Default()) } } } // Note that we get an iterator of addresses, but we only grab the first one for convenience async fn resolve_domain(&self, name: &str) -> Result, BoxError> { pre_resolve(name)?; let result = match self { Self::Default() => tokio::net::lookup_host(name).await?.next(), Self::Hickory(r) => r.lookup_ip(name).await?.iter().next().map(|a| SocketAddr::new(a, 0)), }; if let Some(addr) = &result { post_resolve(name, addr.ip())?; } Ok(result) } } fn pre_resolve(name: &str) -> Result<(), CustomResolverError> { if crate::api::is_domain_blacklisted(name) { return Err(CustomResolverError::Blacklist { domain: name.to_string(), }); } Ok(()) } fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomResolverError> { if CONFIG.icon_blacklist_non_global_ips() && !is_global(ip) { Err(CustomResolverError::NonGlobalIp { domain: name.to_string(), ip, }) } else { Ok(()) } } impl Resolve for CustomDnsResolver { fn resolve(&self, name: Name) -> Resolving { let this = self.clone(); Box::pin(async move { let name = name.as_str(); let result = this.resolve_domain(name).await?; Ok::(Box::new(result.into_iter())) }) } } } pub use dns_resolver::{CustomDnsResolver, CustomResolverError}; /// TODO: This is extracted from IpAddr::is_global, which is unstable: /// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global /// Remove once https://github.com/rust-lang/rust/issues/27709 is merged #[allow(clippy::nonminimal_bool)] #[cfg(any(not(feature = "unstable"), test))] pub fn is_global_hardcoded(ip: std::net::IpAddr) -> bool { match ip { std::net::IpAddr::V4(ip) => { !(ip.octets()[0] == 0 // "This network" || ip.is_private() || (ip.octets()[0] == 100 && (ip.octets()[1] & 0b1100_0000 == 0b0100_0000)) //ip.is_shared() || ip.is_loopback() || ip.is_link_local() // addresses reserved for future protocols (`192.0.0.0/24`) ||(ip.octets()[0] == 192 && ip.octets()[1] == 0 && ip.octets()[2] == 0) || ip.is_documentation() || (ip.octets()[0] == 198 && (ip.octets()[1] & 0xfe) == 18) // ip.is_benchmarking() || (ip.octets()[0] & 240 == 240 && !ip.is_broadcast()) //ip.is_reserved() || ip.is_broadcast()) } std::net::IpAddr::V6(ip) => { !(ip.is_unspecified() || ip.is_loopback() // IPv4-mapped Address (`::ffff:0:0/96`) || matches!(ip.segments(), [0, 0, 0, 0, 0, 0xffff, _, _]) // IPv4-IPv6 Translat. (`64:ff9b:1::/48`) || matches!(ip.segments(), [0x64, 0xff9b, 1, _, _, _, _, _]) // Discard-Only Address Block (`100::/64`) || matches!(ip.segments(), [0x100, 0, 0, 0, _, _, _, _]) // IETF Protocol Assignments (`2001::/23`) || (matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if b < 0x200) && !( // Port Control Protocol Anycast (`2001:1::1`) u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0001 // Traversal Using Relays around NAT Anycast (`2001:1::2`) || u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0002 // AMT (`2001:3::/32`) || matches!(ip.segments(), [0x2001, 3, _, _, _, _, _, _]) // AS112-v6 (`2001:4:112::/48`) || matches!(ip.segments(), [0x2001, 4, 0x112, _, _, _, _, _]) // ORCHIDv2 (`2001:20::/28`) || matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if (0x20..=0x2F).contains(&b)) )) || ((ip.segments()[0] == 0x2001) && (ip.segments()[1] == 0xdb8)) // ip.is_documentation() || ((ip.segments()[0] & 0xfe00) == 0xfc00) //ip.is_unique_local() || ((ip.segments()[0] & 0xffc0) == 0xfe80)) //ip.is_unicast_link_local() } } } #[cfg(not(feature = "unstable"))] pub use is_global_hardcoded as is_global; #[cfg(feature = "unstable")] #[inline(always)] pub fn is_global(ip: std::net::IpAddr) -> bool { ip.is_global() } /// These are some tests to check that the implementations match /// The IPv4 can be all checked in 30 seconds or so and they are correct as of nightly 2023-07-17 /// The IPV6 can't be checked in a reasonable time, so we check over a hundred billion random ones, so far correct /// Note that the is_global implementation is subject to change as new IP RFCs are created /// /// To run while showing progress output: /// cargo +nightly test --release --features sqlite,unstable -- --nocapture --ignored #[cfg(test)] #[cfg(feature = "unstable")] mod tests { use super::*; use std::net::IpAddr; #[test] #[ignore] fn test_ipv4_global() { for a in 0..u8::MAX { println!("Iter: {}/255", a); for b in 0..u8::MAX { for c in 0..u8::MAX { for d in 0..u8::MAX { let ip = IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d)); assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {}", ip) } } } } } #[test] #[ignore] fn test_ipv6_global() { use rand::Rng; std::thread::scope(|s| { for t in 0..16 { let handle = s.spawn(move || { let mut v = [0u8; 16]; let mut rng = rand::thread_rng(); for i in 0..20 { println!("Thread {t} Iter: {i}/50"); for _ in 0..500_000_000 { rng.fill(&mut v); let ip = IpAddr::V6(std::net::Ipv6Addr::from(v)); assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {ip}"); } } }); } }); } }