From f6915a3abaeba260ce3bae001a0667cb44f560fb Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Sun, 28 Oct 2018 13:44:48 -0600 Subject: [PATCH 1/2] Detect trailing symbols that encode unused bits and report an error. --- RELEASE-NOTES.md | 1 + src/decode.rs | 126 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 5179200..4d893f1 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -6,6 +6,7 @@ - Add a streaming encoder `Write` impl to transparently base64 as you write. - Remove the remaining `unsafe` code. - Remove whitespace stripping to simplify `no_std` support. No out of the box configs use it, and it's trivial to do yourself if needed: `filter(|b| !b" \n\t\r\x0b\x0c".contains(b)`. +- Detect invalid trailing symbols when decoding and return an error rather than silently ignoring them. # 0.9.3 diff --git a/src/decode.rs b/src/decode.rs index e869340..7e67c3e 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -25,6 +25,11 @@ pub enum DecodeError { InvalidByte(usize, u8), /// The length of the input is invalid. InvalidLength, + /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded. + /// This is indicative of corrupted or truncated Base64. + /// Unlike InvalidByte, which reports symbols that aren't in the alphabet, this error is for + /// symbols that are in the alphabet but represent nonsensical encodings. + InvalidLastSymbol(usize, u8), } impl fmt::Display for DecodeError { @@ -34,6 +39,9 @@ impl fmt::Display for DecodeError { write!(f, "Invalid byte {}, offset {}.", byte, index) } DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."), + DecodeError::InvalidLastSymbol(index, byte) => { + write!(f, "Invalid last symbol {}, offset {}.", byte, index) + } } } } @@ -43,6 +51,7 @@ impl error::Error for DecodeError { match *self { DecodeError::InvalidByte(_, _) => "invalid byte", DecodeError::InvalidLength => "invalid length", + DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol", } } @@ -302,6 +311,10 @@ fn decode_helper( output_index += DECODED_CHUNK_LEN; } + // always have one more (possibly partial) block of 8 input + debug_assert!(input.len() - input_index > 1 || input.len() == 0); + debug_assert!(input.len() - input_index <= 8); + // Stage 4 // Finally, decode any leftovers that aren't a complete input block of 8 bytes. // Use a u64 as a stack-resident 8 byte buffer. @@ -309,6 +322,7 @@ fn decode_helper( let mut morsels_in_leftover = 0; let mut padding_bytes = 0; let mut first_padding_index: usize = 0; + let mut last_symbol = 0_u8; let start_of_leftovers = input_index; for (i, b) in input[start_of_leftovers..].iter().enumerate() { // '=' padding @@ -355,6 +369,7 @@ fn decode_helper( 0x3D, )); } + last_symbol = *b; // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding. // To minimize shifts, pack the leftovers from left to right. @@ -382,6 +397,14 @@ fn decode_helper( ), }; + // if there are bits set outside the bits we care about, last symbol encodes trailing bits that + // will not be included in the output + let mask = !0 >> leftover_bits_ready_to_append; + if (leftover_bits & mask) != 0 { + // last morsel is at `morsels_in_leftover` - 1 + return Err(DecodeError::InvalidLastSymbol(start_of_leftovers + morsels_in_leftover - 1, last_symbol)); + } + let mut leftover_bits_appended_to_buf = 0; while leftover_bits_appended_to_buf < leftover_bits_ready_to_append { // `as` simply truncates the higher bits, which is what we want here @@ -686,4 +709,107 @@ mod tests { assert_eq!(orig_data, decode_buf); } } + + #[test] + fn detect_invalid_last_symbol_two_bytes() { + // example from https://github.com/alicemaz/rust-base64/issues/75 + assert!(decode("iYU=").is_ok()); + // trailing 01 + assert_eq!(Err(DecodeError::InvalidLastSymbol(2, b'V')), decode("iYV=")); + // trailing 10 + assert_eq!(Err(DecodeError::InvalidLastSymbol(2, b'W')), decode("iYW=")); + // trailing 11 + assert_eq!(Err(DecodeError::InvalidLastSymbol(2, b'X')), decode("iYX=")); + + // also works when there are 2 quads in the last block + assert_eq!(Err(DecodeError::InvalidLastSymbol(6, b'X')), decode("AAAAiYX=")); + } + + #[test] + fn detect_invalid_last_symbol_one_byte() { + // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol + + assert!(decode("/w==").is_ok()); + // trailing 01 + assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x==")); + assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z==")); + assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0==")); + assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9==")); + assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+==")); + assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//==")); + + // also works when there are 2 quads in the last block + assert_eq!(Err(DecodeError::InvalidLastSymbol(5, b'x')), decode("AAAA/x==")); + } + + #[test] + fn detect_invalid_last_symbol_every_possible_three_symbols() { + let mut base64_to_bytes = std::collections::HashMap::new(); + + let mut bytes = [0_u8; 2]; + for b1 in 0_u16..256 { + bytes[0] = b1 as u8; + for b2 in 0_u16..256 { + bytes[1] = b2 as u8; + let mut b64 = vec![0_u8; 4]; + assert_eq!(4, ::encode_config_slice(&bytes, STANDARD, &mut b64[..])); + let mut v = std::vec::Vec::with_capacity(2); + v.extend_from_slice(&bytes[..]); + + assert!(base64_to_bytes.insert(b64, v).is_none()); + }; + } + + // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol + + let mut symbols = [0_u8; 4]; + for &s1 in STANDARD.char_set.encode_table().iter() { + symbols[0] = s1; + for &s2 in STANDARD.char_set.encode_table().iter() { + symbols[1] = s2; + for &s3 in STANDARD.char_set.encode_table().iter() { + symbols[2] = s3; + symbols[3] = b'='; + + match base64_to_bytes.get(&symbols[..]) { + Some(bytes) => assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD)), + None => assert_eq!(Err(DecodeError::InvalidLastSymbol(2, s3)), + decode_config(&symbols[..], STANDARD)) + } + } + } + } + } + + #[test] + fn detect_invalid_last_symbol_every_possible_two_symbols() { + let mut base64_to_bytes = std::collections::HashMap::new(); + + for b in 0_u16..256 { + let mut b64 = vec![0_u8; 4]; + assert_eq!(4, ::encode_config_slice(&[b as u8], STANDARD, &mut b64[..])); + let mut v = std::vec::Vec::with_capacity(1); + v.push(b as u8); + + assert!(base64_to_bytes.insert(b64, v).is_none()); + }; + + // every possible combination of symbols must either decode to 1 byte or get InvalidLastSymbol + + let mut symbols = [0_u8; 4]; + for &s1 in STANDARD.char_set.encode_table().iter() { + symbols[0] = s1; + for &s2 in STANDARD.char_set.encode_table().iter() { + symbols[1] = s2; + symbols[2] = b'='; + symbols[3] = b'='; + + match base64_to_bytes.get(&symbols[..]) { + Some(bytes) => assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD)), + None => assert_eq!(Err(DecodeError::InvalidLastSymbol(1, s2)), + decode_config(&symbols[..], STANDARD)) + } + } + } + } } From 5d0d99932544de4f09db310c0e810c136bc4bb28 Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Sun, 28 Oct 2018 15:42:05 -0600 Subject: [PATCH 2/2] Fix for 1.23.0 compat --- src/decode.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 7e67c3e..9e5a762 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -744,7 +744,7 @@ mod tests { #[test] fn detect_invalid_last_symbol_every_possible_three_symbols() { - let mut base64_to_bytes = std::collections::HashMap::new(); + let mut base64_to_bytes = ::std::collections::HashMap::new(); let mut bytes = [0_u8; 2]; for b1 in 0_u16..256 { @@ -753,7 +753,7 @@ mod tests { bytes[1] = b2 as u8; let mut b64 = vec![0_u8; 4]; assert_eq!(4, ::encode_config_slice(&bytes, STANDARD, &mut b64[..])); - let mut v = std::vec::Vec::with_capacity(2); + let mut v = ::std::vec::Vec::with_capacity(2); v.extend_from_slice(&bytes[..]); assert!(base64_to_bytes.insert(b64, v).is_none()); @@ -783,12 +783,12 @@ mod tests { #[test] fn detect_invalid_last_symbol_every_possible_two_symbols() { - let mut base64_to_bytes = std::collections::HashMap::new(); + let mut base64_to_bytes = ::std::collections::HashMap::new(); for b in 0_u16..256 { let mut b64 = vec![0_u8; 4]; assert_eq!(4, ::encode_config_slice(&[b as u8], STANDARD, &mut b64[..])); - let mut v = std::vec::Vec::with_capacity(1); + let mut v = ::std::vec::Vec::with_capacity(1); v.push(b as u8); assert!(base64_to_bytes.insert(b64, v).is_none());