Merge pull request #82 from alicemaz/detect-invalid-last-byte

Detect trailing symbols that encode unused bits and report an error.
This commit is contained in:
Marshall Pierce 2018-10-28 16:09:33 -06:00 committed by GitHub
commit 51873b6c3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 0 deletions

View File

@ -6,6 +6,7 @@
- Add a streaming encoder `Write` impl to transparently base64 as you write.
- Remove the remaining `unsafe` code.
- Remove whitespace stripping to simplify `no_std` support. No out of the box configs use it, and it's trivial to do yourself if needed: `filter(|b| !b" \n\t\r\x0b\x0c".contains(b)`.
- Detect invalid trailing symbols when decoding and return an error rather than silently ignoring them.
# 0.9.3

View File

@ -25,6 +25,11 @@ pub enum DecodeError {
InvalidByte(usize, u8),
/// The length of the input is invalid.
InvalidLength,
/// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
/// This is indicative of corrupted or truncated Base64.
/// Unlike InvalidByte, which reports symbols that aren't in the alphabet, this error is for
/// symbols that are in the alphabet but represent nonsensical encodings.
InvalidLastSymbol(usize, u8),
}
impl fmt::Display for DecodeError {
@ -34,6 +39,9 @@ impl fmt::Display for DecodeError {
write!(f, "Invalid byte {}, offset {}.", byte, index)
}
DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
DecodeError::InvalidLastSymbol(index, byte) => {
write!(f, "Invalid last symbol {}, offset {}.", byte, index)
}
}
}
}
@ -43,6 +51,7 @@ impl error::Error for DecodeError {
match *self {
DecodeError::InvalidByte(_, _) => "invalid byte",
DecodeError::InvalidLength => "invalid length",
DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol",
}
}
@ -302,6 +311,10 @@ fn decode_helper(
output_index += DECODED_CHUNK_LEN;
}
// always have one more (possibly partial) block of 8 input
debug_assert!(input.len() - input_index > 1 || input.len() == 0);
debug_assert!(input.len() - input_index <= 8);
// Stage 4
// Finally, decode any leftovers that aren't a complete input block of 8 bytes.
// Use a u64 as a stack-resident 8 byte buffer.
@ -309,6 +322,7 @@ fn decode_helper(
let mut morsels_in_leftover = 0;
let mut padding_bytes = 0;
let mut first_padding_index: usize = 0;
let mut last_symbol = 0_u8;
let start_of_leftovers = input_index;
for (i, b) in input[start_of_leftovers..].iter().enumerate() {
// '=' padding
@ -355,6 +369,7 @@ fn decode_helper(
0x3D,
));
}
last_symbol = *b;
// can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
// To minimize shifts, pack the leftovers from left to right.
@ -382,6 +397,14 @@ fn decode_helper(
),
};
// if there are bits set outside the bits we care about, last symbol encodes trailing bits that
// will not be included in the output
let mask = !0 >> leftover_bits_ready_to_append;
if (leftover_bits & mask) != 0 {
// last morsel is at `morsels_in_leftover` - 1
return Err(DecodeError::InvalidLastSymbol(start_of_leftovers + morsels_in_leftover - 1, last_symbol));
}
let mut leftover_bits_appended_to_buf = 0;
while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
// `as` simply truncates the higher bits, which is what we want here
@ -686,4 +709,107 @@ mod tests {
assert_eq!(orig_data, decode_buf);
}
}
#[test]
fn detect_invalid_last_symbol_two_bytes() {
// example from https://github.com/alicemaz/rust-base64/issues/75
assert!(decode("iYU=").is_ok());
// trailing 01
assert_eq!(Err(DecodeError::InvalidLastSymbol(2, b'V')), decode("iYV="));
// trailing 10
assert_eq!(Err(DecodeError::InvalidLastSymbol(2, b'W')), decode("iYW="));
// trailing 11
assert_eq!(Err(DecodeError::InvalidLastSymbol(2, b'X')), decode("iYX="));
// also works when there are 2 quads in the last block
assert_eq!(Err(DecodeError::InvalidLastSymbol(6, b'X')), decode("AAAAiYX="));
}
#[test]
fn detect_invalid_last_symbol_one_byte() {
// 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol
assert!(decode("/w==").is_ok());
// trailing 01
assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x=="));
assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z=="));
assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0=="));
assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9=="));
assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+=="));
assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//=="));
// also works when there are 2 quads in the last block
assert_eq!(Err(DecodeError::InvalidLastSymbol(5, b'x')), decode("AAAA/x=="));
}
#[test]
fn detect_invalid_last_symbol_every_possible_three_symbols() {
let mut base64_to_bytes = ::std::collections::HashMap::new();
let mut bytes = [0_u8; 2];
for b1 in 0_u16..256 {
bytes[0] = b1 as u8;
for b2 in 0_u16..256 {
bytes[1] = b2 as u8;
let mut b64 = vec![0_u8; 4];
assert_eq!(4, ::encode_config_slice(&bytes, STANDARD, &mut b64[..]));
let mut v = ::std::vec::Vec::with_capacity(2);
v.extend_from_slice(&bytes[..]);
assert!(base64_to_bytes.insert(b64, v).is_none());
};
}
// every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol
let mut symbols = [0_u8; 4];
for &s1 in STANDARD.char_set.encode_table().iter() {
symbols[0] = s1;
for &s2 in STANDARD.char_set.encode_table().iter() {
symbols[1] = s2;
for &s3 in STANDARD.char_set.encode_table().iter() {
symbols[2] = s3;
symbols[3] = b'=';
match base64_to_bytes.get(&symbols[..]) {
Some(bytes) => assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD)),
None => assert_eq!(Err(DecodeError::InvalidLastSymbol(2, s3)),
decode_config(&symbols[..], STANDARD))
}
}
}
}
}
#[test]
fn detect_invalid_last_symbol_every_possible_two_symbols() {
let mut base64_to_bytes = ::std::collections::HashMap::new();
for b in 0_u16..256 {
let mut b64 = vec![0_u8; 4];
assert_eq!(4, ::encode_config_slice(&[b as u8], STANDARD, &mut b64[..]));
let mut v = ::std::vec::Vec::with_capacity(1);
v.push(b as u8);
assert!(base64_to_bytes.insert(b64, v).is_none());
};
// every possible combination of symbols must either decode to 1 byte or get InvalidLastSymbol
let mut symbols = [0_u8; 4];
for &s1 in STANDARD.char_set.encode_table().iter() {
symbols[0] = s1;
for &s2 in STANDARD.char_set.encode_table().iter() {
symbols[1] = s2;
symbols[2] = b'=';
symbols[3] = b'=';
match base64_to_bytes.get(&symbols[..]) {
Some(bytes) => assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD)),
None => assert_eq!(Err(DecodeError::InvalidLastSymbol(1, s2)),
decode_config(&symbols[..], STANDARD))
}
}
}
}
}