Optimize decode performance. Fixes #30.

This commit is contained in:
Marshall Pierce 2017-05-19 11:40:39 -07:00
parent e7f3208fca
commit 154103eee3
2 changed files with 212 additions and 236 deletions

View File

@ -116,7 +116,6 @@ pub static URL_SAFE_NO_PAD: Config = Config {
line_wrap: LineWrap::NoWrap,
};
#[derive(Debug, PartialEq, Eq)]
pub enum DecodeError {
InvalidByte(usize, u8),
@ -489,15 +488,19 @@ pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(input: &T,
let decode_table = &config.char_set.decode_table();
buffer.reserve(input_bytes.len() * 3 / 4);
// decode logic operates on chunks of 8 input bytes without padding
const INPUT_CHUNK_LEN: usize = 8;
const DECODED_CHUNK_LEN: usize = 6;
// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
// 2 bytes of any output u64 should not be counted as written to (but must be available in a
// slice).
const DECODED_CHUNK_SUFFIX: usize = 2;
// the fast loop only handles complete chunks of 8 input bytes without padding
let chunk_len = 8;
let decoded_chunk_len = 6;
let remainder_len = input_bytes.len() % chunk_len;
let remainder_len = input_bytes.len() % INPUT_CHUNK_LEN;
let trailing_bytes_to_skip = if remainder_len == 0 {
// if input is a multiple of the chunk size, ignore the last chunk as it may have padding
chunk_len
// if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
// and the fast decode logic cannot handle padding
INPUT_CHUNK_LEN
} else {
remainder_len
};
@ -506,105 +509,63 @@ pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(input: &T,
let starting_output_index = buffer.len();
// Resize to hold decoded output from fast loop. Need the extra two bytes because
// we write a full 8 bytes for the last 6-byte decoded chunk and then truncate off two
// we write a full 8 bytes for the last 6-byte decoded chunk and then truncate off the last two.
let new_size = starting_output_index
+ length_of_full_chunks / chunk_len * decoded_chunk_len
+ (chunk_len - decoded_chunk_len);
.checked_add(length_of_full_chunks / INPUT_CHUNK_LEN * DECODED_CHUNK_LEN)
.and_then(|l| l.checked_add(DECODED_CHUNK_SUFFIX))
.expect("Overflow when calculating output buffer length");
buffer.resize(new_size, 0);
let mut output_index = starting_output_index;
{
let buffer_slice = buffer.as_mut_slice();
let mut output_index = 0;
let mut input_index = 0;
// initial value is never used; always set if fast loop breaks
let mut bad_byte_index: usize = 0;
// a non-invalid value means it's not an error if fast loop never runs
let mut morsel: u8 = 0;
let buffer_slice = &mut buffer.as_mut_slice()[starting_output_index..];
// fast loop of 8 bytes at a time
// how many u64's of input to handle at a time
const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
// includes the trailing 2 bytes for the final u64 write
const DECODED_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN +
DECODED_CHUNK_SUFFIX;
// the start index of the last block of data that is big enough to use the unrolled loop
let last_block_start_index = length_of_full_chunks
.saturating_sub(INPUT_CHUNK_LEN * CHUNKS_PER_FAST_LOOP_BLOCK);
// manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
if last_block_start_index > 0 {
while input_index <= last_block_start_index {
let input_slice = &input_bytes[input_index..(input_index + INPUT_BLOCK_LEN)];
let output_slice = &mut buffer_slice[output_index..(output_index + DECODED_BLOCK_LEN)];
decode_chunk(&input_slice[0..], input_index, decode_table, &mut output_slice[0..])?;
decode_chunk(&input_slice[8..], input_index + 8, decode_table, &mut output_slice[6..])?;
decode_chunk(&input_slice[16..], input_index + 16, decode_table, &mut output_slice[12..])?;
decode_chunk(&input_slice[24..], input_index + 24, decode_table, &mut output_slice[18..])?;
input_index += INPUT_BLOCK_LEN;
output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
}
}
// still pretty fast loop: 8 bytes at a time for whatever we didn't do in the faster loop.
while input_index < length_of_full_chunks {
let mut accum: u64;
decode_chunk(&input_bytes[input_index..(input_index + 8)], input_index, decode_table,
&mut buffer_slice[output_index..(output_index + 8)])?;
let input_chunk = BigEndian::read_u64(&input_bytes[input_index..(input_index + 8)]);
morsel = decode_table[(input_chunk >> 56) as usize];
if morsel == tables::INVALID_VALUE {
bad_byte_index = input_index;
break;
};
accum = (morsel as u64) << 58;
morsel = decode_table[(input_chunk >> 48 & 0xFF) as usize];
if morsel == tables::INVALID_VALUE {
bad_byte_index = input_index + 1;
break;
};
accum |= (morsel as u64) << 52;
morsel = decode_table[(input_chunk >> 40 & 0xFF) as usize];
if morsel == tables::INVALID_VALUE {
bad_byte_index = input_index + 2;
break;
};
accum |= (morsel as u64) << 46;
morsel = decode_table[(input_chunk >> 32 & 0xFF) as usize];
if morsel == tables::INVALID_VALUE {
bad_byte_index = input_index + 3;
break;
};
accum |= (morsel as u64) << 40;
morsel = decode_table[(input_chunk >> 24 & 0xFF) as usize];
if morsel == tables::INVALID_VALUE {
bad_byte_index = input_index + 4;
break;
};
accum |= (morsel as u64) << 34;
morsel = decode_table[(input_chunk >> 16 & 0xFF) as usize];
if morsel == tables::INVALID_VALUE {
bad_byte_index = input_index + 5;
break;
};
accum |= (morsel as u64) << 28;
morsel = decode_table[(input_chunk >> 8 & 0xFF) as usize];
if morsel == tables::INVALID_VALUE {
bad_byte_index = input_index + 6;
break;
};
accum |= (morsel as u64) << 22;
morsel = decode_table[(input_chunk & 0xFF) as usize];
if morsel == tables::INVALID_VALUE {
bad_byte_index = input_index + 7;
break;
};
accum |= (morsel as u64) << 16;
BigEndian::write_u64(&mut buffer_slice[(output_index)..(output_index + 8)],
accum);
output_index += 6;
input_index += chunk_len;
};
if morsel == tables::INVALID_VALUE {
// we got here from a break
return Err(DecodeError::InvalidByte(bad_byte_index, input_bytes[bad_byte_index]));
output_index += DECODED_CHUNK_LEN;
input_index += INPUT_CHUNK_LEN;
}
}
// Truncate off the last two bytes from writing the last u64.
// Unconditional because we added on the extra 2 bytes in the resize before the loop,
// so it will never underflow.
let new_len = buffer.len() - (chunk_len - decoded_chunk_len);
let new_len = buffer.len() - DECODED_CHUNK_SUFFIX;
buffer.truncate(new_len);
// handle leftovers (at most 8 bytes, decoded to 6).
// Use a u64 as a stack-resident 8 bytes buffer.
// Use a u64 as a stack-resident 8 byte buffer.
let mut leftover_bits: u64 = 0;
let mut morsels_in_leftover = 0;
let mut padding_bytes = 0;
@ -623,17 +584,26 @@ pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(input: &T,
if i % 4 < 2 {
// Check for case #2.
// TODO InvalidPadding error
return Err(DecodeError::InvalidByte(length_of_full_chunks + i, *b));
};
let bad_padding_index = length_of_full_chunks + if padding_bytes > 0 {
// If we've already seen padding, report the first padding index.
// This is to be consistent with the faster logic above: it will report an error
// on the first padding character (since it doesn't expect to see anything but
// actual encoded data).
first_padding_index
} else {
// haven't seen padding before, just use where we are now
i
};
return Err(DecodeError::InvalidByte(bad_padding_index, *b));
}
if padding_bytes == 0 {
first_padding_index = i;
};
}
padding_bytes += 1;
continue;
};
}
// Check for case #1.
// To make '=' handling consistent with the main loop, don't allow
@ -642,20 +612,20 @@ pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(input: &T,
if padding_bytes > 0 {
return Err(DecodeError::InvalidByte(
length_of_full_chunks + first_padding_index, 0x3D));
};
}
// can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
// To minimize shifts, pack the leftovers from left to right.
let shift = 64 - (morsels_in_leftover + 1) * 6;
// tables are all 256 elements, cannot overflow from a u8 index
// tables are all 256 elements, lookup with a u8 index always succeeds
let morsel = decode_table[*b as usize];
if morsel == tables::INVALID_VALUE {
return Err(DecodeError::InvalidByte(length_of_full_chunks + i, *b));
};
}
leftover_bits |= (morsel as u64) << shift;
morsels_in_leftover += 1;
};
}
let leftover_bits_ready_to_append = match morsels_in_leftover {
0 => 0,
@ -682,5 +652,64 @@ pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(input: &T,
Ok(())
}
// yes, really inline (worth 30-50% speedup)
#[inline(always)]
fn decode_chunk(input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256],
output: &mut [u8]) -> Result<(), DecodeError> {
let mut accum: u64;
let morsel = decode_table[input[0] as usize];
if morsel == tables::INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
}
accum = (morsel as u64) << 58;
let morsel = decode_table[input[1] as usize];
if morsel == tables::INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input + 1, input[1]));
}
accum |= (morsel as u64) << 52;
let morsel = decode_table[input[2] as usize];
if morsel == tables::INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input + 2, input[2]));
}
accum |= (morsel as u64) << 46;
let morsel = decode_table[input[3] as usize];
if morsel == tables::INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input + 3, input[3]));
}
accum |= (morsel as u64) << 40;
let morsel = decode_table[input[4] as usize];
if morsel == tables::INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input + 4, input[4]));
}
accum |= (morsel as u64) << 34;
let morsel = decode_table[input[5] as usize];
if morsel == tables::INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input + 5, input[5]));
}
accum |= (morsel as u64) << 28;
let morsel = decode_table[input[6] as usize];
if morsel == tables::INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input + 6, input[6]));
}
accum |= (morsel as u64) << 22;
let morsel = decode_table[input[7] as usize];
if morsel == tables::INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input + 7, input[7]));
}
accum |= (morsel as u64) << 16;
BigEndian::write_u64(output, accum);
Ok(())
}
#[cfg(test)]
mod tests;

