Switch protocol filters to min/max

This commit is contained in:
Steven Fackler
2018-05-20 14:38:57 -07:00
parent 5e915be384
commit edfa82624c
7 changed files with 226 additions and 105 deletions
+4
View File
@@ -21,3 +21,7 @@ schannel = "0.1.12"
[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies]
openssl = "0.10.6"
openssl-sys = "0.9.30"
[patch.crates-io]
openssl = { git = "https://github.com/sfackler/rust-openssl" }
openssl-sys = { git = "https://github.com/sfackler/rust-openssl" }
+13 -6
View File
@@ -1,12 +1,19 @@
use std::env;
fn main() {
let openssl_version = env::var("DEP_OPENSSL_VERSION_NUMBER")
.ok()
.map(|s| u64::from_str_radix(&s, 16).unwrap());
if let Ok(version) = env::var("DEP_OPENSSL_VERSION_NUMBER") {
let version = u64::from_str_radix(&version, 16).unwrap();
match openssl_version {
Some(version) if version >= 0x1_00_02_00_0 => println!("cargo:rustc-cfg=have_no_ssl_mask"),
_ => {}
if version >= 0x1_01_00_00_0 {
println!("cargo:rustc-cfg=have_min_max_version");
}
}
if let Ok(version) = env::var("DEP_OPENSSL_LIBRESSL_VERSION_NUMBER") {
let version = u64::from_str_radix(&version, 16).unwrap();
if version >= 0x2_06_01_00_0 {
println!("cargo:rustc-cfg=have_min_max_version");
}
}
}
+102 -32
View File
@@ -4,7 +4,7 @@ use self::openssl::error::ErrorStack;
use self::openssl::pkcs12::{ParsedPkcs12, Pkcs12};
use self::openssl::ssl::{
self, MidHandshakeSslStream, SslAcceptor, SslAcceptorBuilder, SslConnector,
SslConnectorBuilder, SslContextBuilder, SslMethod, SslOptions, SslVerifyMode,
SslConnectorBuilder, SslContextBuilder, SslMethod, SslVerifyMode,
};
use self::openssl::x509::X509;
use std::error;
@@ -13,26 +13,67 @@ use std::io;
use Protocol;
fn supported_protocols(protocols: &[Protocol], ctx: &mut SslContextBuilder) {
#[cfg(not(have_no_ssl_mask))]
let no_ssl_mask = SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1
| SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2;
#[cfg(have_no_ssl_mask)]
let no_ssl_mask = SslOptions::NO_SSL_MASK;
#[cfg(have_min_max_version)]
fn supported_protocols(
min: Option<Protocol>,
max: Option<Protocol>,
ctx: &mut SslContextBuilder,
) -> Result<(), ErrorStack> {
use self::openssl::ssl::SslVersion;
fn cvt(p: Protocol) -> SslVersion {
match p {
Protocol::Sslv3 => SslVersion::SSL3,
Protocol::Tlsv10 => SslVersion::TLS1,
Protocol::Tlsv11 => SslVersion::TLS1_1,
Protocol::Tlsv12 => SslVersion::TLS1_2,
Protocol::__NonExhaustive => unreachable!(),
}
}
ctx.set_min_proto_version(min.map(cvt))?;
ctx.set_max_proto_version(max.map(cvt))?;
Ok(())
}
#[cfg(not(have_min_max_version))]
fn supported_protocols(
min: Option<Protocol>,
max: Option<Protocol>,
ctx: &mut SslContextBuilder,
) -> Result<(), ErrorStack> {
use self::openssl::ssl::SslOptions;
let no_ssl_mask = SslOptions::NO_SSLV2
| SslOptions::NO_SSLV3
| SslOptions::NO_TLSV1
| SslOptions::NO_TLSV1_1
| SslOptions::NO_TLSV1_2;
ctx.clear_options(no_ssl_mask);
let mut options = no_ssl_mask;
for protocol in protocols {
let op = match *protocol {
Protocol::Sslv3 => SslOptions::NO_SSLV3,
Protocol::Tlsv10 => SslOptions::NO_TLSV1,
Protocol::Tlsv11 => SslOptions::NO_TLSV1_1,
Protocol::Tlsv12 => SslOptions::NO_TLSV1_2,
Protocol::__NonExhaustive => unreachable!(),
};
options &= !op;
let mut options = SslOptions::NO_SSLV2;
match min {
None | Some(Protocol::Sslv3) => {}
Some(Protocol::Tlsv10) => options |= SslOptions::NO_SSLV3,
Some(Protocol::Tlsv11) => options |= SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1,
Some(Protocol::Tlsv12) => {
options |= SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1
}
Some(Protocol::__NonExhaustive) => unreachable!(),
}
match max {
None | Some(Protocol::Tlsv12) => {}
Some(Protocol::Tlsv11) => options |= SslOptions::NO_TLSV1_2,
Some(Protocol::Tlsv10) => options |= SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2,
Some(Protocol::Sslv3) => {
options |= SslOptions::NO_TLSV1_0 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2
}
}
ctx.set_options(options);
Ok(())
}
pub struct Error(ssl::Error);
@@ -156,6 +197,8 @@ pub struct TlsConnectorBuilder {
use_sni: bool,
accept_invalid_hostnames: bool,
accept_invalid_certs: bool,
min_protocol: Option<Protocol>,
max_protocol: Option<Protocol>,
}
impl TlsConnectorBuilder {
@@ -189,12 +232,19 @@ impl TlsConnectorBuilder {
self.accept_invalid_certs = accept_invalid_certs;
}
pub fn supported_protocols(&mut self, protocols: &[Protocol]) -> Result<(), Error> {
supported_protocols(protocols, &mut self.connector);
pub fn min_protocol_version(&mut self, protocol: Option<Protocol>) -> Result<(), Error> {
self.min_protocol = protocol;
Ok(())
}
pub fn build(self) -> Result<TlsConnector, Error> {
pub fn max_protocol_version(&mut self, protocol: Option<Protocol>) -> Result<(), Error> {
self.max_protocol = protocol;
Ok(())
}
pub fn build(mut self) -> Result<TlsConnector, Error> {
supported_protocols(self.min_protocol, self.max_protocol, &mut self.connector)?;
Ok(TlsConnector {
connector: self.connector.build(),
use_sni: self.use_sni,
@@ -219,6 +269,8 @@ impl TlsConnector {
use_sni: true,
accept_invalid_hostnames: false,
accept_invalid_certs: false,
min_protocol: None,
max_protocol: None,
})
}
@@ -226,7 +278,8 @@ impl TlsConnector {
where
S: io::Read + io::Write,
{
let mut ssl = self.connector
let mut ssl = self
.connector
.configure()?
.use_server_name_indication(self.use_sni)
.verify_hostname(!self.accept_invalid_hostnames);
@@ -239,16 +292,27 @@ impl TlsConnector {
}
}
pub struct TlsAcceptorBuilder(SslAcceptorBuilder);
pub struct TlsAcceptorBuilder {
acceptor: SslAcceptorBuilder,
min_protocol: Option<Protocol>,
max_protocol: Option<Protocol>,
}
impl TlsAcceptorBuilder {
pub fn supported_protocols(&mut self, protocols: &[Protocol]) -> Result<(), Error> {
supported_protocols(protocols, &mut self.0);
pub fn min_protocol_version(&mut self, protocol: Option<Protocol>) -> Result<(), Error> {
self.min_protocol = protocol;
Ok(())
}
pub fn build(self) -> Result<TlsAcceptor, Error> {
Ok(TlsAcceptor(self.0.build()))
pub fn max_protocol_version(&mut self, protocol: Option<Protocol>) -> Result<(), Error> {
self.max_protocol = protocol;
Ok(())
}
pub fn build(mut self) -> Result<TlsAcceptor, Error> {
supported_protocols(self.min_protocol, self.max_protocol, &mut self.acceptor)?;
Ok(TlsAcceptor(self.acceptor.build()))
}
}
@@ -257,15 +321,20 @@ pub struct TlsAcceptor(SslAcceptor);
impl TlsAcceptor {
pub fn builder(identity: Identity) -> Result<TlsAcceptorBuilder, Error> {
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
builder.set_private_key(&identity.0.pkey)?;
builder.set_certificate(&identity.0.cert)?;
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
acceptor.set_private_key(&identity.0.pkey)?;
acceptor.set_certificate(&identity.0.cert)?;
if let Some(chain) = identity.0.chain {
for cert in chain {
builder.add_extra_chain_cert(cert)?;
acceptor.add_extra_chain_cert(cert)?;
}
}
Ok(TlsAcceptorBuilder(builder))
Ok(TlsAcceptorBuilder {
acceptor,
min_protocol: None,
max_protocol: None,
})
}
pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
@@ -294,7 +363,8 @@ impl<S: io::Read + io::Write> TlsStream<S> {
match self.0.shutdown() {
Ok(_) => Ok(()),
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()),
Err(e) => Err(e.into_io_error()
Err(e) => Err(e
.into_io_error()
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))),
}
}
+41 -21
View File
@@ -8,17 +8,22 @@ use std::error;
use std::fmt;
use std::io;
fn convert_protocols(protocols: &[::Protocol]) -> Vec<Protocol> {
static PROTOCOLS: &'static [Protocol] = &[
Protocol::Ssl3,
Protocol::Tls10,
Protocol::Tls11,
Protocol::Tls12,
];
fn convert_protocols(min: Option<::Protocol>, max: Option<::Protocol>) -> &'static [Protocol] {
let mut protocols = PROTOCOLS;
if let Some(p) = max.and_then(|max| protocols.get(..max as usize)) {
protocols = p;
}
if let Some(p) = min.and_then(|min| protocols.get(min as usize..)) {
protocols = p;
}
protocols
.iter()
.map(|p| match *p {
::Protocol::Sslv3 => Protocol::Ssl3,
::Protocol::Tlsv10 => Protocol::Tls10,
::Protocol::Tlsv11 => Protocol::Tls11,
::Protocol::Tlsv12 => Protocol::Tls12,
::Protocol::__NonExhaustive => unreachable!(),
})
.collect()
}
pub struct Error(io::Error);
@@ -61,7 +66,8 @@ impl Identity {
let mut identity = None;
for cert in store.certs() {
if cert.private_key()
if cert
.private_key()
.silent(true)
.compare_key(true)
.acquire()
@@ -186,8 +192,13 @@ impl TlsConnectorBuilder {
self.0.accept_invalid_certs = accept_invalid_certs;
}
pub fn supported_protocols(&mut self, protocols: &[::Protocol]) -> Result<(), Error> {
self.0.protocols = convert_protocols(protocols);
pub fn min_protocol_version(&mut self, protocol: Option<::Protocol>) -> Result<(), Error> {
self.0.min_protocol = protocol;
Ok(())
}
pub fn max_protocol_version(&mut self, protocol: Option<::Protocol>) -> Result<(), Error> {
self.0.max_protocol = protocol;
Ok(())
}
@@ -200,7 +211,8 @@ impl TlsConnectorBuilder {
pub struct TlsConnector {
cert: Option<CertContext>,
roots: CertStore,
protocols: Vec<Protocol>,
min_protocol: Option<::Protocol>,
max_protocol: Option<::Protocol>,
use_sni: bool,
accept_invalid_hostnames: bool,
accept_invalid_certs: bool,
@@ -211,7 +223,8 @@ impl TlsConnector {
Ok(TlsConnectorBuilder(TlsConnector {
cert: None,
roots: Memory::new()?.into_store(),
protocols: vec![Protocol::Tls10, Protocol::Tls11, Protocol::Tls12],
min_protocol: None,
max_protocol: None,
use_sni: true,
accept_invalid_hostnames: false,
accept_invalid_certs: false,
@@ -223,7 +236,7 @@ impl TlsConnector {
S: io::Read + io::Write,
{
let mut builder = SchannelCred::builder();
builder.enabled_protocols(&self.protocols);
builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol));
if let Some(cert) = self.cert.as_ref() {
builder.cert(cert.clone());
}
@@ -247,8 +260,13 @@ impl TlsConnector {
pub struct TlsAcceptorBuilder(TlsAcceptor);
impl TlsAcceptorBuilder {
pub fn supported_protocols(&mut self, protocols: &[::Protocol]) -> Result<(), Error> {
self.0.protocols = convert_protocols(protocols);
pub fn min_protocol_version(&mut self, protocol: Option<::Protocol>) -> Result<(), Error> {
self.0.min_protocol = protocol;
Ok(())
}
pub fn max_protocol_version(&mut self, protocol: Option<::Protocol>) -> Result<(), Error> {
self.0.max_protocol = protocol;
Ok(())
}
@@ -260,14 +278,16 @@ impl TlsAcceptorBuilder {
#[derive(Clone)]
pub struct TlsAcceptor {
cert: CertContext,
protocols: Vec<Protocol>,
min_protocol: Option<::Protocol>,
max_protocol: Option<::Protocol>,
}
impl TlsAcceptor {
pub fn builder(identity: Identity) -> Result<TlsAcceptorBuilder, Error> {
Ok(TlsAcceptorBuilder(TlsAcceptor {
cert: identity.cert,
protocols: vec![Protocol::Tls10, Protocol::Tls11, Protocol::Tls12],
min_protocol: None,
max_protocol: None,
}))
}
@@ -276,7 +296,7 @@ impl TlsAcceptor {
S: io::Read + io::Write,
{
let mut builder = SchannelCred::builder();
builder.enabled_protocols(&self.protocols);
builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol));
builder.cert(self.cert.clone());
// FIXME we're probably missing the certificate chain?
let cred = builder.acquire(Direction::Inbound)?;
+34 -28
View File
@@ -44,20 +44,6 @@ fn convert_protocol(protocol: Protocol) -> SslProtocol {
}
}
fn protocol_min_max(protocols: &[Protocol]) -> (SslProtocol, SslProtocol) {
let mut min = Protocol::Tlsv12;
let mut max = Protocol::Sslv3;
for protocol in protocols {
if (*protocol as usize) < (min as usize) {
min = *protocol;
}
if (*protocol as usize) > (max as usize) {
max = *protocol;
}
}
(convert_protocol(min), convert_protocol(max))
}
pub struct Error(base::Error);
impl error::Error for Error {
@@ -287,8 +273,13 @@ impl TlsConnectorBuilder {
self.0.danger_accept_invalid_certs = accept_invalid_certs;
}
pub fn supported_protocols(&mut self, protocols: &[Protocol]) -> Result<(), Error> {
self.0.protocols = protocols.to_vec();
pub fn min_protocol_version(&mut self, protocol: Option<Protocol>) -> Result<(), Error> {
self.0.min_protocol = protocol;
Ok(())
}
pub fn max_protocol_version(&mut self, protocol: Option<Protocol>) -> Result<(), Error> {
self.0.max_protocol = protocol;
Ok(())
}
@@ -300,7 +291,8 @@ impl TlsConnectorBuilder {
#[derive(Clone)]
pub struct TlsConnector {
identity: Option<Identity>,
protocols: Vec<Protocol>,
min_protocol: Option<Protocol>,
max_protocol: Option<Protocol>,
roots: Vec<SecCertificate>,
use_sni: bool,
danger_accept_invalid_hostnames: bool,
@@ -311,7 +303,8 @@ impl TlsConnector {
pub fn builder() -> Result<TlsConnectorBuilder, Error> {
Ok(TlsConnectorBuilder(TlsConnector {
identity: None,
protocols: vec![Protocol::Tlsv10, Protocol::Tlsv11, Protocol::Tlsv12],
min_protocol: None,
max_protocol: None,
roots: vec![],
use_sni: true,
danger_accept_invalid_hostnames: false,
@@ -324,9 +317,12 @@ impl TlsConnector {
S: io::Read + io::Write,
{
let mut builder = ClientBuilder::new();
let (min, max) = protocol_min_max(&self.protocols);
builder.protocol_min(min);
builder.protocol_max(max);
if let Some(min) = self.min_protocol {
builder.protocol_min(convert_protocol(min));
}
if let Some(max) = self.max_protocol {
builder.protocol_max(convert_protocol(max));
}
if let Some(identity) = self.identity.as_ref() {
builder.identity(&identity.identity, &identity.chain);
}
@@ -345,8 +341,13 @@ impl TlsConnector {
pub struct TlsAcceptorBuilder(TlsAcceptor);
impl TlsAcceptorBuilder {
pub fn supported_protocols(&mut self, protocols: &[Protocol]) -> Result<(), Error> {
self.0.protocols = protocols.to_vec();
pub fn min_protocol_version(&mut self, protocol: Option<Protocol>) -> Result<(), Error> {
self.0.min_protocol = protocol;
Ok(())
}
pub fn max_protocol_version(&mut self, protocol: Option<Protocol>) -> Result<(), Error> {
self.0.max_protocol = protocol;
Ok(())
}
@@ -358,14 +359,16 @@ impl TlsAcceptorBuilder {
#[derive(Clone)]
pub struct TlsAcceptor {
identity: Identity,
protocols: Vec<Protocol>,
min_protocol: Option<Protocol>,
max_protocol: Option<Protocol>,
}
impl TlsAcceptor {
pub fn builder(identity: Identity) -> Result<TlsAcceptorBuilder, Error> {
Ok(TlsAcceptorBuilder(TlsAcceptor {
identity,
protocols: vec![Protocol::Tlsv10, Protocol::Tlsv11, Protocol::Tlsv12],
min_protocol: None,
max_protocol: None,
}))
}
@@ -375,9 +378,12 @@ impl TlsAcceptor {
{
let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
let (min, max) = protocol_min_max(&self.protocols);
ctx.set_protocol_version_min(min)?;
ctx.set_protocol_version_max(max)?;
if let Some(min) = self.min_protocol {
ctx.set_protocol_version_min(convert_protocol(min))?;
}
if let Some(max) = self.max_protocol {
ctx.set_protocol_version_max(convert_protocol(max))?;
}
ctx.set_certificate(&self.identity.identity, &self.identity.chain)?;
match ctx.handshake(stream) {
Ok(s) => Ok(TlsStream(s)),
+26 -14
View File
@@ -336,15 +336,21 @@ impl TlsConnectorBuilder {
Ok(self)
}
/// Sets the protocols which the connector will support.
///
/// The protocols supported by default are currently TLS 1.0, TLS 1.1, and TLS 1.2, though this
/// is subject to change.
pub fn supported_protocols(
/// Sets the minimum supported protocol version.
pub fn min_protocol_version(
&mut self,
protocols: &[Protocol],
protocol: Option<Protocol>,
) -> Result<&mut TlsConnectorBuilder> {
self.0.supported_protocols(protocols)?;
self.0.min_protocol_version(protocol)?;
Ok(self)
}
/// Sets the minimum supported protocol version.
pub fn max_protocol_version(
&mut self,
protocol: Option<Protocol>,
) -> Result<&mut TlsConnectorBuilder> {
self.0.max_protocol_version(protocol)?;
Ok(self)
}
@@ -452,15 +458,21 @@ impl TlsConnector {
pub struct TlsAcceptorBuilder(imp::TlsAcceptorBuilder);
impl TlsAcceptorBuilder {
/// Sets the protocols which the acceptor will support.
///
/// The protocols supported by default are currently TLS 1.0, TLS 1.1, and TLS 1.2, though this
/// is subject to change.
pub fn supported_protocols(
/// Sets the minimum supported protocol version.
pub fn min_protocol_version(
&mut self,
protocols: &[Protocol],
protocol: Option<Protocol>,
) -> Result<&mut TlsAcceptorBuilder> {
self.0.supported_protocols(protocols)?;
self.0.min_protocol_version(protocol)?;
Ok(self)
}
/// Sets the minimum supported protocol version.
pub fn max_protocol_version(
&mut self,
protocol: Option<Protocol>,
) -> Result<&mut TlsAcceptorBuilder> {
self.0.max_protocol_version(protocol)?;
Ok(self)
}
+6 -4
View File
@@ -133,7 +133,8 @@ mod tests {
let buf = include_bytes!("../test/identity.p12");
let identity = p!(Identity::from_pkcs12(buf, "mypass"));
let mut builder = p!(TlsAcceptor::builder(identity));
p!(builder.supported_protocols(&[Protocol::Tlsv11]));
p!(builder.min_protocol_version(Some(Protocol::Tlsv11)));
p!(builder.max_protocol_version(Some(Protocol::Tlsv11)));
let builder = p!(builder.build());
let listener = p!(TcpListener::bind("0.0.0.0:0"));
@@ -156,7 +157,8 @@ mod tests {
let socket = p!(TcpStream::connect(("localhost", port)));
let mut builder = p!(TlsConnector::builder());
p!(builder.add_root_certificate(root_ca));
p!(builder.supported_protocols(&[Protocol::Tlsv11]));
p!(builder.min_protocol_version(Some(Protocol::Tlsv11)));
p!(builder.max_protocol_version(Some(Protocol::Tlsv11)));
let builder = p!(builder.build());
let mut socket = p!(builder.connect("foobar.com", socket));
@@ -173,7 +175,7 @@ mod tests {
let buf = include_bytes!("../test/identity.p12");
let identity = p!(Identity::from_pkcs12(buf, "mypass"));
let mut builder = p!(TlsAcceptor::builder(identity));
p!(builder.supported_protocols(&[Protocol::Tlsv12]));
p!(builder.min_protocol_version(Some(Protocol::Tlsv12)));
let builder = p!(builder.build());
let listener = p!(TcpListener::bind("0.0.0.0:0"));
@@ -190,7 +192,7 @@ mod tests {
let socket = p!(TcpStream::connect(("localhost", port)));
let mut builder = p!(TlsConnector::builder());
p!(builder.add_root_certificate(root_ca));
p!(builder.supported_protocols(&[Protocol::Sslv3, Protocol::Tlsv10, Protocol::Tlsv11],));
p!(builder.max_protocol_version(Some(Protocol::Tlsv11)));
let builder = p!(builder.build());
assert!(builder.connect("foobar.com", socket).is_err());