From edfa82624c96f904c155592fc0f54112544042aa Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 20 May 2018 14:38:57 -0700 Subject: [PATCH] Switch protocol filters to min/max --- Cargo.toml | 4 + build.rs | 19 +++-- src/imp/openssl.rs | 134 ++++++++++++++++++++++++++-------- src/imp/schannel.rs | 62 ++++++++++------ src/imp/security_framework.rs | 62 +++++++++------- src/lib.rs | 40 ++++++---- src/test.rs | 10 ++- 7 files changed, 226 insertions(+), 105 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 33b1cdb..7f2f8c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,7 @@ schannel = "0.1.12" [target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies] openssl = "0.10.6" openssl-sys = "0.9.30" + +[patch.crates-io] +openssl = { git = "https://github.com/sfackler/rust-openssl" } +openssl-sys = { git = "https://github.com/sfackler/rust-openssl" } diff --git a/build.rs b/build.rs index aa5fc10..cbac306 100644 --- a/build.rs +++ b/build.rs @@ -1,12 +1,19 @@ use std::env; fn main() { - let openssl_version = env::var("DEP_OPENSSL_VERSION_NUMBER") - .ok() - .map(|s| u64::from_str_radix(&s, 16).unwrap()); + if let Ok(version) = env::var("DEP_OPENSSL_VERSION_NUMBER") { + let version = u64::from_str_radix(&version, 16).unwrap(); - match openssl_version { - Some(version) if version >= 0x1_00_02_00_0 => println!("cargo:rustc-cfg=have_no_ssl_mask"), - _ => {} + if version >= 0x1_01_00_00_0 { + println!("cargo:rustc-cfg=have_min_max_version"); + } + } + + if let Ok(version) = env::var("DEP_OPENSSL_LIBRESSL_VERSION_NUMBER") { + let version = u64::from_str_radix(&version, 16).unwrap(); + + if version >= 0x2_06_01_00_0 { + println!("cargo:rustc-cfg=have_min_max_version"); + } } } diff --git a/src/imp/openssl.rs b/src/imp/openssl.rs index 0e84bd1..aac1b13 100644 --- a/src/imp/openssl.rs +++ b/src/imp/openssl.rs @@ -4,7 +4,7 @@ use self::openssl::error::ErrorStack; use self::openssl::pkcs12::{ParsedPkcs12, Pkcs12}; use self::openssl::ssl::{ self, MidHandshakeSslStream, SslAcceptor, SslAcceptorBuilder, SslConnector, - SslConnectorBuilder, SslContextBuilder, SslMethod, SslOptions, SslVerifyMode, + SslConnectorBuilder, SslContextBuilder, SslMethod, SslVerifyMode, }; use self::openssl::x509::X509; use std::error; @@ -13,26 +13,67 @@ use std::io; use Protocol; -fn supported_protocols(protocols: &[Protocol], ctx: &mut SslContextBuilder) { - #[cfg(not(have_no_ssl_mask))] - let no_ssl_mask = SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1 - | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2; - #[cfg(have_no_ssl_mask)] - let no_ssl_mask = SslOptions::NO_SSL_MASK; +#[cfg(have_min_max_version)] +fn supported_protocols( + min: Option, + max: Option, + ctx: &mut SslContextBuilder, +) -> Result<(), ErrorStack> { + use self::openssl::ssl::SslVersion; + + fn cvt(p: Protocol) -> SslVersion { + match p { + Protocol::Sslv3 => SslVersion::SSL3, + Protocol::Tlsv10 => SslVersion::TLS1, + Protocol::Tlsv11 => SslVersion::TLS1_1, + Protocol::Tlsv12 => SslVersion::TLS1_2, + Protocol::__NonExhaustive => unreachable!(), + } + } + + ctx.set_min_proto_version(min.map(cvt))?; + ctx.set_max_proto_version(max.map(cvt))?; + + Ok(()) +} + +#[cfg(not(have_min_max_version))] +fn supported_protocols( + min: Option, + max: Option, + ctx: &mut SslContextBuilder, +) -> Result<(), ErrorStack> { + use self::openssl::ssl::SslOptions; + + let no_ssl_mask = SslOptions::NO_SSLV2 + | SslOptions::NO_SSLV3 + | SslOptions::NO_TLSV1 + | SslOptions::NO_TLSV1_1 + | SslOptions::NO_TLSV1_2; ctx.clear_options(no_ssl_mask); - let mut options = no_ssl_mask; - for protocol in protocols { - let op = match *protocol { - Protocol::Sslv3 => SslOptions::NO_SSLV3, - Protocol::Tlsv10 => SslOptions::NO_TLSV1, - Protocol::Tlsv11 => SslOptions::NO_TLSV1_1, - Protocol::Tlsv12 => SslOptions::NO_TLSV1_2, - Protocol::__NonExhaustive => unreachable!(), - }; - options &= !op; + let mut options = SslOptions::NO_SSLV2; + match min { + None | Some(Protocol::Sslv3) => {} + Some(Protocol::Tlsv10) => options |= SslOptions::NO_SSLV3, + Some(Protocol::Tlsv11) => options |= SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1, + Some(Protocol::Tlsv12) => { + options |= SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 + } + Some(Protocol::__NonExhaustive) => unreachable!(), } + match max { + None | Some(Protocol::Tlsv12) => {} + Some(Protocol::Tlsv11) => options |= SslOptions::NO_TLSV1_2, + Some(Protocol::Tlsv10) => options |= SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2, + Some(Protocol::Sslv3) => { + options |= SslOptions::NO_TLSV1_0 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2 + } + } + ctx.set_options(options); + + Ok(()) } pub struct Error(ssl::Error); @@ -156,6 +197,8 @@ pub struct TlsConnectorBuilder { use_sni: bool, accept_invalid_hostnames: bool, accept_invalid_certs: bool, + min_protocol: Option, + max_protocol: Option, } impl TlsConnectorBuilder { @@ -189,12 +232,19 @@ impl TlsConnectorBuilder { self.accept_invalid_certs = accept_invalid_certs; } - pub fn supported_protocols(&mut self, protocols: &[Protocol]) -> Result<(), Error> { - supported_protocols(protocols, &mut self.connector); + pub fn min_protocol_version(&mut self, protocol: Option) -> Result<(), Error> { + self.min_protocol = protocol; Ok(()) } - pub fn build(self) -> Result { + pub fn max_protocol_version(&mut self, protocol: Option) -> Result<(), Error> { + self.max_protocol = protocol; + Ok(()) + } + + pub fn build(mut self) -> Result { + supported_protocols(self.min_protocol, self.max_protocol, &mut self.connector)?; + Ok(TlsConnector { connector: self.connector.build(), use_sni: self.use_sni, @@ -219,6 +269,8 @@ impl TlsConnector { use_sni: true, accept_invalid_hostnames: false, accept_invalid_certs: false, + min_protocol: None, + max_protocol: None, }) } @@ -226,7 +278,8 @@ impl TlsConnector { where S: io::Read + io::Write, { - let mut ssl = self.connector + let mut ssl = self + .connector .configure()? .use_server_name_indication(self.use_sni) .verify_hostname(!self.accept_invalid_hostnames); @@ -239,16 +292,27 @@ impl TlsConnector { } } -pub struct TlsAcceptorBuilder(SslAcceptorBuilder); +pub struct TlsAcceptorBuilder { + acceptor: SslAcceptorBuilder, + min_protocol: Option, + max_protocol: Option, +} impl TlsAcceptorBuilder { - pub fn supported_protocols(&mut self, protocols: &[Protocol]) -> Result<(), Error> { - supported_protocols(protocols, &mut self.0); + pub fn min_protocol_version(&mut self, protocol: Option) -> Result<(), Error> { + self.min_protocol = protocol; Ok(()) } - pub fn build(self) -> Result { - Ok(TlsAcceptor(self.0.build())) + pub fn max_protocol_version(&mut self, protocol: Option) -> Result<(), Error> { + self.max_protocol = protocol; + Ok(()) + } + + pub fn build(mut self) -> Result { + supported_protocols(self.min_protocol, self.max_protocol, &mut self.acceptor)?; + + Ok(TlsAcceptor(self.acceptor.build())) } } @@ -257,15 +321,20 @@ pub struct TlsAcceptor(SslAcceptor); impl TlsAcceptor { pub fn builder(identity: Identity) -> Result { - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; - builder.set_private_key(&identity.0.pkey)?; - builder.set_certificate(&identity.0.cert)?; + let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; + acceptor.set_private_key(&identity.0.pkey)?; + acceptor.set_certificate(&identity.0.cert)?; if let Some(chain) = identity.0.chain { for cert in chain { - builder.add_extra_chain_cert(cert)?; + acceptor.add_extra_chain_cert(cert)?; } } - Ok(TlsAcceptorBuilder(builder)) + + Ok(TlsAcceptorBuilder { + acceptor, + min_protocol: None, + max_protocol: None, + }) } pub fn accept(&self, stream: S) -> Result, HandshakeError> @@ -294,7 +363,8 @@ impl TlsStream { match self.0.shutdown() { Ok(_) => Ok(()), Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()), - Err(e) => Err(e.into_io_error() + Err(e) => Err(e + .into_io_error() .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))), } } diff --git a/src/imp/schannel.rs b/src/imp/schannel.rs index dd213ef..369a899 100644 --- a/src/imp/schannel.rs +++ b/src/imp/schannel.rs @@ -8,17 +8,22 @@ use std::error; use std::fmt; use std::io; -fn convert_protocols(protocols: &[::Protocol]) -> Vec { +static PROTOCOLS: &'static [Protocol] = &[ + Protocol::Ssl3, + Protocol::Tls10, + Protocol::Tls11, + Protocol::Tls12, +]; + +fn convert_protocols(min: Option<::Protocol>, max: Option<::Protocol>) -> &'static [Protocol] { + let mut protocols = PROTOCOLS; + if let Some(p) = max.and_then(|max| protocols.get(..max as usize)) { + protocols = p; + } + if let Some(p) = min.and_then(|min| protocols.get(min as usize..)) { + protocols = p; + } protocols - .iter() - .map(|p| match *p { - ::Protocol::Sslv3 => Protocol::Ssl3, - ::Protocol::Tlsv10 => Protocol::Tls10, - ::Protocol::Tlsv11 => Protocol::Tls11, - ::Protocol::Tlsv12 => Protocol::Tls12, - ::Protocol::__NonExhaustive => unreachable!(), - }) - .collect() } pub struct Error(io::Error); @@ -61,7 +66,8 @@ impl Identity { let mut identity = None; for cert in store.certs() { - if cert.private_key() + if cert + .private_key() .silent(true) .compare_key(true) .acquire() @@ -186,8 +192,13 @@ impl TlsConnectorBuilder { self.0.accept_invalid_certs = accept_invalid_certs; } - pub fn supported_protocols(&mut self, protocols: &[::Protocol]) -> Result<(), Error> { - self.0.protocols = convert_protocols(protocols); + pub fn min_protocol_version(&mut self, protocol: Option<::Protocol>) -> Result<(), Error> { + self.0.min_protocol = protocol; + Ok(()) + } + + pub fn max_protocol_version(&mut self, protocol: Option<::Protocol>) -> Result<(), Error> { + self.0.max_protocol = protocol; Ok(()) } @@ -200,7 +211,8 @@ impl TlsConnectorBuilder { pub struct TlsConnector { cert: Option, roots: CertStore, - protocols: Vec, + min_protocol: Option<::Protocol>, + max_protocol: Option<::Protocol>, use_sni: bool, accept_invalid_hostnames: bool, accept_invalid_certs: bool, @@ -211,7 +223,8 @@ impl TlsConnector { Ok(TlsConnectorBuilder(TlsConnector { cert: None, roots: Memory::new()?.into_store(), - protocols: vec![Protocol::Tls10, Protocol::Tls11, Protocol::Tls12], + min_protocol: None, + max_protocol: None, use_sni: true, accept_invalid_hostnames: false, accept_invalid_certs: false, @@ -223,7 +236,7 @@ impl TlsConnector { S: io::Read + io::Write, { let mut builder = SchannelCred::builder(); - builder.enabled_protocols(&self.protocols); + builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol)); if let Some(cert) = self.cert.as_ref() { builder.cert(cert.clone()); } @@ -247,8 +260,13 @@ impl TlsConnector { pub struct TlsAcceptorBuilder(TlsAcceptor); impl TlsAcceptorBuilder { - pub fn supported_protocols(&mut self, protocols: &[::Protocol]) -> Result<(), Error> { - self.0.protocols = convert_protocols(protocols); + pub fn min_protocol_version(&mut self, protocol: Option<::Protocol>) -> Result<(), Error> { + self.0.min_protocol = protocol; + Ok(()) + } + + pub fn max_protocol_version(&mut self, protocol: Option<::Protocol>) -> Result<(), Error> { + self.0.max_protocol = protocol; Ok(()) } @@ -260,14 +278,16 @@ impl TlsAcceptorBuilder { #[derive(Clone)] pub struct TlsAcceptor { cert: CertContext, - protocols: Vec, + min_protocol: Option<::Protocol>, + max_protocol: Option<::Protocol>, } impl TlsAcceptor { pub fn builder(identity: Identity) -> Result { Ok(TlsAcceptorBuilder(TlsAcceptor { cert: identity.cert, - protocols: vec![Protocol::Tls10, Protocol::Tls11, Protocol::Tls12], + min_protocol: None, + max_protocol: None, })) } @@ -276,7 +296,7 @@ impl TlsAcceptor { S: io::Read + io::Write, { let mut builder = SchannelCred::builder(); - builder.enabled_protocols(&self.protocols); + builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol)); builder.cert(self.cert.clone()); // FIXME we're probably missing the certificate chain? let cred = builder.acquire(Direction::Inbound)?; diff --git a/src/imp/security_framework.rs b/src/imp/security_framework.rs index 8bc0659..a1f35ad 100644 --- a/src/imp/security_framework.rs +++ b/src/imp/security_framework.rs @@ -44,20 +44,6 @@ fn convert_protocol(protocol: Protocol) -> SslProtocol { } } -fn protocol_min_max(protocols: &[Protocol]) -> (SslProtocol, SslProtocol) { - let mut min = Protocol::Tlsv12; - let mut max = Protocol::Sslv3; - for protocol in protocols { - if (*protocol as usize) < (min as usize) { - min = *protocol; - } - if (*protocol as usize) > (max as usize) { - max = *protocol; - } - } - (convert_protocol(min), convert_protocol(max)) -} - pub struct Error(base::Error); impl error::Error for Error { @@ -287,8 +273,13 @@ impl TlsConnectorBuilder { self.0.danger_accept_invalid_certs = accept_invalid_certs; } - pub fn supported_protocols(&mut self, protocols: &[Protocol]) -> Result<(), Error> { - self.0.protocols = protocols.to_vec(); + pub fn min_protocol_version(&mut self, protocol: Option) -> Result<(), Error> { + self.0.min_protocol = protocol; + Ok(()) + } + + pub fn max_protocol_version(&mut self, protocol: Option) -> Result<(), Error> { + self.0.max_protocol = protocol; Ok(()) } @@ -300,7 +291,8 @@ impl TlsConnectorBuilder { #[derive(Clone)] pub struct TlsConnector { identity: Option, - protocols: Vec, + min_protocol: Option, + max_protocol: Option, roots: Vec, use_sni: bool, danger_accept_invalid_hostnames: bool, @@ -311,7 +303,8 @@ impl TlsConnector { pub fn builder() -> Result { Ok(TlsConnectorBuilder(TlsConnector { identity: None, - protocols: vec![Protocol::Tlsv10, Protocol::Tlsv11, Protocol::Tlsv12], + min_protocol: None, + max_protocol: None, roots: vec![], use_sni: true, danger_accept_invalid_hostnames: false, @@ -324,9 +317,12 @@ impl TlsConnector { S: io::Read + io::Write, { let mut builder = ClientBuilder::new(); - let (min, max) = protocol_min_max(&self.protocols); - builder.protocol_min(min); - builder.protocol_max(max); + if let Some(min) = self.min_protocol { + builder.protocol_min(convert_protocol(min)); + } + if let Some(max) = self.max_protocol { + builder.protocol_max(convert_protocol(max)); + } if let Some(identity) = self.identity.as_ref() { builder.identity(&identity.identity, &identity.chain); } @@ -345,8 +341,13 @@ impl TlsConnector { pub struct TlsAcceptorBuilder(TlsAcceptor); impl TlsAcceptorBuilder { - pub fn supported_protocols(&mut self, protocols: &[Protocol]) -> Result<(), Error> { - self.0.protocols = protocols.to_vec(); + pub fn min_protocol_version(&mut self, protocol: Option) -> Result<(), Error> { + self.0.min_protocol = protocol; + Ok(()) + } + + pub fn max_protocol_version(&mut self, protocol: Option) -> Result<(), Error> { + self.0.max_protocol = protocol; Ok(()) } @@ -358,14 +359,16 @@ impl TlsAcceptorBuilder { #[derive(Clone)] pub struct TlsAcceptor { identity: Identity, - protocols: Vec, + min_protocol: Option, + max_protocol: Option, } impl TlsAcceptor { pub fn builder(identity: Identity) -> Result { Ok(TlsAcceptorBuilder(TlsAcceptor { identity, - protocols: vec![Protocol::Tlsv10, Protocol::Tlsv11, Protocol::Tlsv12], + min_protocol: None, + max_protocol: None, })) } @@ -375,9 +378,12 @@ impl TlsAcceptor { { let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?; - let (min, max) = protocol_min_max(&self.protocols); - ctx.set_protocol_version_min(min)?; - ctx.set_protocol_version_max(max)?; + if let Some(min) = self.min_protocol { + ctx.set_protocol_version_min(convert_protocol(min))?; + } + if let Some(max) = self.max_protocol { + ctx.set_protocol_version_max(convert_protocol(max))?; + } ctx.set_certificate(&self.identity.identity, &self.identity.chain)?; match ctx.handshake(stream) { Ok(s) => Ok(TlsStream(s)), diff --git a/src/lib.rs b/src/lib.rs index 4fc1e40..0209ff0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -336,15 +336,21 @@ impl TlsConnectorBuilder { Ok(self) } - /// Sets the protocols which the connector will support. - /// - /// The protocols supported by default are currently TLS 1.0, TLS 1.1, and TLS 1.2, though this - /// is subject to change. - pub fn supported_protocols( + /// Sets the minimum supported protocol version. + pub fn min_protocol_version( &mut self, - protocols: &[Protocol], + protocol: Option, ) -> Result<&mut TlsConnectorBuilder> { - self.0.supported_protocols(protocols)?; + self.0.min_protocol_version(protocol)?; + Ok(self) + } + + /// Sets the minimum supported protocol version. + pub fn max_protocol_version( + &mut self, + protocol: Option, + ) -> Result<&mut TlsConnectorBuilder> { + self.0.max_protocol_version(protocol)?; Ok(self) } @@ -452,15 +458,21 @@ impl TlsConnector { pub struct TlsAcceptorBuilder(imp::TlsAcceptorBuilder); impl TlsAcceptorBuilder { - /// Sets the protocols which the acceptor will support. - /// - /// The protocols supported by default are currently TLS 1.0, TLS 1.1, and TLS 1.2, though this - /// is subject to change. - pub fn supported_protocols( + /// Sets the minimum supported protocol version. + pub fn min_protocol_version( &mut self, - protocols: &[Protocol], + protocol: Option, ) -> Result<&mut TlsAcceptorBuilder> { - self.0.supported_protocols(protocols)?; + self.0.min_protocol_version(protocol)?; + Ok(self) + } + + /// Sets the minimum supported protocol version. + pub fn max_protocol_version( + &mut self, + protocol: Option, + ) -> Result<&mut TlsAcceptorBuilder> { + self.0.max_protocol_version(protocol)?; Ok(self) } diff --git a/src/test.rs b/src/test.rs index 8748aac..fd4e346 100644 --- a/src/test.rs +++ b/src/test.rs @@ -133,7 +133,8 @@ mod tests { let buf = include_bytes!("../test/identity.p12"); let identity = p!(Identity::from_pkcs12(buf, "mypass")); let mut builder = p!(TlsAcceptor::builder(identity)); - p!(builder.supported_protocols(&[Protocol::Tlsv11])); + p!(builder.min_protocol_version(Some(Protocol::Tlsv11))); + p!(builder.max_protocol_version(Some(Protocol::Tlsv11))); let builder = p!(builder.build()); let listener = p!(TcpListener::bind("0.0.0.0:0")); @@ -156,7 +157,8 @@ mod tests { let socket = p!(TcpStream::connect(("localhost", port))); let mut builder = p!(TlsConnector::builder()); p!(builder.add_root_certificate(root_ca)); - p!(builder.supported_protocols(&[Protocol::Tlsv11])); + p!(builder.min_protocol_version(Some(Protocol::Tlsv11))); + p!(builder.max_protocol_version(Some(Protocol::Tlsv11))); let builder = p!(builder.build()); let mut socket = p!(builder.connect("foobar.com", socket)); @@ -173,7 +175,7 @@ mod tests { let buf = include_bytes!("../test/identity.p12"); let identity = p!(Identity::from_pkcs12(buf, "mypass")); let mut builder = p!(TlsAcceptor::builder(identity)); - p!(builder.supported_protocols(&[Protocol::Tlsv12])); + p!(builder.min_protocol_version(Some(Protocol::Tlsv12))); let builder = p!(builder.build()); let listener = p!(TcpListener::bind("0.0.0.0:0")); @@ -190,7 +192,7 @@ mod tests { let socket = p!(TcpStream::connect(("localhost", port))); let mut builder = p!(TlsConnector::builder()); p!(builder.add_root_certificate(root_ca)); - p!(builder.supported_protocols(&[Protocol::Sslv3, Protocol::Tlsv10, Protocol::Tlsv11],)); + p!(builder.max_protocol_version(Some(Protocol::Tlsv11))); let builder = p!(builder.build()); assert!(builder.connect("foobar.com", socket).is_err());