View File

@ -127,151 +127,134 @@ fn decode_rfc4648_6() {
//this is a MAY in the rfc: https://tools.ietf.org/html/rfc4648#section-3.3
#[test]
fn decode_pad_inside_fast_loop_chunk_error() {
// can't PartialEq Base64Error, so we do this the hard way
match decode("YWxpY2U=====").unwrap_err() {
DecodeError::InvalidByte(offset, byte) => {
// since the first 8 bytes are handled in the fast loop, the
// padding is an error. Could argue that the *next* padding
// byte is technically the first erroneous one, but reporting
// that accurately is more complex and probably nobody cares
assert_eq!(7, offset);
assert_eq!(0x3D, byte);
}
_ => assert!(false)
for num_quads in 0..25 {
let mut s: String = std::iter::repeat("ABCD").take(num_quads).collect();
s.push_str("YWxpY2U=====");
// since the first 8 bytes are handled in the fast loop, the
// padding is an error. Could argue that the *next* padding
// byte is technically the first erroneous one, but reporting
// that accurately is more complex and probably nobody cares
assert_eq!(DecodeError::InvalidByte(num_quads * 4 + 7, b'='), decode(&s).unwrap_err());
}
}
#[test]
fn decode_extra_pad_after_fast_loop_chunk_error() {
match decode("YWxpY2UABB===").unwrap_err() {
DecodeError::InvalidByte(offset, byte) => {
// extraneous third padding byte
assert_eq!(12, offset);
assert_eq!(0x3D, byte);
}
_ => assert!(false)
};
}
for num_quads in 0..25 {
let mut s: String = std::iter::repeat("ABCD").take(num_quads).collect();
s.push_str("YWxpY2UABB===");
//same
#[test]
fn decode_absurd_pad_error() {
match decode("==Y=Wx===pY=2U=====").unwrap_err() {
DecodeError::InvalidByte(size, byte) => {
assert_eq!(0, size);
assert_eq!(0x3D, byte);
}
_ => assert!(false)
// first padding byte
assert_eq!(DecodeError::InvalidByte(num_quads * 4 + 10, b'='), decode(&s).unwrap_err());
}
}
#[test]
fn decode_starts_with_padding_single_quad_error() {
match decode("====").unwrap_err() {
DecodeError::InvalidByte(offset, byte) => {
// with no real input, first padding byte is bogus
assert_eq!(0, offset);
assert_eq!(0x3D, byte);
}
_ => assert!(false)
fn decode_absurd_pad_error() {
for num_quads in 0..25 {
let mut s: String = std::iter::repeat("ABCD").take(num_quads).collect();
s.push_str("==Y=Wx===pY=2U=====");
// first padding byte
assert_eq!(DecodeError::InvalidByte(num_quads * 4, b'='), decode(&s).unwrap_err());
}
}
#[test]
fn decode_extra_padding_in_trailing_quad_returns_error() {
match decode("zzz==").unwrap_err() {
DecodeError::InvalidByte(size, byte) => {
// first unneeded padding byte
assert_eq!(4, size);
assert_eq!(0x3D, byte);
}
_ => assert!(false)
for num_quads in 0..25 {
let mut s: String = std::iter::repeat("ABCD").take(num_quads).collect();
s.push_str("EEE==");
// first padding byte -- which would be legal if it was by itself
assert_eq!(DecodeError::InvalidByte(num_quads * 4 + 3, b'='), decode(&s).unwrap_err());
}
}
#[test]
fn decode_extra_padding_in_trailing_quad_2_returns_error() {
match decode("zz===").unwrap_err() {
DecodeError::InvalidByte(size, byte) => {
// first unneeded padding byte
assert_eq!(4, size);
assert_eq!(0x3D, byte);
}
_ => assert!(false)
for num_quads in 0..25 {
let mut s: String = std::iter::repeat("ABCD").take(num_quads).collect();
s.push_str("EE===");
// first padding byte -- which would be legal if it was by itself
assert_eq!(DecodeError::InvalidByte(num_quads * 4 + 2, b'='), decode(&s).unwrap_err());
}
}
#[test]
fn decode_start_second_quad_with_padding_returns_error() {
match decode("zzzz=").unwrap_err() {
DecodeError::InvalidByte(size, byte) => {
// first unneeded padding byte
assert_eq!(4, size);
assert_eq!(0x3D, byte);
}
_ => assert!(false)
for num_quads in 0..25 {
let mut s: String = std::iter::repeat("ABCD").take(num_quads).collect();
s.push_str("=");
// first padding byte -- must have two non-padding bytes in a quad
assert_eq!(DecodeError::InvalidByte(num_quads * 4, b'='), decode(&s).unwrap_err());
// two padding bytes -- same
s.push_str("=");
assert_eq!(DecodeError::InvalidByte(num_quads * 4, b'='), decode(&s).unwrap_err());
s.push_str("=");
assert_eq!(DecodeError::InvalidByte(num_quads * 4, b'='), decode(&s).unwrap_err());
s.push_str("=");
assert_eq!(DecodeError::InvalidByte(num_quads * 4, b'='), decode(&s).unwrap_err());
}
}
#[test]
fn decode_padding_in_last_quad_followed_by_non_padding_returns_error() {
match decode("zzzz==z").unwrap_err() {
DecodeError::InvalidByte(size, byte) => {
assert_eq!(4, size);
assert_eq!(0x3D, byte);
}
_ => assert!(false)
for num_quads in 0..25 {
let mut s: String = std::iter::repeat("ABCD").take(num_quads).collect();
s.push_str("==E");
// first padding byte -- must have two non-padding bytes in a quad
assert_eq!(DecodeError::InvalidByte(num_quads * 4, b'='), decode(&s).unwrap_err());
}
}
#[test]
fn decode_too_short_with_padding_error() {
match decode("z==").unwrap_err() {
DecodeError::InvalidByte(size, byte) => {
// first unneeded padding byte
assert_eq!(1, size);
assert_eq!(0x3D, byte);
}
_ => assert!(false)
fn decode_one_char_in_quad_with_padding_error() {
for num_quads in 0..25 {
let mut s: String = std::iter::repeat("ABCD").take(num_quads).collect();
s.push_str("E=");
assert_eq!(DecodeError::InvalidByte(num_quads * 4 + 1, b'='), decode(&s).unwrap_err());
// more padding doesn't change the error
s.push_str("=");
assert_eq!(DecodeError::InvalidByte(num_quads * 4 + 1, b'='), decode(&s).unwrap_err());
s.push_str("=");
assert_eq!(DecodeError::InvalidByte(num_quads * 4 + 1, b'='), decode(&s).unwrap_err());
}
}
#[test]
fn decode_too_short_without_padding_error() {
match decode("z").unwrap_err() {
DecodeError::InvalidLength => {}
_ => assert!(false)
fn decode_one_char_in_quad_without_padding_error() {
for num_quads in 0..25 {
let mut s: String = std::iter::repeat("ABCD").take(num_quads).collect();
s.push('E');
assert_eq!(DecodeError::InvalidLength, decode(&s).unwrap_err());
}
}
#[test]
fn decode_too_short_second_quad_without_padding_error() {
match decode("zzzzX").unwrap_err() {
DecodeError::InvalidLength => {}
_ => assert!(false)
}
}
fn decode_reject_invalid_bytes_with_correct_error() {
for length in 1..100 {
for index in 0_usize..length {
for invalid_byte in " \t\n\r\x0C\x0B\x00%*.".bytes() {
let prefix: String = std::iter::repeat("A").take(index).collect();
let suffix: String = std::iter::repeat("B").take(length - index - 1).collect();
#[test]
fn decode_error_for_bogus_char_in_right_position() {
for length in 1..25 {
for error_position in 0_usize..length {
let prefix: String = std::iter::repeat("A").take(error_position).collect();
let suffix: String = std::iter::repeat("B").take(length - error_position - 1).collect();
let input = prefix + &String::from_utf8(vec![invalid_byte]).unwrap() + &suffix;
assert_eq!(length, input.len(), "length {} error position {}", length, index);
let input = prefix + "%" + &suffix;
assert_eq!(length, input.len(),
"length {} error position {}", length, error_position);
match decode(&input).unwrap_err() {
DecodeError::InvalidByte(size, byte) => {
assert_eq!(error_position, size,
"length {} error position {}", length, error_position);
assert_eq!(0x25, byte);
}
_ => assert!(false)
assert_eq!(DecodeError::InvalidByte(index, invalid_byte),
decode(&input).unwrap_err());
}
}
}
@ -317,42 +300,6 @@ fn roundtrip_random_no_padding() {
}
}
//strip yr whitespace kids
#[test]
fn decode_reject_space() {
assert_eq!(DecodeError::InvalidByte(3, 0x20), decode("YWx pY2U=").unwrap_err());
}
#[test]
fn decode_reject_tab() {
assert_eq!(DecodeError::InvalidByte(3, 0x9),decode("YWx\tpY2U=").unwrap_err());
}
#[test]
fn decode_reject_ff() {
assert_eq!(DecodeError::InvalidByte(3, 0xC),decode("YWx\x0cpY2U=").unwrap_err());
}
#[test]
fn decode_reject_vtab() {
assert_eq!(DecodeError::InvalidByte(3, 0xB),decode("YWx\x0bpY2U=").unwrap_err());
}
#[test]
fn decode_reject_nl() {
assert_eq!(DecodeError::InvalidByte(3, 0xA),decode("YWx\npY2U=").unwrap_err());
}
#[test]
fn decode_reject_crnl() {
assert_eq!(DecodeError::InvalidByte(3, 0xD),decode("YWx\r\npY2U=").unwrap_err());
}
#[test]
fn decode_reject_null() {
assert_eq!(DecodeError::InvalidByte(3, 0x0),decode("YWx\0pY2U=").unwrap_err());
}
#[test]
fn decode_mime_allow_space() {
assert!(decode_config("YWx pY2U=", MIME).is_ok());