diff --git a/src/aes128gcm.rs b/src/aes128gcm.rs index 4dc4388..3d376bf 100644 --- a/src/aes128gcm.rs +++ b/src/aes128gcm.rs @@ -76,39 +76,16 @@ pub(crate) fn encrypt( keyid: &keyid, }; - // We always add at least one padding byte, for the delimiter. - let padding = std::cmp::max(params.pad_length, ECE_AES128GCM_PAD_SIZE); + let records = split_into_records(plaintext, params.pad_length, params.rs as usize)?; - // For now, everything must fit in a single record. - // Calling code will ensure that this is the case. - if params.rs < ECE_AES128GCM_MIN_RS { - return Err(Error::InvalidRecordSize); - } - if plaintext.len() + padding + ECE_TAG_LENGTH > params.rs as usize { - dbg!(format!( - "Message content too long for a single record (rs={}, plaintext={}, padding={})", - params.rs, - plaintext.len(), - padding - )); - return Err(Error::MultipleRecordsNotSupported); - } - let record = PlaintextRecord { - plaintext, - padding, - sequence_number: 0, - is_final: true, - }; + let mut ciphertext = vec![0; header.encoded_size() + records.total_ciphertext_size()]; + let mut offset = 0; - let mut ciphertext = vec![0; header.encoded_size() + record.encrypted_size()]; - - header.write_into(&mut ciphertext); - record.encrypt_into( - cryptographer, - &key, - &nonce, - &mut ciphertext[header.encoded_size()..], - )?; + offset += header.write_into(&mut ciphertext); + for record in records { + offset += record.encrypt_into(cryptographer, &key, &nonce, &mut ciphertext[offset..])?; + } + assert!(offset == ciphertext.len()); Ok(ciphertext) } @@ -234,12 +211,15 @@ impl<'a> Header<'a> { /// This assumes that the buffer has sufficient space for the data, and will /// panic (via Rust's runtime safety checks) if it does not. /// - pub fn write_into(&self, output: &mut [u8]) { + /// Returns the number of bytes written. + /// + pub fn write_into(&self, output: &mut [u8]) -> usize { output[0..ECE_SALT_LENGTH].copy_from_slice(self.salt); BigEndian::write_u32(&mut output[ECE_SALT_LENGTH..], self.rs); output[ECE_AES128GCM_HEADER_LENGTH - 1] = self.keyid.len() as u8; output[ECE_AES128GCM_HEADER_LENGTH..ECE_AES128GCM_HEADER_LENGTH + self.keyid.len()] .copy_from_slice(self.keyid); + self.encoded_size() } /// Get the size occupied by this header when written to the encrypted data. @@ -338,6 +318,8 @@ impl<'a> PlaintextRecord<'a> { /// and this method will panic (via Rust's runtime safety checks) if there is insufficient /// space available. /// + /// Returns the number of bytes written. + /// pub(crate) fn encrypt_into( &self, cryptographer: &dyn Cryptographer, @@ -363,9 +345,174 @@ impl<'a> PlaintextRecord<'a> { output[0..ciphertext.len()].copy_from_slice(&ciphertext); Ok(ciphertext.len()) } +} + +/// Iterator returning record-sized chunks of plaintext + padding. +/// +/// Given a plaintext, an amount of padding data to add, and a target encrypted record +/// size, this function returns an iterator of `PlaintextRecord` structs such that: +/// +/// * The encrypted size of each plaintext chunk plus its padding will be equal +/// to the given record size, except for the final record which may be shorter. +/// +/// * Each record has at least one padding byte; if necessary, additional padding +/// bytes will be inserted beyond what was requested by the caller in order +/// to meet this requirement. (This ensures each record has enough room for the +/// padding delimiter byte). +/// +/// * The plaintext is distributed as evenly as possible between records. Records +/// consisting entirely of padding will only be produced in degenerate cases such +/// as where the caller requested far more padding than available plaintext, or +/// where the requested total size falls just beyond a record boundary. +/// +fn split_into_records( + plaintext: &[u8], + pad_length: usize, + rs: usize, +) -> Result> { + // Adjust for encryption overhead. + if rs < ECE_AES128GCM_MIN_RS as usize { + return Err(Error::InvalidRecordSize); + } + let rs = rs - ECE_TAG_LENGTH; + // Ensure we have enough padding to give at least one byte of it to each record. + // This is the only reason why we might expand the padding beyond what was requested. + let mut min_num_records = plaintext.len() / (rs - 1); + if plaintext.len() % (rs - 1) != 0 { + min_num_records += 1; + } + let pad_length = std::cmp::max(pad_length, min_num_records); + // Knowing the total data size, determines the number of records. + let total_size = plaintext.len() + pad_length; + let mut num_records = total_size / rs; + let size_of_final_record = total_size % rs; + if size_of_final_record > 0 { + num_records += 1; + } + assert!( + num_records >= min_num_records, + "record chunking error: we miscalculated the minimum number of records ({} < {})", + num_records, + min_num_records, + ); + // Evenly distribute the plaintext between that many records. + // There may of course be some leftover that won't distribute evenly. + let plaintext_per_record = plaintext.len() / num_records; + let mut extra_plaintext = plaintext.len() % num_records; + // If the final record is very small, we might not be able to fit + // the recommended number of plaintext bytes, so redistribute them. + // (Remember, the final block must contain at least one padding byte). + if size_of_final_record > 0 && plaintext_per_record > size_of_final_record - 1 { + extra_plaintext += plaintext_per_record - (size_of_final_record - 1) + } + // And now we can iterate! + Ok(PlaintextRecordIterator { + plaintext, + pad_length, + plaintext_per_record, + extra_plaintext, + rs, + sequence_number: 0, + num_records, + total_size, + }) +} + +/// The underlying iterator implementation for `split_into_records`. +/// +struct PlaintextRecordIterator<'a> { + /// The plaintext that remains to be split. + plaintext: &'a [u8], + /// The amount of padding that remains to be split. + pad_length: usize, + /// The amount of plaintext to put in each record. + plaintext_per_record: usize, + /// The amount of leftover plaintext that could not be distributed evenly. + extra_plaintext: usize, + /// The total number of bytes that will be produced by this iterator. + total_size: usize, + /// The target unencrypted record size. + rs: usize, + /// The total number of records that will be produced. + num_records: usize, + /// The sequence number of the next record to be produced. + sequence_number: usize, +} + +impl<'a> PlaintextRecordIterator<'a> { + pub(crate) fn total_ciphertext_size(&self) -> usize { + self.total_size + self.num_records * ECE_TAG_LENGTH + } +} - pub(crate) fn encrypted_size(&self) -> usize { - self.plaintext.len() + self.padding + ECE_TAG_LENGTH +impl<'a> Iterator for PlaintextRecordIterator<'a> { + type Item = PlaintextRecord<'a>; + fn next(&mut self) -> Option { + let records_remaining = self.num_records - self.sequence_number; + // We stop iterating when we've produced all records. + if records_remaining == 0 { + assert!( + self.plaintext.is_empty(), + "record chunking error: the plaintext was not fully consumed" + ); + assert!( + self.extra_plaintext == 0, + "record chunking error: the extra plaintext was not fully consumed" + ); + assert!( + self.pad_length == 0, + "record chunking error: the padding was not fully consumed" + ); + return None; + } + // Allocate a chunk of plaintext to this record. + // We target `plaintext_per_record` bytes per record, but it's a little + // more complicated than that... + let mut plaintext_share = self.plaintext_per_record; + if plaintext_share > self.plaintext.len() { + // ...because the final record is allowed to be smaller. + assert!( + records_remaining == 1, + "record chunking error: the plaintext was consumed too early" + ); + plaintext_share = self.plaintext.len(); + } else { + // ...because non-final records need to consume any extra plaintext. + if self.extra_plaintext > 0 { + // The extra plaintext must be distributed as evenly as possible + // amongst all but the final record. + let mut extra_share = self.extra_plaintext / (records_remaining - 1); + if self.extra_plaintext % (records_remaining - 1) != 0 { + extra_share += 1; + } + plaintext_share += extra_share; + self.extra_plaintext -= extra_share; + } + } + let plaintext = &self.plaintext[0..plaintext_share]; + self.plaintext = &self.plaintext[plaintext_share..]; + // Fill the rest of the record with padding. + let padding_share = std::cmp::min(self.pad_length, self.rs - plaintext_share); + self.pad_length -= padding_share; + assert!( + padding_share > 0, + "record chunking error: the padding was consumed too early" + ); + // Check where we are in the iteration. + let sequence_number = self.sequence_number; + self.sequence_number += 1; + let is_final = self.sequence_number == self.num_records; + assert!( + is_final || plaintext.len() + padding_share == self.rs, + "record chunking error: non-final record is too short" + ); + // That's a record! + Some(PlaintextRecord { + plaintext, + padding: padding_share, + sequence_number, + is_final, + }) } } @@ -432,3 +579,120 @@ fn generate_info( info[offset..].copy_from_slice(raw_sender_pub_key); Ok(info) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_into_records_17_0_20() { + let records = split_into_records(&[0u8; 17], 0, 20 + ECE_TAG_LENGTH) + .unwrap() + .collect::>(); + // Should fit comfortably into a single record. + assert_eq!(records.len(), 1); + assert_eq!(records[0].plaintext.len(), 17); + assert_eq!(records[0].padding, 1); + assert_eq!(records[0].sequence_number, 0); + assert!(records[0].is_final); + } + + #[test] + fn test_split_into_records_15_0_6() { + let records = split_into_records(&[0u8; 15], 0, 6 + ECE_TAG_LENGTH) + .unwrap() + .collect::>(); + // Should fit exactly across three records. + assert_eq!(records.len(), 3); + + assert_eq!(records[0].plaintext.len(), 5); + assert_eq!(records[0].padding, 1); + assert_eq!(records[0].sequence_number, 0); + assert!(!records[0].is_final); + + assert_eq!(records[1].plaintext.len(), 5); + assert_eq!(records[1].padding, 1); + assert_eq!(records[1].sequence_number, 1); + assert!(!records[1].is_final); + + assert_eq!(records[2].plaintext.len(), 5); + assert_eq!(records[2].padding, 1); + assert_eq!(records[2].sequence_number, 2); + assert!(records[2].is_final); + } + + fn split_and_summarize(payload_len: usize, padding: usize, rs: usize) -> Vec<(usize, usize)> { + split_into_records(&vec![0u8; payload_len], padding, rs + ECE_TAG_LENGTH) + .unwrap() + .map(|record| (record.plaintext.len(), record.padding)) + .collect() + } + + #[test] + fn test_split_into_records_8_2_3() { + // Should expand to 4 bytes of padding, then return 4 equal records + // with two bytes of plaintext and one byte of padding. + assert_eq!( + split_and_summarize(8, 2, 3), + vec![(2, 1), (2, 1), (2, 1), (2, 1)] + ); + } + + #[test] + fn test_split_into_records_8_0_8() { + // Should expand to 2 bytes of padding, 2 records. + // The last record is only size 2, so can only fit 1 plaintext byte. + assert_eq!(split_and_summarize(8, 0, 8), vec![(7, 1), (1, 1)]); + } + + #[test] + fn test_split_into_records_24_6_8() { + // Total length of 30, 4 records. + // Ideally we'd have 6 bytes of plaintext in each, but the final record + // is only length 6 so it can't hold more than 5 bytes of plaintext. + assert_eq!( + split_and_summarize(24, 6, 8), + vec![(7, 1), (6, 2), (6, 2), (5, 1)] + ); + } + + #[test] + fn test_split_into_records_8_6_3() { + // Total length 14, 4 records, the last only 2 bytes long. + // But we can still spread the plaintext so that there's some in each record. + assert_eq!( + split_and_summarize(8, 6, 3), + vec![(2, 1), (2, 1), (2, 1), (1, 2), (1, 1)] + ); + } + + #[test] + fn test_split_into_records_3_25_8() { + // Total length of 28, meaning 4 records. + // One of the records will have to be only padding. + assert_eq!( + split_and_summarize(3, 25, 8), + vec![(1, 7), (1, 7), (1, 7), (0, 4)] + ); + } + + #[test] + fn test_split_into_records_3_35_8() { + // Total length of 38, meaning 5 records. + // Two of the records will have to be only padding. + assert_eq!( + split_and_summarize(3, 35, 8), + vec![(1, 7), (1, 7), (1, 7), (0, 8), (0, 6)] + ); + } + + #[test] + fn test_split_into_records_19_6_8() { + // Total length of 25, 4 records with the final record being only a single byte. + // It therefore can only be padding. + assert_eq!( + split_and_summarize(19, 6, 8), + vec![(7, 1), (6, 2), (6, 2), (0, 1)] + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 8a72477..81cf61e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ pub use crate::{ use crate::{ aes128gcm::ECE_AES128GCM_PAD_SIZE, - common::{WebPushParams, ECE_TAG_LENGTH, ECE_WEBPUSH_AUTH_SECRET_LENGTH}, + common::{WebPushParams, ECE_WEBPUSH_AUTH_SECRET_LENGTH}, }; /// Generate a local ECE key pair and authentication secret. @@ -45,12 +45,7 @@ pub fn encrypt(remote_pub: &[u8], remote_auth: &[u8], data: &[u8]) -> Result params.rs as usize { - params.rs = (data.len() + params.pad_length + ECE_TAG_LENGTH) as u32; - } + let params = WebPushParams::new_for_plaintext(data, ECE_AES128GCM_PAD_SIZE); aes128gcm::encrypt(&*local_key_pair, &*remote_key, &remote_auth, data, params) } @@ -81,6 +76,7 @@ fn generate_keys() -> Result<(Box, Box)> { #[cfg(all(test, feature = "backend-openssl"))] mod aes128gcm_tests { + use super::common::ECE_TAG_LENGTH; use super::*; use hex; @@ -173,6 +169,45 @@ mod aes128gcm_tests { assert_eq!(decrypted, plaintext.to_vec()); } + #[test] + fn test_e2e_with_different_record_sizes_and_padding() { + let (local_key, remote_key) = generate_keys().unwrap(); + let plaintext = b"When I grow up, I want to be a watermelon"; + let mut auth_secret = vec![0u8; 16]; + let cryptographer = crypto::holder::get_cryptographer(); + cryptographer.random_bytes(&mut auth_secret).unwrap(); + let remote_public = cryptographer + .import_public_key(&remote_key.pub_as_raw().unwrap()) + .unwrap(); + let plen = plaintext.len(); + // Try a variety of different record sizes. The numbers here aren't particularly deeply + // considered, just a selection of numbers that might be interesting. (Although they did + // trigger a bunch of interesting edge-cases during development, which is re-assuring). + for plaintext_rs in &[2, 3, 7, 8, plen - 1, plen, plen + 1, 1024, 8192] { + let rs = (*plaintext_rs + ECE_TAG_LENGTH) as u32; + // Try a variety of padding lengths. Again, not deeply considered numbers. + for pad_length in &[0, 1, 2, 8, 37, 127, 128] { + let pad_length = *pad_length; + let params = WebPushParams { + rs, + pad_length, + ..WebPushParams::default() + }; + let ciphertext = aes128gcm::encrypt( + &*local_key, + &*remote_public, + &auth_secret, + plaintext, + params, + ) + .unwrap(); + let decrypted = + aes128gcm::decrypt(&*remote_key, &auth_secret, &ciphertext).unwrap(); + assert_eq!(decrypted, plaintext.to_vec()); + } + } + } + #[test] fn test_conv_fn() -> Result<()> { let (local_key, auth) = generate_keypair_and_auth_secret()?;