diff --git a/src/errors.rs b/src/errors.rs index c1177269..d416aa04 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -60,6 +60,9 @@ pub enum Error { /// Label too long. LabelTooLong, + + /// Invalid padding length. + InvalidPadLen, } #[cfg(feature = "std")] @@ -87,6 +90,7 @@ impl core::fmt::Display for Error { Error::Pkcs8(err) => write!(f, "{}", err), Error::Internal => write!(f, "internal error"), Error::LabelTooLong => write!(f, "label too long"), + Error::InvalidPadLen => write!(f, "invalid padding length"), } } } diff --git a/src/internals.rs b/src/internals.rs index 0a1e5f76..90ae2181 100644 --- a/src/internals.rs +++ b/src/internals.rs @@ -174,14 +174,34 @@ pub fn unblind(key: &impl PublicKeyParts, m: &BigUint, unblinder: &BigUint) -> B /// Returns a new vector of the given length, with 0s left padded. #[inline] -pub fn left_pad(input: &[u8], size: usize) -> Vec { - let n = if input.len() > size { - size - } else { - input.len() - }; +pub fn left_pad(input: &[u8], padded_len: usize) -> Result> { + if input.len() > padded_len { + return Err(Error::InvalidPadLen); + } - let mut out = vec![0u8; size]; - out[size - n..].copy_from_slice(input); - out + let mut out = vec![0u8; padded_len]; + out[padded_len - input.len()..].copy_from_slice(input); + Ok(out) +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_left_pad() { + const INPUT_LEN: usize = 3; + let input = vec![0u8; INPUT_LEN]; + + // input len < padded len + let padded = left_pad(&input, INPUT_LEN + 1).unwrap(); + assert_eq!(padded.len(), INPUT_LEN + 1); + + // input len == padded len + let padded = left_pad(&input, INPUT_LEN).unwrap(); + assert_eq!(padded.len(), INPUT_LEN); + + // input len > padded len + let padded = left_pad(&input, INPUT_LEN - 1); + assert!(padded.is_err()); + } } diff --git a/src/raw.rs b/src/raw.rs index f77fd89b..cc0af721 100644 --- a/src/raw.rs +++ b/src/raw.rs @@ -3,7 +3,7 @@ use num_bigint::BigUint; use rand_core::CryptoRngCore; use zeroize::Zeroize; -use crate::errors::{Error, Result}; +use crate::errors::Result; use crate::internals; use crate::key::{RsaPrivateKey, RsaPublicKey}; @@ -29,16 +29,12 @@ impl EncryptionPrimitive for RsaPublicKey { let mut c_bytes = c.to_bytes_be(); let ciphertext = internals::left_pad(&c_bytes, pad_size); - if pad_size < ciphertext.len() { - return Err(Error::Verification); - } - // clear out tmp values m.zeroize(); c.zeroize(); c_bytes.zeroize(); - Ok(ciphertext) + ciphertext } } @@ -59,6 +55,6 @@ impl DecryptionPrimitive for RsaPrivateKey { m.zeroize(); m_bytes.zeroize(); - Ok(plaintext) + plaintext } }