From abe9c386af032e9e1a96eb8fb9b153270cb6484b Mon Sep 17 00:00:00 2001 From: Dmitry Baryshkov Date: Tue, 18 Apr 2023 21:57:22 +0300 Subject: [PATCH 1/6] pss: use `BigUint` as `Signature`'s inner type Follow the pkcs1v15 change and use BigUint as a Signature's internal implementation. Signed-off-by: Dmitry Baryshkov --- src/pss.rs | 113 +++++++++++++++++++++++++++-------------------------- 1 file changed, 58 insertions(+), 55 deletions(-) diff --git a/src/pss.rs b/src/pss.rs index 1ef9a259..a25ddb5d 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -9,13 +9,13 @@ //! [Probabilistic Signature Scheme]: https://en.wikipedia.org/wiki/Probabilistic_signature_scheme //! [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1 -use alloc::boxed::Box; -use alloc::vec::Vec; +use alloc::{boxed::Box, string::ToString, vec::Vec}; use core::fmt::{self, Debug, Display, Formatter, LowerHex, UpperHex}; use core::marker::PhantomData; use const_oid::{AssociatedOid, ObjectIdentifier}; use digest::{Digest, DynDigest, FixedOutputReset}; +use num_bigint::BigUint; use pkcs1::RsaPssParams; use pkcs8::{ spki::{ @@ -106,7 +106,14 @@ impl SignatureScheme for Pss { } fn verify(mut self, pub_key: &Pub, hashed: &[u8], sig: &[u8]) -> Result<()> { - verify(pub_key, hashed, sig, &mut *self.digest, self.salt_len) + verify( + pub_key, + hashed, + &BigUint::from_bytes_be(sig), + sig.len(), + &mut *self.digest, + self.salt_len, + ) } } @@ -125,62 +132,48 @@ impl Debug for Pss { /// [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1 #[derive(Clone, PartialEq, Eq)] pub struct Signature { - bytes: Box<[u8]>, + inner: BigUint, + len: usize, } impl SignatureEncoding for Signature { type Repr = Box<[u8]>; } -impl From> for Signature { - fn from(bytes: Box<[u8]>) -> Self { - Self { bytes } - } -} - impl TryFrom<&[u8]> for Signature { type Error = signature::Error; fn try_from(bytes: &[u8]) -> signature::Result { Ok(Self { - bytes: bytes.into(), + len: bytes.len(), + inner: BigUint::from_bytes_be(bytes), }) } } impl From for Box<[u8]> { fn from(signature: Signature) -> Box<[u8]> { - signature.bytes + signature.inner.to_bytes_be().into_boxed_slice() } } impl Debug for Signature { fn fmt(&self, fmt: &mut Formatter<'_>) -> core::result::Result<(), core::fmt::Error> { - fmt.debug_list().entries(self.bytes.iter()).finish() - } -} - -impl AsRef<[u8]> for Signature { - fn as_ref(&self) -> &[u8] { - self.bytes.as_ref() + fmt.debug_tuple("Signature") + .field(&self.to_string()) + .finish() } } impl LowerHex for Signature { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - for byte in self.bytes.iter() { - write!(f, "{:02x}", byte)?; - } - Ok(()) + write!(f, "{:x}", &self.inner) } } impl UpperHex for Signature { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - for byte in self.bytes.iter() { - write!(f, "{:02X}", byte)?; - } - Ok(()) + write!(f, "{:X}", &self.inner) } } @@ -193,18 +186,19 @@ impl Display for Signature { pub(crate) fn verify( pub_key: &PK, hashed: &[u8], - sig: &[u8], + sig: &BigUint, + sig_len: usize, digest: &mut dyn DynDigest, salt_len: usize, ) -> Result<()> { - if sig.len() != pub_key.size() { + if sig_len != pub_key.size() { return Err(Error::Verification); } let em_bits = pub_key.n().bits() - 1; let em_len = (em_bits + 7) / 8; let key_len = pub_key.size(); - let mut em = pub_key.raw_encryption_primitive(sig, key_len)?; + let mut em = pub_key.raw_int_encryption_primitive(sig, key_len)?; emsa_pss_verify( hashed, @@ -218,21 +212,22 @@ pub(crate) fn verify( pub(crate) fn verify_digest( pub_key: &PK, hashed: &[u8], - sig: &[u8], + sig: &BigUint, + sig_len: usize, salt_len: usize, ) -> Result<()> where PK: PublicKey, D: Digest + FixedOutputReset, { - if sig.len() != pub_key.size() { + if sig_len != pub_key.size() { return Err(Error::Verification); } let em_bits = pub_key.n().bits() - 1; let em_len = (em_bits + 7) / 8; let key_len = pub_key.size(); - let mut em = pub_key.raw_encryption_primitive(sig, key_len)?; + let mut em = pub_key.raw_int_encryption_primitive(sig, key_len)?; emsa_pss_verify_digest::(hashed, &mut em[key_len - em_len..], em_bits, salt_len) } @@ -762,9 +757,9 @@ where rng: &mut impl CryptoRngCore, msg: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, false, &self.inner, &D::digest(msg), self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, _, D>(rng, false, &self.inner, &D::digest(msg), self.salt_len)? + .as_slice() + .try_into() } } @@ -777,9 +772,9 @@ where rng: &mut impl CryptoRngCore, digest: D, ) -> signature::Result { - sign_digest::<_, _, D>(rng, false, &self.inner, &digest.finalize(), self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, _, D>(rng, false, &self.inner, &digest.finalize(), self.salt_len)? + .as_slice() + .try_into() } } @@ -792,9 +787,9 @@ where rng: &mut impl CryptoRngCore, prehash: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, false, &self.inner, prehash, self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, _, D>(rng, false, &self.inner, prehash, self.salt_len)? + .as_slice() + .try_into() } } @@ -935,9 +930,9 @@ where rng: &mut impl CryptoRngCore, msg: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, true, &self.inner, &D::digest(msg), self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, _, D>(rng, true, &self.inner, &D::digest(msg), self.salt_len)? + .as_slice() + .try_into() } } @@ -950,9 +945,9 @@ where rng: &mut impl CryptoRngCore, digest: D, ) -> signature::Result { - sign_digest::<_, _, D>(rng, true, &self.inner, &digest.finalize(), self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, _, D>(rng, true, &self.inner, &digest.finalize(), self.salt_len)? + .as_slice() + .try_into() } } @@ -965,9 +960,9 @@ where rng: &mut impl CryptoRngCore, prehash: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, true, &self.inner, prehash, self.salt_len) - .map(|v| v.into_boxed_slice().into()) - .map_err(|e| e.into()) + sign_digest::<_, _, D>(rng, true, &self.inner, prehash, self.salt_len)? + .as_slice() + .try_into() } } @@ -1063,7 +1058,8 @@ where verify_digest::<_, D>( &self.inner, &D::digest(msg), - signature.as_ref(), + &signature.inner, + signature.len, self.salt_len, ) .map_err(|e| e.into()) @@ -1078,7 +1074,8 @@ where verify_digest::<_, D>( &self.inner, &digest.finalize(), - signature.as_ref(), + &signature.inner, + signature.len, self.salt_len, ) .map_err(|e| e.into()) @@ -1090,8 +1087,14 @@ where D: Digest + FixedOutputReset, { fn verify_prehash(&self, prehash: &[u8], signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>(&self.inner, prehash, signature.as_ref(), self.salt_len) - .map_err(|e| e.into()) + verify_digest::<_, D>( + &self.inner, + prehash, + &signature.inner, + signature.len, + self.salt_len, + ) + .map_err(|e| e.into()) } } From 487f249de635ea016181dc7125123ce4d905cfe7 Mon Sep 17 00:00:00 2001 From: Dmitry Baryshkov Date: Tue, 18 Apr 2023 22:33:18 +0300 Subject: [PATCH 2/6] feat: simplify public traits The crate contains several exported traits targeting hardware-accelerated implementations (PublicKey, PrivateKey, EncryptionPrimitive, DecriptionPrimitive). However these traits overcomplicate internal structure of the crate. It is not clear, which level of API can be implemented by the hardware accelerators. The crate is already quite complicated, implementing both PaddingScheme-based API and Signer/Verifier/Encryptor/Decryptor API. Remove the complication for now. The proper level of indirection can be introduced once support for actual hardware accelerators is implemented. Signed-off-by: Dmitry Baryshkov --- README.md | 2 +- src/key.rs | 72 ++++++++++++++++++++++++++++++------------------- src/lib.rs | 7 +++-- src/oaep.rs | 60 ++++++++++++++++++----------------------- src/padding.rs | 16 +++++------ src/pkcs1v15.rs | 49 ++++++++++++++++----------------- src/pss.rs | 63 +++++++++++++++++++------------------------ src/raw.rs | 64 ------------------------------------------- 8 files changed, 134 insertions(+), 199 deletions(-) delete mode 100644 src/raw.rs diff --git a/README.md b/README.md index 551058b1..9027659c 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ A portable RSA implementation in pure Rust. ## Example ```rust -use rsa::{Pkcs1v15Encrypt, PublicKey, RsaPrivateKey, RsaPublicKey}; +use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey}; let mut rng = rand::thread_rng(); let bits = 2048; diff --git a/src/key.rs b/src/key.rs index 87784395..4635b530 100644 --- a/src/key.rs +++ b/src/key.rs @@ -7,14 +7,14 @@ use num_traits::{One, ToPrimitive}; use rand_core::CryptoRngCore; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use zeroize::Zeroize; +use zeroize::{Zeroize, Zeroizing}; use crate::algorithms::{generate_multi_prime_key, generate_multi_prime_key_with_exp}; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; +use crate::internals; use crate::padding::{PaddingScheme, SignatureScheme}; -use crate::raw::{DecryptionPrimitive, EncryptionPrimitive}; /// Components of an RSA public key. pub trait PublicKeyParts { @@ -31,8 +31,6 @@ pub trait PublicKeyParts { } } -pub trait PrivateKey: DecryptionPrimitive + PublicKeyParts {} - /// Represents the public part of an RSA key. #[derive(Debug, Clone, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -161,25 +159,6 @@ impl From<&RsaPrivateKey> for RsaPublicKey { } } -/// Generic trait for operations on a public key. -pub trait PublicKey: EncryptionPrimitive + PublicKeyParts { - /// Encrypt the given message. - fn encrypt( - &self, - rng: &mut R, - padding: P, - msg: &[u8], - ) -> Result>; - - /// Verify a signed message. - /// - /// `hashed` must be the result of hashing the input using the hashing function - /// passed in through `hash`. - /// - /// If the message is valid `Ok(())` is returned, otherwise an `Err` indicating failure. - fn verify(&self, scheme: S, hashed: &[u8], sig: &[u8]) -> Result<()>; -} - impl PublicKeyParts for RsaPublicKey { fn n(&self) -> &BigUint { &self.n @@ -190,8 +169,9 @@ impl PublicKeyParts for RsaPublicKey { } } -impl PublicKey for RsaPublicKey { - fn encrypt( +impl RsaPublicKey { + /// Encrypt the given message. + pub fn encrypt( &self, rng: &mut R, padding: P, @@ -200,7 +180,13 @@ impl PublicKey for RsaPublicKey { padding.encrypt(rng, self, msg) } - fn verify(&self, scheme: S, hashed: &[u8], sig: &[u8]) -> Result<()> { + /// Verify a signed message. + /// + /// `hashed` must be the result of hashing the input using the hashing function + /// passed in through `hash`. + /// + /// If the message is valid `Ok(())` is returned, otherwise an `Err` indicating failure. + pub fn verify(&self, scheme: S, hashed: &[u8], sig: &[u8]) -> Result<()> { scheme.verify(self, hashed, sig) } } @@ -239,6 +225,16 @@ impl RsaPublicKey { pub fn new_unchecked(n: BigUint, e: BigUint) -> Self { Self { n, e } } + + pub(crate) fn raw_int_encryption_primitive( + &self, + plaintext: &BigUint, + pad_size: usize, + ) -> Result> { + let c = Zeroizing::new(internals::encrypt(self, plaintext)); + let c_bytes = Zeroizing::new(c.to_bytes_be()); + internals::left_pad(&c_bytes, pad_size) + } } impl PublicKeyParts for RsaPrivateKey { @@ -251,8 +247,6 @@ impl PublicKeyParts for RsaPrivateKey { } } -impl PrivateKey for RsaPrivateKey {} - impl RsaPrivateKey { /// Generate a new Rsa key pair of the given bit size using the passed in `rng`. pub fn new(rng: &mut R, bit_size: usize) -> Result { @@ -461,6 +455,28 @@ impl RsaPrivateKey { ) -> Result> { padding.sign(Some(rng), self, digest_in) } + + /// Do NOT use directly! Only for implementors. + pub(crate) fn raw_decryption_primitive( + &self, + rng: Option<&mut R>, + ciphertext: &[u8], + pad_size: usize, + ) -> Result> { + let int = Zeroizing::new(BigUint::from_bytes_be(ciphertext)); + self.raw_int_decryption_primitive(rng, &int, pad_size) + } + + pub(crate) fn raw_int_decryption_primitive( + &self, + rng: Option<&mut R>, + ciphertext: &BigUint, + pad_size: usize, + ) -> Result> { + let m = Zeroizing::new(internals::decrypt_and_check(rng, self, ciphertext)?); + let m_bytes = Zeroizing::new(m.to_bytes_be()); + internals::left_pad(&m_bytes, pad_size) + } } /// Check that the public key is well formed and has an exponent within acceptable bounds. diff --git a/src/lib.rs b/src/lib.rs index 76956fe6..5efcb86b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ //! //! ## PKCS#1 v1.5 encryption //! ``` -//! use rsa::{PublicKey, RsaPrivateKey, RsaPublicKey, Pkcs1v15Encrypt}; +//! use rsa::{RsaPrivateKey, RsaPublicKey, Pkcs1v15Encrypt}; //! //! let mut rng = rand::thread_rng(); //! @@ -34,7 +34,7 @@ //! //! ## OAEP encryption //! ``` -//! use rsa::{PublicKey, RsaPrivateKey, RsaPublicKey, Oaep}; +//! use rsa::{RsaPrivateKey, RsaPublicKey, Oaep}; //! //! let mut rng = rand::thread_rng(); //! @@ -233,7 +233,6 @@ mod dummy_rng; mod encoding; mod key; mod padding; -mod raw; pub use pkcs1; pub use pkcs8; @@ -241,7 +240,7 @@ pub use pkcs8; pub use sha2; pub use crate::{ - key::{PublicKey, PublicKeyParts, RsaPrivateKey, RsaPublicKey}, + key::{PublicKeyParts, RsaPrivateKey, RsaPublicKey}, oaep::Oaep, padding::{PaddingScheme, SignatureScheme}, pkcs1v15::{Pkcs1v15Encrypt, Pkcs1v15Sign}, diff --git a/src/oaep.rs b/src/oaep.rs index 309a0cd1..7750234c 100644 --- a/src/oaep.rs +++ b/src/oaep.rs @@ -12,13 +12,14 @@ use core::marker::PhantomData; use rand_core::CryptoRngCore; use digest::{Digest, DynDigest, FixedOutputReset}; +use num_bigint::BigUint; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use zeroize::Zeroizing; use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; -use crate::key::{self, PrivateKey, PublicKey, RsaPrivateKey, RsaPublicKey}; +use crate::key::{self, PublicKeyParts, RsaPrivateKey, RsaPublicKey}; use crate::padding::PaddingScheme; use crate::traits::{Decryptor, RandomizedDecryptor, RandomizedEncryptor}; @@ -55,7 +56,7 @@ impl Oaep { /// ``` /// use sha1::Sha1; /// use sha2::Sha256; - /// use rsa::{BigUint, RsaPublicKey, Oaep, PublicKey}; + /// use rsa::{BigUint, RsaPublicKey, Oaep, }; /// use base64ct::{Base64, Encoding}; /// /// let n = Base64::decode_vec("ALHgDoZmBQIx+jTmgeeHW6KsPOrj11f6CvWsiRleJlQpW77AwSZhd21ZDmlTKfaIHBSUxRUsuYNh7E2SHx8rkFVCQA2/gXkZ5GK2IUbzSTio9qXA25MWHvVxjMfKSL8ZAxZyKbrG94FLLszFAFOaiLLY8ECs7g+dXOriYtBwLUJK+lppbd+El+8ZA/zH0bk7vbqph5pIoiWggxwdq3mEz4LnrUln7r6dagSQzYErKewY8GADVpXcq5mfHC1xF2DFBub7bFjMVM5fHq7RK+pG5xjNDiYITbhLYrbVv3X0z75OvN0dY49ITWjM7xyvMWJXVJS7sJlgmCCL6RwWgP8PhcE=").unwrap(); @@ -92,7 +93,7 @@ impl Oaep { /// ``` /// use sha1::Sha1; /// use sha2::Sha256; - /// use rsa::{BigUint, RsaPublicKey, Oaep, PublicKey}; + /// use rsa::{BigUint, RsaPublicKey, Oaep, }; /// use base64ct::{Base64, Encoding}; /// /// let n = Base64::decode_vec("ALHgDoZmBQIx+jTmgeeHW6KsPOrj11f6CvWsiRleJlQpW77AwSZhd21ZDmlTKfaIHBSUxRUsuYNh7E2SHx8rkFVCQA2/gXkZ5GK2IUbzSTio9qXA25MWHvVxjMfKSL8ZAxZyKbrG94FLLszFAFOaiLLY8ECs7g+dXOriYtBwLUJK+lppbd+El+8ZA/zH0bk7vbqph5pIoiWggxwdq3mEz4LnrUln7r6dagSQzYErKewY8GADVpXcq5mfHC1xF2DFBub7bFjMVM5fHq7RK+pG5xjNDiYITbhLYrbVv3X0z75OvN0dY49ITWjM7xyvMWJXVJS7sJlgmCCL6RwWgP8PhcE=").unwrap(); @@ -131,10 +132,10 @@ impl Oaep { } impl PaddingScheme for Oaep { - fn decrypt( + fn decrypt( mut self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result> { decrypt( @@ -147,10 +148,10 @@ impl PaddingScheme for Oaep { ) } - fn encrypt( + fn encrypt( mut self, rng: &mut Rng, - pub_key: &Pub, + pub_key: &RsaPublicKey, msg: &[u8], ) -> Result> { encrypt( @@ -175,9 +176,9 @@ impl fmt::Debug for Oaep { } #[inline] -fn encrypt_internal( +fn encrypt_internal( rng: &mut R, - pub_key: &K, + pub_key: &RsaPublicKey, msg: &[u8], p_hash: &[u8], h_size: usize, @@ -206,7 +207,8 @@ fn encrypt_internal( +fn encrypt( rng: &mut R, - pub_key: &K, + pub_key: &RsaPublicKey, msg: &[u8], digest: &mut dyn DynDigest, mgf_digest: &mut dyn DynDigest, @@ -249,14 +251,9 @@ fn encrypt( /// /// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 #[inline] -fn encrypt_digest< - R: CryptoRngCore + ?Sized, - K: PublicKey, - D: Digest, - MGD: Digest + FixedOutputReset, ->( +fn encrypt_digest( rng: &mut R, - pub_key: &K, + pub_key: &RsaPublicKey, msg: &[u8], label: Option, ) -> Result> { @@ -289,9 +286,9 @@ fn encrypt_digest< /// /// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 #[inline] -fn decrypt( +fn decrypt( rng: Option<&mut R>, - priv_key: &SK, + priv_key: &RsaPrivateKey, ciphertext: &[u8], digest: &mut dyn DynDigest, mgf_digest: &mut dyn DynDigest, @@ -343,14 +340,9 @@ fn decrypt( /// /// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 #[inline] -fn decrypt_digest< - R: CryptoRngCore + ?Sized, - SK: PrivateKey, - D: Digest, - MGD: Digest + FixedOutputReset, ->( +fn decrypt_digest( rng: Option<&mut R>, - priv_key: &SK, + priv_key: &RsaPrivateKey, ciphertext: &[u8], label: Option, ) -> Result> { @@ -390,9 +382,9 @@ fn decrypt_digest< /// `rng` is given. It returns one or zero in valid that indicates whether the /// plaintext was correctly structured. #[inline] -fn decrypt_inner( +fn decrypt_inner( rng: Option<&mut R>, - priv_key: &SK, + priv_key: &RsaPrivateKey, ciphertext: &[u8], h_size: usize, expected_p_hash: &[u8], @@ -491,7 +483,7 @@ where rng: &mut R, msg: &[u8], ) -> Result> { - encrypt_digest::<_, _, D, MGD>(rng, &self.inner, msg, self.label.as_ref().cloned()) + encrypt_digest::<_, D, MGD>(rng, &self.inner, msg, self.label.as_ref().cloned()) } } @@ -542,7 +534,7 @@ where MGD: Digest + FixedOutputReset, { fn decrypt(&self, ciphertext: &[u8]) -> Result> { - decrypt_digest::( + decrypt_digest::( None, &self.inner, ciphertext, @@ -561,7 +553,7 @@ where rng: &mut R, ciphertext: &[u8], ) -> Result> { - decrypt_digest::<_, _, D, MGD>( + decrypt_digest::<_, D, MGD>( Some(rng), &self.inner, ciphertext, @@ -572,7 +564,7 @@ where #[cfg(test)] mod tests { - use crate::key::{PublicKey, PublicKeyParts, RsaPrivateKey, RsaPublicKey}; + use crate::key::{PublicKeyParts, RsaPrivateKey, RsaPublicKey}; use crate::oaep::{DecryptingKey, EncryptingKey, Oaep}; use crate::traits::{Decryptor, RandomizedDecryptor, RandomizedEncryptor}; diff --git a/src/padding.rs b/src/padding.rs index 391779f6..ce198fc3 100644 --- a/src/padding.rs +++ b/src/padding.rs @@ -5,7 +5,7 @@ use alloc::vec::Vec; use rand_core::CryptoRngCore; use crate::errors::Result; -use crate::key::{PrivateKey, PublicKey}; +use crate::key::{RsaPrivateKey, RsaPublicKey}; /// Padding scheme used for encryption. pub trait PaddingScheme { @@ -13,18 +13,18 @@ pub trait PaddingScheme { /// /// If an `rng` is passed, it uses RSA blinding to help mitigate timing /// side-channel attacks. - fn decrypt( + fn decrypt( self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result>; /// Encrypt the given message using the given public key. - fn encrypt( + fn encrypt( self, rng: &mut Rng, - pub_key: &Pub, + pub_key: &RsaPublicKey, msg: &[u8], ) -> Result>; } @@ -32,10 +32,10 @@ pub trait PaddingScheme { /// Digital signature scheme. pub trait SignatureScheme { /// Sign the given digest. - fn sign( + fn sign( self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, hashed: &[u8], ) -> Result>; @@ -45,5 +45,5 @@ pub trait SignatureScheme { /// passed in through `hash`. /// /// If the message is valid `Ok(())` is returned, otherwise an `Err` indicating failure. - fn verify(self, pub_key: &Pub, hashed: &[u8], sig: &[u8]) -> Result<()>; + fn verify(self, pub_key: &RsaPublicKey, hashed: &[u8], sig: &[u8]) -> Result<()>; } diff --git a/src/pkcs1v15.rs b/src/pkcs1v15.rs index a6cb13a3..e9a7ead8 100644 --- a/src/pkcs1v15.rs +++ b/src/pkcs1v15.rs @@ -29,7 +29,7 @@ use zeroize::Zeroizing; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; -use crate::key::{self, PrivateKey, PublicKey}; +use crate::key::{self, PublicKeyParts}; use crate::padding::{PaddingScheme, SignatureScheme}; use crate::traits::{Decryptor, EncryptingKeypair, RandomizedDecryptor, RandomizedEncryptor}; use crate::{RsaPrivateKey, RsaPublicKey}; @@ -39,19 +39,19 @@ use crate::{RsaPrivateKey, RsaPublicKey}; pub struct Pkcs1v15Encrypt; impl PaddingScheme for Pkcs1v15Encrypt { - fn decrypt( + fn decrypt( self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result> { decrypt(rng, priv_key, ciphertext) } - fn encrypt( + fn encrypt( self, rng: &mut Rng, - pub_key: &Pub, + pub_key: &RsaPublicKey, msg: &[u8], ) -> Result> { encrypt(rng, pub_key, msg) @@ -103,10 +103,10 @@ impl Pkcs1v15Sign { } impl SignatureScheme for Pkcs1v15Sign { - fn sign( + fn sign( self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, hashed: &[u8], ) -> Result> { if let Some(hash_len) = self.hash_len { @@ -118,7 +118,7 @@ impl SignatureScheme for Pkcs1v15Sign { sign(rng, priv_key, &self.prefix, hashed) } - fn verify(self, pub_key: &Pub, hashed: &[u8], sig: &[u8]) -> Result<()> { + fn verify(self, pub_key: &RsaPublicKey, hashed: &[u8], sig: &[u8]) -> Result<()> { if let Some(hash_len) = self.hash_len { if hashed.len() != hash_len { return Err(Error::InputNotHashed); @@ -192,9 +192,9 @@ impl Display for Signature { /// scheme from PKCS#1 v1.5. The message must be no longer than the /// length of the public modulus minus 11 bytes. #[inline] -pub(crate) fn encrypt( +pub(crate) fn encrypt( rng: &mut R, - pub_key: &PK, + pub_key: &RsaPublicKey, msg: &[u8], ) -> Result> { key::check_public(pub_key)?; @@ -211,7 +211,8 @@ pub(crate) fn encrypt( em[k - msg.len() - 1] = 0; em[k - msg.len()..].copy_from_slice(msg); - pub_key.raw_encryption_primitive(&em, pub_key.size()) + let int = Zeroizing::new(BigUint::from_bytes_be(&em)); + pub_key.raw_int_encryption_primitive(&int, pub_key.size()) } /// Decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5. @@ -224,9 +225,9 @@ pub(crate) fn encrypt( /// forge signatures as if they had the private key. See /// `decrypt_session_key` for a way of solving this problem. #[inline] -pub(crate) fn decrypt( +pub(crate) fn decrypt( rng: Option<&mut R>, - priv_key: &SK, + priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result> { key::check_public(priv_key)?; @@ -253,9 +254,9 @@ pub(crate) fn decrypt( /// messages to signatures and identify the signed messages. As ever, /// signatures provide authenticity, not confidentiality. #[inline] -pub(crate) fn sign( +pub(crate) fn sign( rng: Option<&mut R>, - priv_key: &SK, + priv_key: &RsaPrivateKey, prefix: &[u8], hashed: &[u8], ) -> Result> { @@ -279,8 +280,8 @@ pub(crate) fn sign( /// Verifies an RSA PKCS#1 v1.5 signature. #[inline] -pub(crate) fn verify( - pub_key: &PK, +pub(crate) fn verify( + pub_key: &RsaPublicKey, prefix: &[u8], hashed: &[u8], sig: &BigUint, @@ -341,9 +342,9 @@ where /// in order to maintain constant memory access patterns. If the plaintext was /// valid then index contains the index of the original message in em. #[inline] -fn decrypt_inner( +fn decrypt_inner( rng: Option<&mut R>, - priv_key: &SK, + priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result<(u8, Vec, u32)> { let k = priv_key.size(); @@ -543,7 +544,7 @@ where D: Digest, { fn try_sign(&self, msg: &[u8]) -> signature::Result { - sign::(None, &self.inner, &self.prefix, &D::digest(msg))? + sign::(None, &self.inner, &self.prefix, &D::digest(msg))? .as_slice() .try_into() } @@ -569,7 +570,7 @@ where D: Digest, { fn try_sign_digest(&self, digest: D) -> signature::Result { - sign::(None, &self.inner, &self.prefix, &digest.finalize())? + sign::(None, &self.inner, &self.prefix, &digest.finalize())? .as_slice() .try_into() } @@ -595,7 +596,7 @@ where D: Digest, { fn sign_prehash(&self, prehash: &[u8]) -> signature::Result { - sign::(None, &self.inner, &self.prefix, prehash)? + sign::(None, &self.inner, &self.prefix, prehash)? .as_slice() .try_into() } @@ -819,7 +820,7 @@ impl DecryptingKey { impl Decryptor for DecryptingKey { fn decrypt(&self, ciphertext: &[u8]) -> Result> { - decrypt::(None, &self.inner, ciphertext) + decrypt::(None, &self.inner, ciphertext) } } @@ -899,7 +900,7 @@ mod tests { use sha3::Sha3_256; use signature::{RandomizedSigner, Signer, Verifier}; - use crate::{PublicKey, PublicKeyParts, RsaPrivateKey, RsaPublicKey}; + use crate::{PublicKeyParts, RsaPrivateKey, RsaPublicKey}; #[test] fn test_non_zero_bytes() { diff --git a/src/pss.rs b/src/pss.rs index a25ddb5d..ac6ac35f 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -34,7 +34,7 @@ use subtle::{Choice, ConstantTimeEq}; use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; use crate::errors::{Error, Result}; -use crate::key::{PrivateKey, PublicKey}; +use crate::key::PublicKeyParts; use crate::padding::SignatureScheme; use crate::{RsaPrivateKey, RsaPublicKey}; @@ -89,10 +89,10 @@ impl Pss { } impl SignatureScheme for Pss { - fn sign( + fn sign( mut self, rng: Option<&mut Rng>, - priv_key: &Priv, + priv_key: &RsaPrivateKey, hashed: &[u8], ) -> Result> { sign( @@ -105,7 +105,7 @@ impl SignatureScheme for Pss { ) } - fn verify(mut self, pub_key: &Pub, hashed: &[u8], sig: &[u8]) -> Result<()> { + fn verify(mut self, pub_key: &RsaPublicKey, hashed: &[u8], sig: &[u8]) -> Result<()> { verify( pub_key, hashed, @@ -183,8 +183,8 @@ impl Display for Signature { } } -pub(crate) fn verify( - pub_key: &PK, +pub(crate) fn verify( + pub_key: &RsaPublicKey, hashed: &[u8], sig: &BigUint, sig_len: usize, @@ -209,15 +209,14 @@ pub(crate) fn verify( ) } -pub(crate) fn verify_digest( - pub_key: &PK, +pub(crate) fn verify_digest( + pub_key: &RsaPublicKey, hashed: &[u8], sig: &BigUint, sig_len: usize, salt_len: usize, ) -> Result<()> where - PK: PublicKey, D: Digest + FixedOutputReset, { if sig_len != pub_key.size() { @@ -237,10 +236,10 @@ where /// Note that hashed must be the result of hashing the input message using the /// given hash function. The opts argument may be nil, in which case sensible /// defaults are used. -pub(crate) fn sign( +pub(crate) fn sign( rng: &mut T, blind: bool, - priv_key: &SK, + priv_key: &RsaPrivateKey, hashed: &[u8], salt_len: usize, digest: &mut dyn DynDigest, @@ -251,21 +250,17 @@ pub(crate) fn sign( sign_pss_with_salt(blind.then_some(rng), priv_key, hashed, &salt, digest) } -pub(crate) fn sign_digest< - T: CryptoRngCore + ?Sized, - SK: PrivateKey, - D: Digest + FixedOutputReset, ->( +pub(crate) fn sign_digest( rng: &mut T, blind: bool, - priv_key: &SK, + priv_key: &RsaPrivateKey, hashed: &[u8], salt_len: usize, ) -> Result> { let mut salt = vec![0; salt_len]; rng.fill_bytes(&mut salt[..]); - sign_pss_with_salt_digest::<_, _, D>(blind.then_some(rng), priv_key, hashed, &salt) + sign_pss_with_salt_digest::<_, D>(blind.then_some(rng), priv_key, hashed, &salt) } /// signPSSWithSalt calculates the signature of hashed using PSS with specified salt. @@ -273,9 +268,9 @@ pub(crate) fn sign_digest< /// Note that hashed must be the result of hashing the input message using the /// given hash function. salt is a random sequence of bytes whose length will be /// later used to verify the signature. -fn sign_pss_with_salt( +fn sign_pss_with_salt( blind_rng: Option<&mut T>, - priv_key: &SK, + priv_key: &RsaPrivateKey, hashed: &[u8], salt: &[u8], digest: &mut dyn DynDigest, @@ -286,13 +281,9 @@ fn sign_pss_with_salt( priv_key.raw_decryption_primitive(blind_rng, &em, priv_key.size()) } -fn sign_pss_with_salt_digest< - T: CryptoRngCore + ?Sized, - SK: PrivateKey, - D: Digest + FixedOutputReset, ->( +fn sign_pss_with_salt_digest( blind_rng: Option<&mut T>, - priv_key: &SK, + priv_key: &RsaPrivateKey, hashed: &[u8], salt: &[u8], ) -> Result> { @@ -757,7 +748,7 @@ where rng: &mut impl CryptoRngCore, msg: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, false, &self.inner, &D::digest(msg), self.salt_len)? + sign_digest::<_, D>(rng, false, &self.inner, &D::digest(msg), self.salt_len)? .as_slice() .try_into() } @@ -772,7 +763,7 @@ where rng: &mut impl CryptoRngCore, digest: D, ) -> signature::Result { - sign_digest::<_, _, D>(rng, false, &self.inner, &digest.finalize(), self.salt_len)? + sign_digest::<_, D>(rng, false, &self.inner, &digest.finalize(), self.salt_len)? .as_slice() .try_into() } @@ -787,7 +778,7 @@ where rng: &mut impl CryptoRngCore, prehash: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, false, &self.inner, prehash, self.salt_len)? + sign_digest::<_, D>(rng, false, &self.inner, prehash, self.salt_len)? .as_slice() .try_into() } @@ -930,7 +921,7 @@ where rng: &mut impl CryptoRngCore, msg: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, true, &self.inner, &D::digest(msg), self.salt_len)? + sign_digest::<_, D>(rng, true, &self.inner, &D::digest(msg), self.salt_len)? .as_slice() .try_into() } @@ -945,7 +936,7 @@ where rng: &mut impl CryptoRngCore, digest: D, ) -> signature::Result { - sign_digest::<_, _, D>(rng, true, &self.inner, &digest.finalize(), self.salt_len)? + sign_digest::<_, D>(rng, true, &self.inner, &digest.finalize(), self.salt_len)? .as_slice() .try_into() } @@ -960,7 +951,7 @@ where rng: &mut impl CryptoRngCore, prehash: &[u8], ) -> signature::Result { - sign_digest::<_, _, D>(rng, true, &self.inner, prehash, self.salt_len)? + sign_digest::<_, D>(rng, true, &self.inner, prehash, self.salt_len)? .as_slice() .try_into() } @@ -1055,7 +1046,7 @@ where D: Digest + FixedOutputReset, { fn verify(&self, msg: &[u8], signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>( + verify_digest::( &self.inner, &D::digest(msg), &signature.inner, @@ -1071,7 +1062,7 @@ where D: Digest + FixedOutputReset, { fn verify_digest(&self, digest: D, signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>( + verify_digest::( &self.inner, &digest.finalize(), &signature.inner, @@ -1087,7 +1078,7 @@ where D: Digest + FixedOutputReset, { fn verify_prehash(&self, prehash: &[u8], signature: &Signature) -> signature::Result<()> { - verify_digest::<_, D>( + verify_digest::( &self.inner, prehash, &signature.inner, @@ -1119,7 +1110,7 @@ where #[cfg(test)] mod test { use crate::pss::{BlindedSigningKey, Pss, Signature, SigningKey, VerifyingKey}; - use crate::{PublicKey, RsaPrivateKey, RsaPublicKey}; + use crate::{RsaPrivateKey, RsaPublicKey}; use hex_literal::hex; use num_bigint::BigUint; diff --git a/src/raw.rs b/src/raw.rs deleted file mode 100644 index 793e68c2..00000000 --- a/src/raw.rs +++ /dev/null @@ -1,64 +0,0 @@ -use alloc::vec::Vec; -use num_bigint::BigUint; -use rand_core::CryptoRngCore; -use zeroize::Zeroizing; - -use crate::errors::Result; -use crate::internals; -use crate::key::{RsaPrivateKey, RsaPublicKey}; - -pub trait EncryptionPrimitive { - /// Do NOT use directly! Only for implementors. - fn raw_encryption_primitive(&self, plaintext: &[u8], pad_size: usize) -> Result> { - let int = Zeroizing::new(BigUint::from_bytes_be(plaintext)); - self.raw_int_encryption_primitive(&int, pad_size) - } - - fn raw_int_encryption_primitive(&self, plaintext: &BigUint, pad_size: usize) - -> Result>; -} - -pub trait DecryptionPrimitive { - /// Do NOT use directly! Only for implementors. - fn raw_decryption_primitive( - &self, - rng: Option<&mut R>, - ciphertext: &[u8], - pad_size: usize, - ) -> Result> { - let int = Zeroizing::new(BigUint::from_bytes_be(ciphertext)); - self.raw_int_decryption_primitive(rng, &int, pad_size) - } - - fn raw_int_decryption_primitive( - &self, - rng: Option<&mut R>, - ciphertext: &BigUint, - pad_size: usize, - ) -> Result>; -} - -impl EncryptionPrimitive for RsaPublicKey { - fn raw_int_encryption_primitive( - &self, - plaintext: &BigUint, - pad_size: usize, - ) -> Result> { - let c = Zeroizing::new(internals::encrypt(self, &plaintext)); - let c_bytes = Zeroizing::new(c.to_bytes_be()); - internals::left_pad(&c_bytes, pad_size) - } -} - -impl DecryptionPrimitive for RsaPrivateKey { - fn raw_int_decryption_primitive( - &self, - rng: Option<&mut R>, - ciphertext: &BigUint, - pad_size: usize, - ) -> Result> { - let m = Zeroizing::new(internals::decrypt_and_check(rng, self, &ciphertext)?); - let m_bytes = Zeroizing::new(m.to_bytes_be()); - internals::left_pad(&m_bytes, pad_size) - } -} From b4ee037058ad64c6a559366a8458e2c7548b66f5 Mon Sep 17 00:00:00 2001 From: Dmitry Baryshkov Date: Wed, 19 Apr 2023 01:31:40 +0300 Subject: [PATCH 3/6] feat: drop RsaPrivateKey::raw_decryption_primitive() Inline and drop the RsaPrivateKey::raw_decryption_primitive() function. There is no need to zeroize argument, it is ciphertext, so it can be assumed to be safe. Signed-off-by: Dmitry Baryshkov --- src/key.rs | 10 ---------- src/oaep.rs | 6 +++++- src/pkcs1v15.rs | 8 ++++++-- src/pss.rs | 4 ++-- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/key.rs b/src/key.rs index 4635b530..88346fa8 100644 --- a/src/key.rs +++ b/src/key.rs @@ -457,16 +457,6 @@ impl RsaPrivateKey { } /// Do NOT use directly! Only for implementors. - pub(crate) fn raw_decryption_primitive( - &self, - rng: Option<&mut R>, - ciphertext: &[u8], - pad_size: usize, - ) -> Result> { - let int = Zeroizing::new(BigUint::from_bytes_be(ciphertext)); - self.raw_int_decryption_primitive(rng, &int, pad_size) - } - pub(crate) fn raw_int_decryption_primitive( &self, rng: Option<&mut R>, diff --git a/src/oaep.rs b/src/oaep.rs index 7750234c..9d13ceef 100644 --- a/src/oaep.rs +++ b/src/oaep.rs @@ -399,7 +399,11 @@ fn decrypt_inner( return Err(Error::Decryption); } - let mut em = priv_key.raw_decryption_primitive(rng, ciphertext, priv_key.size())?; + let mut em = priv_key.raw_int_decryption_primitive( + rng, + &BigUint::from_bytes_be(ciphertext), + priv_key.size(), + )?; let first_byte_is_zero = em[0].ct_eq(&0u8); diff --git a/src/pkcs1v15.rs b/src/pkcs1v15.rs index e9a7ead8..eb75d4a0 100644 --- a/src/pkcs1v15.rs +++ b/src/pkcs1v15.rs @@ -275,7 +275,7 @@ pub(crate) fn sign( em[k - t_len..k - hash_len].copy_from_slice(prefix); em[k - hash_len..k].copy_from_slice(hashed); - priv_key.raw_decryption_primitive(rng, &em, priv_key.size()) + priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(&em), priv_key.size()) } /// Verifies an RSA PKCS#1 v1.5 signature. @@ -352,7 +352,11 @@ fn decrypt_inner( return Err(Error::Decryption); } - let em = priv_key.raw_decryption_primitive(rng, ciphertext, priv_key.size())?; + let em = priv_key.raw_int_decryption_primitive( + rng, + &BigUint::from_bytes_be(ciphertext), + priv_key.size(), + )?; let first_byte_is_zero = em[0].ct_eq(&0u8); let second_byte_is_two = em[1].ct_eq(&2u8); diff --git a/src/pss.rs b/src/pss.rs index ac6ac35f..a46027a9 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -278,7 +278,7 @@ fn sign_pss_with_salt( let em_bits = priv_key.n().bits() - 1; let em = emsa_pss_encode(hashed, em_bits, salt, digest)?; - priv_key.raw_decryption_primitive(blind_rng, &em, priv_key.size()) + priv_key.raw_int_decryption_primitive(blind_rng, &BigUint::from_bytes_be(&em), priv_key.size()) } fn sign_pss_with_salt_digest( @@ -290,7 +290,7 @@ fn sign_pss_with_salt_digest(hashed, em_bits, salt)?; - priv_key.raw_decryption_primitive(blind_rng, &em, priv_key.size()) + priv_key.raw_int_decryption_primitive(blind_rng, &BigUint::from_bytes_be(&em), priv_key.size()) } fn emsa_pss_encode( From 43a67aee5cb79a2286b377262d35edf4f9078d13 Mon Sep 17 00:00:00 2001 From: Dmitry Baryshkov Date: Wed, 19 Apr 2023 01:33:05 +0300 Subject: [PATCH 4/6] feat: make raw primitive functions symmetric Change raw_int_decryption_primitive() and raw_int_decryption_primitive() to output Result instead of Result>, because they also take BigUint rather than Vec or &[u8]. Signed-off-by: Dmitry Baryshkov --- src/internals.rs | 18 ++++++++++++++++-- src/key.rs | 19 +++++-------------- src/oaep.rs | 10 ++++------ src/pkcs1v15.rs | 17 +++++++++-------- src/pss.rs | 15 +++++++++++---- 5 files changed, 45 insertions(+), 34 deletions(-) diff --git a/src/internals.rs b/src/internals.rs index 90ae2181..a56b5f9a 100644 --- a/src/internals.rs +++ b/src/internals.rs @@ -4,7 +4,7 @@ use alloc::vec::Vec; use num_bigint::{BigInt, BigUint, IntoBigInt, IntoBigUint, ModInverse, RandBigInt, ToBigInt}; use num_traits::{One, Signed, Zero}; use rand_core::CryptoRngCore; -use zeroize::Zeroize; +use zeroize::{Zeroize, Zeroizing}; use crate::errors::{Error, Result}; use crate::key::{PublicKeyParts, RsaPrivateKey}; @@ -174,7 +174,7 @@ pub fn unblind(key: &impl PublicKeyParts, m: &BigUint, unblinder: &BigUint) -> B /// Returns a new vector of the given length, with 0s left padded. #[inline] -pub fn left_pad(input: &[u8], padded_len: usize) -> Result> { +fn left_pad(input: &[u8], padded_len: usize) -> Result> { if input.len() > padded_len { return Err(Error::InvalidPadLen); } @@ -184,6 +184,20 @@ pub fn left_pad(input: &[u8], padded_len: usize) -> Result> { Ok(out) } +/// Converts input to the new vector of the given length, using BE and with 0s left padded. +#[inline] +pub fn uint_to_be_pad(input: BigUint, padded_len: usize) -> Result> { + left_pad(&input.to_bytes_be(), padded_len) +} + +/// Converts input to the new vector of the given length, using BE and with 0s left padded. +#[inline] +pub fn uint_to_zeroizing_be_pad(input: BigUint, padded_len: usize) -> Result> { + let m = Zeroizing::new(input); + let m = Zeroizing::new(m.to_bytes_be()); + left_pad(&m, padded_len) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/key.rs b/src/key.rs index 88346fa8..c3514d69 100644 --- a/src/key.rs +++ b/src/key.rs @@ -7,7 +7,7 @@ use num_traits::{One, ToPrimitive}; use rand_core::CryptoRngCore; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use zeroize::{Zeroize, Zeroizing}; +use zeroize::Zeroize; use crate::algorithms::{generate_multi_prime_key, generate_multi_prime_key_with_exp}; use crate::dummy_rng::DummyRng; @@ -226,14 +226,8 @@ impl RsaPublicKey { Self { n, e } } - pub(crate) fn raw_int_encryption_primitive( - &self, - plaintext: &BigUint, - pad_size: usize, - ) -> Result> { - let c = Zeroizing::new(internals::encrypt(self, plaintext)); - let c_bytes = Zeroizing::new(c.to_bytes_be()); - internals::left_pad(&c_bytes, pad_size) + pub(crate) fn raw_int_encryption_primitive(&self, plaintext: &BigUint) -> Result { + Ok(internals::encrypt(self, plaintext)) } } @@ -461,11 +455,8 @@ impl RsaPrivateKey { &self, rng: Option<&mut R>, ciphertext: &BigUint, - pad_size: usize, - ) -> Result> { - let m = Zeroizing::new(internals::decrypt_and_check(rng, self, ciphertext)?); - let m_bytes = Zeroizing::new(m.to_bytes_be()); - internals::left_pad(&m_bytes, pad_size) + ) -> Result { + internals::decrypt_and_check(rng, self, ciphertext) } } diff --git a/src/oaep.rs b/src/oaep.rs index 9d13ceef..cc0c488d 100644 --- a/src/oaep.rs +++ b/src/oaep.rs @@ -19,6 +19,7 @@ use zeroize::Zeroizing; use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; +use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; use crate::key::{self, PublicKeyParts, RsaPrivateKey, RsaPublicKey}; use crate::padding::PaddingScheme; use crate::traits::{Decryptor, RandomizedDecryptor, RandomizedEncryptor}; @@ -208,7 +209,7 @@ fn encrypt_internal mgf(seed, db); let int = Zeroizing::new(BigUint::from_bytes_be(&em)); - pub_key.raw_int_encryption_primitive(&int, pub_key.size()) + uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } /// Encrypts the given message with RSA and the padding scheme from @@ -399,11 +400,8 @@ fn decrypt_inner( return Err(Error::Decryption); } - let mut em = priv_key.raw_int_decryption_primitive( - rng, - &BigUint::from_bytes_be(ciphertext), - priv_key.size(), - )?; + let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; + let mut em = uint_to_zeroizing_be_pad(em, priv_key.size())?; let first_byte_is_zero = em[0].ct_eq(&0u8); diff --git a/src/pkcs1v15.rs b/src/pkcs1v15.rs index eb75d4a0..9973761d 100644 --- a/src/pkcs1v15.rs +++ b/src/pkcs1v15.rs @@ -29,6 +29,7 @@ use zeroize::Zeroizing; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; +use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; use crate::key::{self, PublicKeyParts}; use crate::padding::{PaddingScheme, SignatureScheme}; use crate::traits::{Decryptor, EncryptingKeypair, RandomizedDecryptor, RandomizedEncryptor}; @@ -212,7 +213,7 @@ pub(crate) fn encrypt( em[k - msg.len()..].copy_from_slice(msg); let int = Zeroizing::new(BigUint::from_bytes_be(&em)); - pub_key.raw_int_encryption_primitive(&int, pub_key.size()) + uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } /// Decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5. @@ -275,7 +276,10 @@ pub(crate) fn sign( em[k - t_len..k - hash_len].copy_from_slice(prefix); em[k - hash_len..k].copy_from_slice(hashed); - priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(&em), priv_key.size()) + uint_to_zeroizing_be_pad( + priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(&em))?, + priv_key.size(), + ) } /// Verifies an RSA PKCS#1 v1.5 signature. @@ -293,7 +297,7 @@ pub(crate) fn verify( return Err(Error::Verification); } - let em = pub_key.raw_int_encryption_primitive(sig, pub_key.size())?; + let em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, pub_key.size())?; // EM = 0x00 || 0x01 || PS || 0x00 || T let mut ok = em[0].ct_eq(&0u8); @@ -352,11 +356,8 @@ fn decrypt_inner( return Err(Error::Decryption); } - let em = priv_key.raw_int_decryption_primitive( - rng, - &BigUint::from_bytes_be(ciphertext), - priv_key.size(), - )?; + let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; + let em = uint_to_zeroizing_be_pad(em, priv_key.size())?; let first_byte_is_zero = em[0].ct_eq(&0u8); let second_byte_is_two = em[1].ct_eq(&2u8); diff --git a/src/pss.rs b/src/pss.rs index a46027a9..a6ce9f93 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -34,6 +34,7 @@ use subtle::{Choice, ConstantTimeEq}; use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; use crate::errors::{Error, Result}; +use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; use crate::key::PublicKeyParts; use crate::padding::SignatureScheme; use crate::{RsaPrivateKey, RsaPublicKey}; @@ -198,7 +199,7 @@ pub(crate) fn verify( let em_bits = pub_key.n().bits() - 1; let em_len = (em_bits + 7) / 8; let key_len = pub_key.size(); - let mut em = pub_key.raw_int_encryption_primitive(sig, key_len)?; + let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, key_len)?; emsa_pss_verify( hashed, @@ -226,7 +227,7 @@ where let em_bits = pub_key.n().bits() - 1; let em_len = (em_bits + 7) / 8; let key_len = pub_key.size(); - let mut em = pub_key.raw_int_encryption_primitive(sig, key_len)?; + let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, key_len)?; emsa_pss_verify_digest::(hashed, &mut em[key_len - em_len..], em_bits, salt_len) } @@ -278,7 +279,10 @@ fn sign_pss_with_salt( let em_bits = priv_key.n().bits() - 1; let em = emsa_pss_encode(hashed, em_bits, salt, digest)?; - priv_key.raw_int_decryption_primitive(blind_rng, &BigUint::from_bytes_be(&em), priv_key.size()) + uint_to_zeroizing_be_pad( + priv_key.raw_int_decryption_primitive(blind_rng, &BigUint::from_bytes_be(&em))?, + priv_key.size(), + ) } fn sign_pss_with_salt_digest( @@ -290,7 +294,10 @@ fn sign_pss_with_salt_digest(hashed, em_bits, salt)?; - priv_key.raw_int_decryption_primitive(blind_rng, &BigUint::from_bytes_be(&em), priv_key.size()) + uint_to_zeroizing_be_pad( + priv_key.raw_int_decryption_primitive(blind_rng, &BigUint::from_bytes_be(&em))?, + priv_key.size(), + ) } fn emsa_pss_encode( From 5e2768acace64e4f3b131a405df456b39378619c Mon Sep 17 00:00:00 2001 From: Dmitry Baryshkov Date: Wed, 19 Apr 2023 03:36:26 +0300 Subject: [PATCH 5/6] feat: decompose padding/format algorithms to separate module In order to simplify adding support for RSA hardware accelerators, move all formatting and padding functions to a separate modules, making it theoretically possible to use that for implementing support for low-level RSA hardware accelerators. Signed-off-by: Dmitry Baryshkov --- src/algorithms.rs | 78 +-------- src/algorithms/mgf.rs | 75 +++++++++ src/algorithms/oaep.rs | 246 +++++++++++++++++++++++++++ src/algorithms/pkcs1v15.rs | 198 ++++++++++++++++++++++ src/algorithms/pss.rs | 334 +++++++++++++++++++++++++++++++++++++ src/oaep.rs | 185 ++------------------ src/pkcs1v15.rs | 183 ++------------------ src/pss.rs | 328 +----------------------------------- 8 files changed, 893 insertions(+), 734 deletions(-) create mode 100644 src/algorithms/mgf.rs create mode 100644 src/algorithms/oaep.rs create mode 100644 src/algorithms/pkcs1v15.rs create mode 100644 src/algorithms/pss.rs diff --git a/src/algorithms.rs b/src/algorithms.rs index 8bdc1d9a..9d803b11 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -1,6 +1,10 @@ //! Useful algorithms related to RSA. -use digest::{Digest, DynDigest, FixedOutputReset}; +mod mgf; +pub(crate) mod oaep; +pub(crate) mod pkcs1v15; +pub(crate) mod pss; + use num_bigint::traits::ModInverse; use num_bigint::{BigUint, RandPrime}; #[allow(unused_imports)] @@ -134,75 +138,3 @@ pub fn generate_multi_prime_key_with_exp( RsaPrivateKey::from_components(n_final, exp.clone(), d_final, primes) } - -/// Mask generation function. -/// -/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 -pub fn mgf1_xor(out: &mut [u8], digest: &mut dyn DynDigest, seed: &[u8]) { - let mut counter = [0u8; 4]; - let mut i = 0; - - const MAX_LEN: u64 = core::u32::MAX as u64 + 1; - assert!(out.len() as u64 <= MAX_LEN); - - while i < out.len() { - let mut digest_input = vec![0u8; seed.len() + 4]; - digest_input[0..seed.len()].copy_from_slice(seed); - digest_input[seed.len()..].copy_from_slice(&counter); - - digest.update(digest_input.as_slice()); - let digest_output = &*digest.finalize_reset(); - let mut j = 0; - loop { - if j >= digest_output.len() || i >= out.len() { - break; - } - - out[i] ^= digest_output[j]; - j += 1; - i += 1; - } - inc_counter(&mut counter); - } -} - -/// Mask generation function. -/// -/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 -pub fn mgf1_xor_digest(out: &mut [u8], digest: &mut D, seed: &[u8]) -where - D: Digest + FixedOutputReset, -{ - let mut counter = [0u8; 4]; - let mut i = 0; - - const MAX_LEN: u64 = core::u32::MAX as u64 + 1; - assert!(out.len() as u64 <= MAX_LEN); - - while i < out.len() { - Digest::update(digest, seed); - Digest::update(digest, counter); - - let digest_output = digest.finalize_reset(); - let mut j = 0; - loop { - if j >= digest_output.len() || i >= out.len() { - break; - } - - out[i] ^= digest_output[j]; - j += 1; - i += 1; - } - inc_counter(&mut counter); - } -} -fn inc_counter(counter: &mut [u8; 4]) { - for i in (0..4).rev() { - counter[i] = counter[i].wrapping_add(1); - if counter[i] != 0 { - // No overflow - return; - } - } -} diff --git a/src/algorithms/mgf.rs b/src/algorithms/mgf.rs new file mode 100644 index 00000000..aa8fb2a3 --- /dev/null +++ b/src/algorithms/mgf.rs @@ -0,0 +1,75 @@ +//! Mask generation function common to both PSS and OAEP padding + +use digest::{Digest, DynDigest, FixedOutputReset}; + +/// Mask generation function. +/// +/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 +pub fn mgf1_xor(out: &mut [u8], digest: &mut dyn DynDigest, seed: &[u8]) { + let mut counter = [0u8; 4]; + let mut i = 0; + + const MAX_LEN: u64 = core::u32::MAX as u64 + 1; + assert!(out.len() as u64 <= MAX_LEN); + + while i < out.len() { + let mut digest_input = vec![0u8; seed.len() + 4]; + digest_input[0..seed.len()].copy_from_slice(seed); + digest_input[seed.len()..].copy_from_slice(&counter); + + digest.update(digest_input.as_slice()); + let digest_output = &*digest.finalize_reset(); + let mut j = 0; + loop { + if j >= digest_output.len() || i >= out.len() { + break; + } + + out[i] ^= digest_output[j]; + j += 1; + i += 1; + } + inc_counter(&mut counter); + } +} + +/// Mask generation function. +/// +/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 +pub fn mgf1_xor_digest(out: &mut [u8], digest: &mut D, seed: &[u8]) +where + D: Digest + FixedOutputReset, +{ + let mut counter = [0u8; 4]; + let mut i = 0; + + const MAX_LEN: u64 = core::u32::MAX as u64 + 1; + assert!(out.len() as u64 <= MAX_LEN); + + while i < out.len() { + Digest::update(digest, seed); + Digest::update(digest, counter); + + let digest_output = digest.finalize_reset(); + let mut j = 0; + loop { + if j >= digest_output.len() || i >= out.len() { + break; + } + + out[i] ^= digest_output[j]; + j += 1; + i += 1; + } + inc_counter(&mut counter); + } +} +fn inc_counter(counter: &mut [u8; 4]) { + for i in (0..4).rev() { + counter[i] = counter[i].wrapping_add(1); + if counter[i] != 0 { + // No overflow + return; + } + } +} diff --git a/src/algorithms/oaep.rs b/src/algorithms/oaep.rs new file mode 100644 index 00000000..0ba2de9d --- /dev/null +++ b/src/algorithms/oaep.rs @@ -0,0 +1,246 @@ +//! Encryption and Decryption using [OAEP padding](https://datatracker.ietf.org/doc/html/rfc8017#section-7.1). +//! +use alloc::string::String; +use alloc::vec::Vec; + +use digest::{Digest, DynDigest, FixedOutputReset}; +use rand_core::CryptoRngCore; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; +use zeroize::Zeroizing; + +use super::mgf::{mgf1_xor, mgf1_xor_digest}; +use crate::errors::{Error, Result}; + +// 2**61 -1 (pow is not const yet) +// TODO: This is the maximum for SHA-1, unclear from the RFC what the values are for other hashing functions. +const MAX_LABEL_LEN: u64 = 2_305_843_009_213_693_951; + +#[inline] +fn encrypt_internal( + rng: &mut R, + msg: &[u8], + p_hash: &[u8], + h_size: usize, + k: usize, + mut mgf: MGF, +) -> Result>> { + if msg.len() + 2 * h_size + 2 > k { + return Err(Error::MessageTooLong); + } + + let mut em = Zeroizing::new(vec![0u8; k]); + + let (_, payload) = em.split_at_mut(1); + let (seed, db) = payload.split_at_mut(h_size); + rng.fill_bytes(seed); + + // Data block DB = pHash || PS || 01 || M + let db_len = k - h_size - 1; + + db[0..h_size].copy_from_slice(p_hash); + db[db_len - msg.len() - 1] = 1; + db[db_len - msg.len()..].copy_from_slice(msg); + + mgf(seed, db); + + Ok(em) +} + +/// Encrypts the given message with RSA and the padding scheme from +/// [PKCS#1 OAEP]. +/// +/// The message must be no longer than the length of the public modulus minus +/// `2 + (2 * hash.size())`. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_encrypt( + rng: &mut R, + msg: &[u8], + digest: &mut dyn DynDigest, + mgf_digest: &mut dyn DynDigest, + label: Option, + k: usize, +) -> Result>> { + let h_size = digest.output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + digest.update(label.as_bytes()); + let p_hash = digest.finalize_reset(); + + encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| { + mgf1_xor(db, mgf_digest, seed); + mgf1_xor(seed, mgf_digest, db); + }) +} + +/// Encrypts the given message with RSA and the padding scheme from +/// [PKCS#1 OAEP]. +/// +/// The message must be no longer than the length of the public modulus minus +/// `2 + (2 * hash.size())`. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_encrypt_digest< + R: CryptoRngCore + ?Sized, + D: Digest, + MGD: Digest + FixedOutputReset, +>( + rng: &mut R, + msg: &[u8], + label: Option, + k: usize, +) -> Result>> { + let h_size = ::output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + let p_hash = D::digest(label.as_bytes()); + + encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| { + let mut mgf_digest = MGD::new(); + mgf1_xor_digest(db, &mut mgf_digest, seed); + mgf1_xor_digest(seed, &mut mgf_digest, db); + }) +} + +///Decrypts OAEP padding. +/// +/// Note that whether this function returns an error or not discloses secret +/// information. If an attacker can cause this function to run repeatedly and +/// learn whether each instance returned an error then they can decrypt and +/// forge signatures as if they had the private key. +/// +/// See `decrypt_session_key` for a way of solving this problem. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_decrypt( + em: &mut [u8], + digest: &mut dyn DynDigest, + mgf_digest: &mut dyn DynDigest, + label: Option, + k: usize, +) -> Result> { + let h_size = digest.output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::Decryption); + } + + digest.update(label.as_bytes()); + + let expected_p_hash = digest.finalize_reset(); + + let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| { + mgf1_xor(seed, mgf_digest, db); + mgf1_xor(db, mgf_digest, seed); + })?; + if res.is_none().into() { + return Err(Error::Decryption); + } + + let (out, index) = res.unwrap(); + + Ok(out[index as usize..].to_vec()) +} + +///Decrypts OAEP padding. +/// +/// Note that whether this function returns an error or not discloses secret +/// information. If an attacker can cause this function to run repeatedly and +/// learn whether each instance returned an error then they can decrypt and +/// forge signatures as if they had the private key. +/// +/// See `decrypt_session_key` for a way of solving this problem. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_decrypt_digest( + em: &mut [u8], + label: Option, + k: usize, +) -> Result> { + let h_size = ::output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + let expected_p_hash = D::digest(label.as_bytes()); + + let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| { + let mut mgf_digest = MGD::new(); + mgf1_xor_digest(seed, &mut mgf_digest, db); + mgf1_xor_digest(db, &mut mgf_digest, seed); + })?; + if res.is_none().into() { + return Err(Error::Decryption); + } + + let (out, index) = res.unwrap(); + + Ok(out[index as usize..].to_vec()) +} + +/// Decrypts OAEP padding. It returns one or zero in valid that indicates whether the +/// plaintext was correctly structured. +#[inline] +fn decrypt_inner( + em: &mut [u8], + h_size: usize, + expected_p_hash: &[u8], + k: usize, + mut mgf: MGF, +) -> Result, u32)>> { + if k < 11 { + return Err(Error::Decryption); + } + + if k < h_size * 2 + 2 { + return Err(Error::Decryption); + } + + let first_byte_is_zero = em[0].ct_eq(&0u8); + + let (_, payload) = em.split_at_mut(1); + let (seed, db) = payload.split_at_mut(h_size); + + mgf(seed, db); + + let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash); + + // The remainder of the plaintext must be zero or more 0x00, followed + // by 0x01, followed by the message. + // looking_for_index: 1 if we are still looking for the 0x01 + // index: the offset of the first 0x01 byte + // zero_before_one: 1 if we saw a non-zero byte before the 1 + let mut looking_for_index = Choice::from(1u8); + let mut index = 0u32; + let mut nonzero_before_one = Choice::from(0u8); + + for (i, el) in db.iter().skip(h_size).enumerate() { + let equals0 = el.ct_eq(&0u8); + let equals1 = el.ct_eq(&1u8); + index.conditional_assign(&(i as u32), looking_for_index & equals1); + looking_for_index &= !equals1; + nonzero_before_one |= looking_for_index & !equals0; + } + + let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index; + + Ok(CtOption::new( + (em.to_vec(), index + 2 + (h_size * 2) as u32), + valid, + )) +} diff --git a/src/algorithms/pkcs1v15.rs b/src/algorithms/pkcs1v15.rs new file mode 100644 index 00000000..c1f0779a --- /dev/null +++ b/src/algorithms/pkcs1v15.rs @@ -0,0 +1,198 @@ +//! PKCS#1 v1.5 support as described in [RFC8017 § 8.2]. +//! +//! # Usage +//! +//! See [code example in the toplevel rustdoc](../index.html#pkcs1-v15-signatures). +//! +//! [RFC8017 § 8.2]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.2 + +use alloc::vec::Vec; +use digest::Digest; +use pkcs8::AssociatedOid; +use rand_core::CryptoRngCore; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; +use zeroize::Zeroizing; + +use crate::errors::{Error, Result}; + +/// Fills the provided slice with random values, which are guaranteed +/// to not be zero. +#[inline] +fn non_zero_random_bytes(rng: &mut R, data: &mut [u8]) { + rng.fill_bytes(data); + + for el in data { + if *el == 0u8 { + // TODO: break after a certain amount of time + while *el == 0u8 { + rng.fill_bytes(core::slice::from_mut(el)); + } + } + } +} + +/// Applied the padding scheme from PKCS#1 v1.5 for encryption. The message must be no longer than +/// the length of the public modulus minus 11 bytes. +pub(crate) fn pkcs1v15_encrypt_pad( + rng: &mut R, + msg: &[u8], + k: usize, +) -> Result>> +where + R: CryptoRngCore + ?Sized, +{ + if msg.len() > k - 11 { + return Err(Error::MessageTooLong); + } + + // EM = 0x00 || 0x02 || PS || 0x00 || M + let mut em = Zeroizing::new(vec![0u8; k]); + em[1] = 2; + non_zero_random_bytes(rng, &mut em[2..k - msg.len() - 1]); + em[k - msg.len() - 1] = 0; + em[k - msg.len()..].copy_from_slice(msg); + Ok(em) +} + +/// Removes the encryption padding scheme from PKCS#1 v1.5. +/// +/// Note that whether this function returns an error or not discloses secret +/// information. If an attacker can cause this function to run repeatedly and +/// learn whether each instance returned an error then they can decrypt and +/// forge signatures as if they had the private key. See +/// `decrypt_session_key` for a way of solving this problem. +#[inline] +pub(crate) fn pkcs1v15_encrypt_unpad(em: Vec, k: usize) -> Result> { + let (valid, out, index) = decrypt_inner(em, k)?; + if valid == 0 { + return Err(Error::Decryption); + } + + Ok(out[index as usize..].to_vec()) +} + +/// Removes the PKCS1v15 padding It returns one or zero in valid that indicates whether the +/// plaintext was correctly structured. In either case, the plaintext is +/// returned in em so that it may be read independently of whether it was valid +/// in order to maintain constant memory access patterns. If the plaintext was +/// valid then index contains the index of the original message in em. +#[inline] +fn decrypt_inner(em: Vec, k: usize) -> Result<(u8, Vec, u32)> { + if k < 11 { + return Err(Error::Decryption); + } + + let first_byte_is_zero = em[0].ct_eq(&0u8); + let second_byte_is_two = em[1].ct_eq(&2u8); + + // The remainder of the plaintext must be a string of non-zero random + // octets, followed by a 0, followed by the message. + // looking_for_index: 1 iff we are still looking for the zero. + // index: the offset of the first zero byte. + let mut looking_for_index = 1u8; + let mut index = 0u32; + + for (i, el) in em.iter().enumerate().skip(2) { + let equals0 = el.ct_eq(&0u8); + index.conditional_assign(&(i as u32), Choice::from(looking_for_index) & equals0); + looking_for_index.conditional_assign(&0u8, equals0); + } + + // The PS padding must be at least 8 bytes long, and it starts two + // bytes into em. + // TODO: WARNING: THIS MUST BE CONSTANT TIME CHECK: + // Ref: https://github.com/dalek-cryptography/subtle/issues/20 + // This is currently copy & paste from the constant time impl in + // go, but very likely not sufficient. + let valid_ps = Choice::from((((2i32 + 8i32 - index as i32 - 1i32) >> 31) & 1) as u8); + let valid = + first_byte_is_zero & second_byte_is_two & Choice::from(!looking_for_index & 1) & valid_ps; + index = u32::conditional_select(&0, &(index + 1), valid); + + Ok((valid.unwrap_u8(), em, index)) +} + +#[inline] +pub(crate) fn pkcs1v15_sign_pad(prefix: &[u8], hashed: &[u8], k: usize) -> Result> { + let hash_len = hashed.len(); + let t_len = prefix.len() + hashed.len(); + if k < t_len + 11 { + return Err(Error::MessageTooLong); + } + + // EM = 0x00 || 0x01 || PS || 0x00 || T + let mut em = vec![0xff; k]; + em[0] = 0; + em[1] = 1; + em[k - t_len - 1] = 0; + em[k - t_len..k - hash_len].copy_from_slice(prefix); + em[k - hash_len..k].copy_from_slice(hashed); + + Ok(em) +} + +#[inline] +pub(crate) fn pkcs1v15_sign_unpad(prefix: &[u8], hashed: &[u8], em: &[u8], k: usize) -> Result<()> { + let hash_len = hashed.len(); + let t_len = prefix.len() + hashed.len(); + if k < t_len + 11 { + return Err(Error::Verification); + } + + // EM = 0x00 || 0x01 || PS || 0x00 || T + let mut ok = em[0].ct_eq(&0u8); + ok &= em[1].ct_eq(&1u8); + ok &= em[k - hash_len..k].ct_eq(hashed); + ok &= em[k - t_len..k - hash_len].ct_eq(prefix); + ok &= em[k - t_len - 1].ct_eq(&0u8); + + for el in em.iter().skip(2).take(k - t_len - 3) { + ok &= el.ct_eq(&0xff) + } + + if ok.unwrap_u8() != 1 { + return Err(Error::Verification); + } + + Ok(()) +} + +/// prefix = 0x30 0x30 0x06 oid 0x05 0x00 0x04 +#[inline] +pub(crate) fn pkcs1v15_generate_prefix() -> Vec +where + D: Digest + AssociatedOid, +{ + let oid = D::OID.as_bytes(); + let oid_len = oid.len() as u8; + let digest_len = ::output_size() as u8; + let mut v = vec![ + 0x30, + oid_len + 8 + digest_len, + 0x30, + oid_len + 4, + 0x6, + oid_len, + ]; + v.extend_from_slice(oid); + v.extend_from_slice(&[0x05, 0x00, 0x04, digest_len]); + v +} + +#[cfg(test)] +mod tests { + use super::*; + use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; + + #[test] + fn test_non_zero_bytes() { + for _ in 0..10 { + let mut rng = ChaCha8Rng::from_seed([42; 32]); + let mut b = vec![0u8; 512]; + non_zero_random_bytes(&mut rng, &mut b); + for el in &b { + assert_ne!(*el, 0u8); + } + } + } +} diff --git a/src/algorithms/pss.rs b/src/algorithms/pss.rs new file mode 100644 index 00000000..db58584d --- /dev/null +++ b/src/algorithms/pss.rs @@ -0,0 +1,334 @@ +//! Support for the [Probabilistic Signature Scheme] (PSS) a.k.a. RSASSA-PSS. +//! +//! Designed by Mihir Bellare and Phillip Rogaway. Specified in [RFC8017 § 8.1]. +//! +//! # Usage +//! +//! See [code example in the toplevel rustdoc](../index.html#pss-signatures). +//! +//! [Probabilistic Signature Scheme]: https://en.wikipedia.org/wiki/Probabilistic_signature_scheme +//! [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1 + +use alloc::vec::Vec; +use digest::{Digest, DynDigest, FixedOutputReset}; +use subtle::{Choice, ConstantTimeEq}; + +use super::mgf::{mgf1_xor, mgf1_xor_digest}; +use crate::errors::{Error, Result}; + +pub(crate) fn emsa_pss_encode( + m_hash: &[u8], + em_bits: usize, + salt: &[u8], + hash: &mut dyn DynDigest, +) -> Result> { + // See [1], section 9.1.1 + let h_len = hash.output_size(); + let s_len = salt.len(); + let em_len = (em_bits + 7) / 8; + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "message too + // long" and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + if m_hash.len() != h_len { + return Err(Error::InputNotHashed); + } + + // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. + if em_len < h_len + s_len + 2 { + // TODO: Key size too small + return Err(Error::Internal); + } + + let mut em = vec![0; em_len]; + + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..(em_len - 1) - db.len()]; + + // 4. Generate a random octet string salt of length s_len; if s_len = 0, + // then salt is the empty string. + // + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; + // + // M' is an octet string of length 8 + h_len + s_len with eight + // initial zero octets. + // + // 6. Let H = Hash(M'), an octet string of length h_len. + let prefix = [0u8; 8]; + + hash.update(&prefix); + hash.update(m_hash); + hash.update(salt); + + let hashed = hash.finalize_reset(); + h.copy_from_slice(&hashed); + + // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 + // zero octets. The length of PS may be 0. + // + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + db[em_len - s_len - h_len - 2] = 0x01; + db[em_len - s_len - h_len - 1..].copy_from_slice(salt); + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 10. Let maskedDB = DB \xor dbMask. + mgf1_xor(db, hash, h); + + // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB to zero. + db[0] &= 0xFF >> (8 * em_len - em_bits); + + // 12. Let EM = maskedDB || H || 0xbc. + em[em_len - 1] = 0xBC; + + Ok(em) +} + +pub(crate) fn emsa_pss_encode_digest( + m_hash: &[u8], + em_bits: usize, + salt: &[u8], +) -> Result> +where + D: Digest + FixedOutputReset, +{ + // See [1], section 9.1.1 + let h_len = ::output_size(); + let s_len = salt.len(); + let em_len = (em_bits + 7) / 8; + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "message too + // long" and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + if m_hash.len() != h_len { + return Err(Error::InputNotHashed); + } + + // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. + if em_len < h_len + s_len + 2 { + // TODO: Key size too small + return Err(Error::Internal); + } + + let mut em = vec![0; em_len]; + + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..(em_len - 1) - db.len()]; + + // 4. Generate a random octet string salt of length s_len; if s_len = 0, + // then salt is the empty string. + // + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; + // + // M' is an octet string of length 8 + h_len + s_len with eight + // initial zero octets. + // + // 6. Let H = Hash(M'), an octet string of length h_len. + let prefix = [0u8; 8]; + + let mut hash = D::new(); + + Digest::update(&mut hash, prefix); + Digest::update(&mut hash, m_hash); + Digest::update(&mut hash, salt); + + let hashed = hash.finalize_reset(); + h.copy_from_slice(&hashed); + + // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 + // zero octets. The length of PS may be 0. + // + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + db[em_len - s_len - h_len - 2] = 0x01; + db[em_len - s_len - h_len - 1..].copy_from_slice(salt); + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 10. Let maskedDB = DB \xor dbMask. + mgf1_xor_digest(db, &mut hash, h); + + // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB to zero. + db[0] &= 0xFF >> (8 * em_len - em_bits); + + // 12. Let EM = maskedDB || H || 0xbc. + em[em_len - 1] = 0xBC; + + Ok(em) +} + +fn emsa_pss_verify_pre<'a>( + m_hash: &[u8], + em: &'a mut [u8], + em_bits: usize, + s_len: usize, + h_len: usize, +) -> Result<(&'a mut [u8], &'a mut [u8])> { + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" + // and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen + if m_hash.len() != h_len { + return Err(Error::Verification); + } + + // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. + let em_len = em.len(); //(em_bits + 7) / 8; + if em_len < h_len + s_len + 2 { + return Err(Error::Verification); + } + + // 4. If the rightmost octet of EM does not have hexadecimal value + // 0xbc, output "inconsistent" and stop. + if em[em.len() - 1] != 0xBC { + return Err(Error::Verification); + } + + // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and + // let H be the next hLen octets. + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..h_len]; + + // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB are not all equal to zero, output "inconsistent" and + // stop. + if db[0] + & (0xFF_u8 + .checked_shl(8 - (8 * em_len - em_bits) as u32) + .unwrap_or(0)) + != 0 + { + return Err(Error::Verification); + } + + Ok((db, h)) +} + +fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice { + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero + // or if the octet at position emLen - hLen - sLen - 1 (the leftmost + // position is "position 1") does not have hexadecimal value 0x01, + // output "inconsistent" and stop. + let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); + let valid: Choice = zeroes + .iter() + .fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00)); + + valid & rest[0].ct_eq(&0x01) +} + +pub(crate) fn emsa_pss_verify( + m_hash: &[u8], + em: &mut [u8], + s_len: usize, + hash: &mut dyn DynDigest, + key_bits: usize, +) -> Result<()> { + let em_bits = key_bits - 1; + let em_len = (em_bits + 7) / 8; + let key_len = (key_bits + 7) / 8; + let h_len = hash.output_size(); + + let em = &mut em[key_len - em_len..]; + + let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; + + // 7. Let dbMask = MGF(H, em_len - h_len - 1) + // + // 8. Let DB = maskedDB \xor dbMask + mgf1_xor(db, hash, &*h); + + // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB + // to zero. + db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); + + let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 13. Let H' = Hash(M'), an octet string of length hLen. + let prefix = [0u8; 8]; + + hash.update(&prefix[..]); + hash.update(m_hash); + hash.update(salt); + let h0 = hash.finalize_reset(); + + // 14. If H = H', output "consistent." Otherwise, output "inconsistent." + if (salt_valid & h0.ct_eq(h)).into() { + Ok(()) + } else { + Err(Error::Verification) + } +} + +pub(crate) fn emsa_pss_verify_digest( + m_hash: &[u8], + em: &mut [u8], + s_len: usize, + key_bits: usize, +) -> Result<()> +where + D: Digest + FixedOutputReset, +{ + let em_bits = key_bits - 1; + let em_len = (em_bits + 7) / 8; + let key_len = (key_bits + 7) / 8; + let h_len = ::output_size(); + + let em = &mut em[key_len - em_len..]; + + let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; + + let mut hash = D::new(); + + // 7. Let dbMask = MGF(H, em_len - h_len - 1) + // + // 8. Let DB = maskedDB \xor dbMask + mgf1_xor_digest::(db, &mut hash, &*h); + + // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB + // to zero. + db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); + + let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 13. Let H' = Hash(M'), an octet string of length hLen. + let prefix = [0u8; 8]; + + Digest::update(&mut hash, &prefix[..]); + Digest::update(&mut hash, m_hash); + Digest::update(&mut hash, salt); + let h0 = hash.finalize_reset(); + + // 14. If H = H', output "consistent." Otherwise, output "inconsistent." + if (salt_valid & h0.ct_eq(h)).into() { + Ok(()) + } else { + Err(Error::Verification) + } +} diff --git a/src/oaep.rs b/src/oaep.rs index cc0c488d..07755cd9 100644 --- a/src/oaep.rs +++ b/src/oaep.rs @@ -5,7 +5,6 @@ //! See [code example in the toplevel rustdoc](../index.html#oaep-encryption). use alloc::boxed::Box; use alloc::string::{String, ToString}; -use alloc::vec; use alloc::vec::Vec; use core::fmt; use core::marker::PhantomData; @@ -13,10 +12,9 @@ use rand_core::CryptoRngCore; use digest::{Digest, DynDigest, FixedOutputReset}; use num_bigint::BigUint; -use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use zeroize::Zeroizing; -use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; +use crate::algorithms::oaep::*; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; @@ -24,10 +22,6 @@ use crate::key::{self, PublicKeyParts, RsaPrivateKey, RsaPublicKey}; use crate::padding::PaddingScheme; use crate::traits::{Decryptor, RandomizedDecryptor, RandomizedEncryptor}; -// 2**61 -1 (pow is not const yet) -// TODO: This is the maximum for SHA-1, unclear from the RFC what the values are for other hashing functions. -const MAX_LABEL_LEN: u64 = 2_305_843_009_213_693_951; - /// Encryption and Decryption using [OAEP padding](https://datatracker.ietf.org/doc/html/rfc8017#section-7.1). /// /// - `digest` is used to hash the label. The maximum possible plaintext length is `m = k - 2 * h_len - 2`, @@ -176,42 +170,6 @@ impl fmt::Debug for Oaep { } } -#[inline] -fn encrypt_internal( - rng: &mut R, - pub_key: &RsaPublicKey, - msg: &[u8], - p_hash: &[u8], - h_size: usize, - mut mgf: MGF, -) -> Result> { - key::check_public(pub_key)?; - - let k = pub_key.size(); - - if msg.len() + 2 * h_size + 2 > k { - return Err(Error::MessageTooLong); - } - - let mut em = Zeroizing::new(vec![0u8; k]); - - let (_, payload) = em.split_at_mut(1); - let (seed, db) = payload.split_at_mut(h_size); - rng.fill_bytes(seed); - - // Data block DB = pHash || PS || 01 || M - let db_len = k - h_size - 1; - - db[0..h_size].copy_from_slice(p_hash); - db[db_len - msg.len() - 1] = 1; - db[db_len - msg.len()..].copy_from_slice(msg); - - mgf(seed, db); - - let int = Zeroizing::new(BigUint::from_bytes_be(&em)); - uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) -} - /// Encrypts the given message with RSA and the padding scheme from /// [PKCS#1 OAEP]. /// @@ -228,20 +186,12 @@ fn encrypt( mgf_digest: &mut dyn DynDigest, label: Option, ) -> Result> { - let h_size = digest.output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::LabelTooLong); - } + key::check_public(pub_key)?; - digest.update(label.as_bytes()); - let p_hash = digest.finalize_reset(); + let em = oaep_encrypt(rng, msg, digest, mgf_digest, label, pub_key.size())?; - encrypt_internal(rng, pub_key, msg, &p_hash, h_size, |seed, db| { - mgf1_xor(db, mgf_digest, seed); - mgf1_xor(seed, mgf_digest, db); - }) + let int = Zeroizing::new(BigUint::from_bytes_be(&em)); + uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } /// Encrypts the given message with RSA and the padding scheme from @@ -251,27 +201,18 @@ fn encrypt( /// `2 + (2 * hash.size())`. /// /// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 -#[inline] fn encrypt_digest( rng: &mut R, pub_key: &RsaPublicKey, msg: &[u8], label: Option, ) -> Result> { - let h_size = ::output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::LabelTooLong); - } + key::check_public(pub_key)?; - let p_hash = D::digest(label.as_bytes()); + let em = oaep_encrypt_digest::<_, D, MGD>(rng, msg, label, pub_key.size())?; - encrypt_internal(rng, pub_key, msg, &p_hash, h_size, |seed, db| { - let mut mgf_digest = MGD::new(); - mgf1_xor_digest(db, &mut mgf_digest, seed); - mgf1_xor_digest(seed, &mut mgf_digest, db); - }) + let int = Zeroizing::new(BigUint::from_bytes_be(&em)); + uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } /// Decrypts a plaintext using RSA and the padding scheme from [PKCS#1 OAEP]. @@ -297,35 +238,14 @@ fn decrypt( ) -> Result> { key::check_public(priv_key)?; - let h_size = digest.output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::Decryption); - } - - digest.update(label.as_bytes()); - - let expected_p_hash = digest.finalize_reset(); - - let res = decrypt_inner( - rng, - priv_key, - ciphertext, - h_size, - &expected_p_hash, - |seed, db| { - mgf1_xor(seed, mgf_digest, db); - mgf1_xor(db, mgf_digest, seed); - }, - )?; - if res.is_none().into() { + if ciphertext.len() != priv_key.size() { return Err(Error::Decryption); } - let (out, index) = res.unwrap(); + let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; + let mut em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - Ok(out[index as usize..].to_vec()) + oaep_decrypt(&mut em, digest, mgf_digest, label, priv_key.size()) } /// Decrypts a plaintext using RSA and the padding scheme from [PKCS#1 OAEP]. @@ -349,89 +269,14 @@ fn decrypt_digest Result> { key::check_public(priv_key)?; - let h_size = ::output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::LabelTooLong); - } - - let expected_p_hash = D::digest(label.as_bytes()); - - let res = decrypt_inner( - rng, - priv_key, - ciphertext, - h_size, - &expected_p_hash, - |seed, db| { - let mut mgf_digest = MGD::new(); - mgf1_xor_digest(seed, &mut mgf_digest, db); - mgf1_xor_digest(db, &mut mgf_digest, seed); - }, - )?; - if res.is_none().into() { - return Err(Error::Decryption); - } - - let (out, index) = res.unwrap(); - - Ok(out[index as usize..].to_vec()) -} - -/// Decrypts ciphertext using `priv_key` and blinds the operation if -/// `rng` is given. It returns one or zero in valid that indicates whether the -/// plaintext was correctly structured. -#[inline] -fn decrypt_inner( - rng: Option<&mut R>, - priv_key: &RsaPrivateKey, - ciphertext: &[u8], - h_size: usize, - expected_p_hash: &[u8], - mut mgf: MGF, -) -> Result, u32)>> { - let k = priv_key.size(); - if k < 11 { - return Err(Error::Decryption); - } - - if ciphertext.len() != k || k < h_size * 2 + 2 { + if ciphertext.len() != priv_key.size() { return Err(Error::Decryption); } let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; let mut em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - let first_byte_is_zero = em[0].ct_eq(&0u8); - - let (_, payload) = em.split_at_mut(1); - let (seed, db) = payload.split_at_mut(h_size); - - mgf(seed, db); - - let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash); - - // The remainder of the plaintext must be zero or more 0x00, followed - // by 0x01, followed by the message. - // looking_for_index: 1 if we are still looking for the 0x01 - // index: the offset of the first 0x01 byte - // zero_before_one: 1 if we saw a non-zero byte before the 1 - let mut looking_for_index = Choice::from(1u8); - let mut index = 0u32; - let mut nonzero_before_one = Choice::from(0u8); - - for (i, el) in db.iter().skip(h_size).enumerate() { - let equals0 = el.ct_eq(&0u8); - let equals1 = el.ct_eq(&1u8); - index.conditional_assign(&(i as u32), looking_for_index & equals1); - looking_for_index &= !equals1; - nonzero_before_one |= looking_for_index & !equals0; - } - - let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index; - - Ok(CtOption::new((em, index + 2 + (h_size * 2) as u32), valid)) + oaep_decrypt_digest::(&mut em, label, priv_key.size()) } /// Encryption key for PKCS#1 v1.5 encryption as described in [RFC8017 § 7.1]. diff --git a/src/pkcs1v15.rs b/src/pkcs1v15.rs index 9973761d..d82a7303 100644 --- a/src/pkcs1v15.rs +++ b/src/pkcs1v15.rs @@ -24,9 +24,9 @@ use signature::{ DigestSigner, DigestVerifier, Keypair, RandomizedDigestSigner, RandomizedSigner, SignatureEncoding, Signer, Verifier, }; -use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use zeroize::Zeroizing; +use crate::algorithms::pkcs1v15::*; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; @@ -80,7 +80,7 @@ impl Pkcs1v15Sign { { Self { hash_len: Some(::output_size()), - prefix: generate_prefix::().into_boxed_slice(), + prefix: pkcs1v15_generate_prefix::().into_boxed_slice(), } } @@ -193,25 +193,14 @@ impl Display for Signature { /// scheme from PKCS#1 v1.5. The message must be no longer than the /// length of the public modulus minus 11 bytes. #[inline] -pub(crate) fn encrypt( +fn encrypt( rng: &mut R, pub_key: &RsaPublicKey, msg: &[u8], ) -> Result> { key::check_public(pub_key)?; - let k = pub_key.size(); - if msg.len() > k - 11 { - return Err(Error::MessageTooLong); - } - - // EM = 0x00 || 0x02 || PS || 0x00 || M - let mut em = Zeroizing::new(vec![0u8; k]); - em[1] = 2; - non_zero_random_bytes(rng, &mut em[2..k - msg.len() - 1]); - em[k - msg.len() - 1] = 0; - em[k - msg.len()..].copy_from_slice(msg); - + let em = pkcs1v15_encrypt_pad(rng, msg, pub_key.size())?; let int = Zeroizing::new(BigUint::from_bytes_be(&em)); uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } @@ -226,19 +215,17 @@ pub(crate) fn encrypt( /// forge signatures as if they had the private key. See /// `decrypt_session_key` for a way of solving this problem. #[inline] -pub(crate) fn decrypt( +fn decrypt( rng: Option<&mut R>, priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result> { key::check_public(priv_key)?; - let (valid, out, index) = decrypt_inner(rng, priv_key, ciphertext)?; - if valid == 0 { - return Err(Error::Decryption); - } + let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; + let em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - Ok(out[index as usize..].to_vec()) + pkcs1v15_encrypt_unpad(em, priv_key.size()) } /// Calculates the signature of hashed using @@ -255,26 +242,13 @@ pub(crate) fn decrypt( /// messages to signatures and identify the signed messages. As ever, /// signatures provide authenticity, not confidentiality. #[inline] -pub(crate) fn sign( +fn sign( rng: Option<&mut R>, priv_key: &RsaPrivateKey, prefix: &[u8], hashed: &[u8], ) -> Result> { - let hash_len = hashed.len(); - let t_len = prefix.len() + hashed.len(); - let k = priv_key.size(); - if k < t_len + 11 { - return Err(Error::MessageTooLong); - } - - // EM = 0x00 || 0x01 || PS || 0x00 || T - let mut em = vec![0xff; k]; - em[0] = 0; - em[1] = 1; - em[k - t_len - 1] = 0; - em[k - t_len..k - hash_len].copy_from_slice(prefix); - em[k - hash_len..k].copy_from_slice(hashed); + let em = pkcs1v15_sign_pad(prefix, hashed, priv_key.size())?; uint_to_zeroizing_be_pad( priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(&em))?, @@ -284,125 +258,10 @@ pub(crate) fn sign( /// Verifies an RSA PKCS#1 v1.5 signature. #[inline] -pub(crate) fn verify( - pub_key: &RsaPublicKey, - prefix: &[u8], - hashed: &[u8], - sig: &BigUint, -) -> Result<()> { - let hash_len = hashed.len(); - let t_len = prefix.len() + hashed.len(); - let k = pub_key.size(); - if k < t_len + 11 { - return Err(Error::Verification); - } - +fn verify(pub_key: &RsaPublicKey, prefix: &[u8], hashed: &[u8], sig: &BigUint) -> Result<()> { let em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, pub_key.size())?; - // EM = 0x00 || 0x01 || PS || 0x00 || T - let mut ok = em[0].ct_eq(&0u8); - ok &= em[1].ct_eq(&1u8); - ok &= em[k - hash_len..k].ct_eq(hashed); - ok &= em[k - t_len..k - hash_len].ct_eq(prefix); - ok &= em[k - t_len - 1].ct_eq(&0u8); - - for el in em.iter().skip(2).take(k - t_len - 3) { - ok &= el.ct_eq(&0xff) - } - - if ok.unwrap_u8() != 1 { - return Err(Error::Verification); - } - - Ok(()) -} - -/// prefix = 0x30 0x30 0x06 oid 0x05 0x00 0x04 -#[inline] -pub(crate) fn generate_prefix() -> Vec -where - D: Digest + AssociatedOid, -{ - let oid = D::OID.as_bytes(); - let oid_len = oid.len() as u8; - let digest_len = ::output_size() as u8; - let mut v = vec![ - 0x30, - oid_len + 8 + digest_len, - 0x30, - oid_len + 4, - 0x6, - oid_len, - ]; - v.extend_from_slice(oid); - v.extend_from_slice(&[0x05, 0x00, 0x04, digest_len]); - v -} - -/// Decrypts ciphertext using `priv_key` and blinds the operation if -/// `rng` is given. It returns one or zero in valid that indicates whether the -/// plaintext was correctly structured. In either case, the plaintext is -/// returned in em so that it may be read independently of whether it was valid -/// in order to maintain constant memory access patterns. If the plaintext was -/// valid then index contains the index of the original message in em. -#[inline] -fn decrypt_inner( - rng: Option<&mut R>, - priv_key: &RsaPrivateKey, - ciphertext: &[u8], -) -> Result<(u8, Vec, u32)> { - let k = priv_key.size(); - if k < 11 { - return Err(Error::Decryption); - } - - let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; - let em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - - let first_byte_is_zero = em[0].ct_eq(&0u8); - let second_byte_is_two = em[1].ct_eq(&2u8); - - // The remainder of the plaintext must be a string of non-zero random - // octets, followed by a 0, followed by the message. - // looking_for_index: 1 iff we are still looking for the zero. - // index: the offset of the first zero byte. - let mut looking_for_index = 1u8; - let mut index = 0u32; - - for (i, el) in em.iter().enumerate().skip(2) { - let equals0 = el.ct_eq(&0u8); - index.conditional_assign(&(i as u32), Choice::from(looking_for_index) & equals0); - looking_for_index.conditional_assign(&0u8, equals0); - } - - // The PS padding must be at least 8 bytes long, and it starts two - // bytes into em. - // TODO: WARNING: THIS MUST BE CONSTANT TIME CHECK: - // Ref: https://github.com/dalek-cryptography/subtle/issues/20 - // This is currently copy & paste from the constant time impl in - // go, but very likely not sufficient. - let valid_ps = Choice::from((((2i32 + 8i32 - index as i32 - 1i32) >> 31) & 1) as u8); - let valid = - first_byte_is_zero & second_byte_is_two & Choice::from(!looking_for_index & 1) & valid_ps; - index = u32::conditional_select(&0, &(index + 1), valid); - - Ok((valid.unwrap_u8(), em, index)) -} - -/// Fills the provided slice with random values, which are guaranteed -/// to not be zero. -#[inline] -fn non_zero_random_bytes(rng: &mut R, data: &mut [u8]) { - rng.fill_bytes(data); - - for el in data { - if *el == 0u8 { - // TODO: break after a certain amount of time - while *el == 0u8 { - rng.fill_bytes(core::slice::from_mut(el)); - } - } - } + pkcs1v15_sign_unpad(prefix, hashed, &em, pub_key.size()) } /// Signing key for PKCS#1 v1.5 signatures as described in [RFC8017 § 8.2]. @@ -496,7 +355,7 @@ where pub fn new(key: RsaPrivateKey) -> Self { Self { inner: key, - prefix: generate_prefix::(), + prefix: pkcs1v15_generate_prefix::(), phantom: Default::default(), } } @@ -505,7 +364,7 @@ where pub fn random(rng: &mut R, bit_size: usize) -> Result { Ok(Self { inner: RsaPrivateKey::new(rng, bit_size)?, - prefix: generate_prefix::(), + prefix: pkcs1v15_generate_prefix::(), phantom: Default::default(), }) } @@ -700,7 +559,7 @@ where pub fn new(key: RsaPublicKey) -> Self { Self { inner: key, - prefix: generate_prefix::(), + prefix: pkcs1v15_generate_prefix::(), phantom: Default::default(), } } @@ -907,18 +766,6 @@ mod tests { use crate::{PublicKeyParts, RsaPrivateKey, RsaPublicKey}; - #[test] - fn test_non_zero_bytes() { - for _ in 0..10 { - let mut rng = ChaCha8Rng::from_seed([42; 32]); - let mut b = vec![0u8; 512]; - non_zero_random_bytes(&mut rng, &mut b); - for el in &b { - assert_ne!(*el, 0u8); - } - } - } - fn get_private_key() -> RsaPrivateKey { // In order to generate new test vectors you'll need the PEM form of this key: // -----BEGIN RSA PRIVATE KEY----- diff --git a/src/pss.rs b/src/pss.rs index a6ce9f93..8713a8dd 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -30,9 +30,8 @@ use signature::{ hazmat::{PrehashVerifier, RandomizedPrehashSigner}, DigestVerifier, Keypair, RandomizedDigestSigner, RandomizedSigner, SignatureEncoding, Verifier, }; -use subtle::{Choice, ConstantTimeEq}; -use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; +use crate::algorithms::pss::*; use crate::errors::{Error, Result}; use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; use crate::key::PublicKeyParts; @@ -196,18 +195,9 @@ pub(crate) fn verify( return Err(Error::Verification); } - let em_bits = pub_key.n().bits() - 1; - let em_len = (em_bits + 7) / 8; - let key_len = pub_key.size(); - let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, key_len)?; + let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, pub_key.size())?; - emsa_pss_verify( - hashed, - &mut em[key_len - em_len..], - em_bits, - salt_len, - digest, - ) + emsa_pss_verify(hashed, &mut em, salt_len, digest, pub_key.n().bits()) } pub(crate) fn verify_digest( @@ -224,12 +214,9 @@ where return Err(Error::Verification); } - let em_bits = pub_key.n().bits() - 1; - let em_len = (em_bits + 7) / 8; - let key_len = pub_key.size(); - let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, key_len)?; + let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, pub_key.size())?; - emsa_pss_verify_digest::(hashed, &mut em[key_len - em_len..], em_bits, salt_len) + emsa_pss_verify_digest::(hashed, &mut em, salt_len, pub_key.n().bits()) } /// SignPSS calculates the signature of hashed using RSASSA-PSS. @@ -300,311 +287,6 @@ fn sign_pss_with_salt_digest Result> { - // See [1], section 9.1.1 - let h_len = hash.output_size(); - let s_len = salt.len(); - let em_len = (em_bits + 7) / 8; - - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "message too - // long" and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen. - if m_hash.len() != h_len { - return Err(Error::InputNotHashed); - } - - // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. - if em_len < h_len + s_len + 2 { - // TODO: Key size too small - return Err(Error::Internal); - } - - let mut em = vec![0; em_len]; - - let (db, h) = em.split_at_mut(em_len - h_len - 1); - let h = &mut h[..(em_len - 1) - db.len()]; - - // 4. Generate a random octet string salt of length s_len; if s_len = 0, - // then salt is the empty string. - // - // 5. Let - // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; - // - // M' is an octet string of length 8 + h_len + s_len with eight - // initial zero octets. - // - // 6. Let H = Hash(M'), an octet string of length h_len. - let prefix = [0u8; 8]; - - hash.update(&prefix); - hash.update(m_hash); - hash.update(salt); - - let hashed = hash.finalize_reset(); - h.copy_from_slice(&hashed); - - // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 - // zero octets. The length of PS may be 0. - // - // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length - // emLen - hLen - 1. - db[em_len - s_len - h_len - 2] = 0x01; - db[em_len - s_len - h_len - 1..].copy_from_slice(salt); - - // 9. Let dbMask = MGF(H, emLen - hLen - 1). - // - // 10. Let maskedDB = DB \xor dbMask. - mgf1_xor(db, hash, h); - - // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in - // maskedDB to zero. - db[0] &= 0xFF >> (8 * em_len - em_bits); - - // 12. Let EM = maskedDB || H || 0xbc. - em[em_len - 1] = 0xBC; - - Ok(em) -} - -fn emsa_pss_encode_digest(m_hash: &[u8], em_bits: usize, salt: &[u8]) -> Result> -where - D: Digest + FixedOutputReset, -{ - // See [1], section 9.1.1 - let h_len = ::output_size(); - let s_len = salt.len(); - let em_len = (em_bits + 7) / 8; - - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "message too - // long" and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen. - if m_hash.len() != h_len { - return Err(Error::InputNotHashed); - } - - // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. - if em_len < h_len + s_len + 2 { - // TODO: Key size too small - return Err(Error::Internal); - } - - let mut em = vec![0; em_len]; - - let (db, h) = em.split_at_mut(em_len - h_len - 1); - let h = &mut h[..(em_len - 1) - db.len()]; - - // 4. Generate a random octet string salt of length s_len; if s_len = 0, - // then salt is the empty string. - // - // 5. Let - // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; - // - // M' is an octet string of length 8 + h_len + s_len with eight - // initial zero octets. - // - // 6. Let H = Hash(M'), an octet string of length h_len. - let prefix = [0u8; 8]; - - let mut hash = D::new(); - - Digest::update(&mut hash, prefix); - Digest::update(&mut hash, m_hash); - Digest::update(&mut hash, salt); - - let hashed = hash.finalize_reset(); - h.copy_from_slice(&hashed); - - // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 - // zero octets. The length of PS may be 0. - // - // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length - // emLen - hLen - 1. - db[em_len - s_len - h_len - 2] = 0x01; - db[em_len - s_len - h_len - 1..].copy_from_slice(salt); - - // 9. Let dbMask = MGF(H, emLen - hLen - 1). - // - // 10. Let maskedDB = DB \xor dbMask. - mgf1_xor_digest(db, &mut hash, h); - - // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in - // maskedDB to zero. - db[0] &= 0xFF >> (8 * em_len - em_bits); - - // 12. Let EM = maskedDB || H || 0xbc. - em[em_len - 1] = 0xBC; - - Ok(em) -} - -fn emsa_pss_verify_pre<'a>( - m_hash: &[u8], - em: &'a mut [u8], - em_bits: usize, - s_len: usize, - h_len: usize, -) -> Result<(&'a mut [u8], &'a mut [u8])> { - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" - // and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen - if m_hash.len() != h_len { - return Err(Error::Verification); - } - - // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. - let em_len = em.len(); //(em_bits + 7) / 8; - if em_len < h_len + s_len + 2 { - return Err(Error::Verification); - } - - // 4. If the rightmost octet of EM does not have hexadecimal value - // 0xbc, output "inconsistent" and stop. - if em[em.len() - 1] != 0xBC { - return Err(Error::Verification); - } - - // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and - // let H be the next hLen octets. - let (db, h) = em.split_at_mut(em_len - h_len - 1); - let h = &mut h[..h_len]; - - // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in - // maskedDB are not all equal to zero, output "inconsistent" and - // stop. - if db[0] - & (0xFF_u8 - .checked_shl(8 - (8 * em_len - em_bits) as u32) - .unwrap_or(0)) - != 0 - { - return Err(Error::Verification); - } - - Ok((db, h)) -} - -fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice { - // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero - // or if the octet at position emLen - hLen - sLen - 1 (the leftmost - // position is "position 1") does not have hexadecimal value 0x01, - // output "inconsistent" and stop. - let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); - let valid: Choice = zeroes - .iter() - .fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00)); - - valid & rest[0].ct_eq(&0x01) -} - -fn emsa_pss_verify( - m_hash: &[u8], - em: &mut [u8], - em_bits: usize, - s_len: usize, - hash: &mut dyn DynDigest, -) -> Result<()> { - let em_len = em.len(); //(em_bits + 7) / 8; - let h_len = hash.output_size(); - - let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; - - // 7. Let dbMask = MGF(H, em_len - h_len - 1) - // - // 8. Let DB = maskedDB \xor dbMask - mgf1_xor(db, hash, &*h); - - // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB - // to zero. - db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - - let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); - - // 11. Let salt be the last s_len octets of DB. - let salt = &db[db.len() - s_len..]; - - // 12. Let - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; - // M' is an octet string of length 8 + hLen + sLen with eight - // initial zero octets. - // - // 13. Let H' = Hash(M'), an octet string of length hLen. - let prefix = [0u8; 8]; - - hash.update(&prefix[..]); - hash.update(m_hash); - hash.update(salt); - let h0 = hash.finalize_reset(); - - // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if (salt_valid & h0.ct_eq(h)).into() { - Ok(()) - } else { - Err(Error::Verification) - } -} - -fn emsa_pss_verify_digest( - m_hash: &[u8], - em: &mut [u8], - em_bits: usize, - s_len: usize, -) -> Result<()> -where - D: Digest + FixedOutputReset, -{ - let em_len = em.len(); //(em_bits + 7) / 8; - let h_len = ::output_size(); - - let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; - - let mut hash = D::new(); - - // 7. Let dbMask = MGF(H, em_len - h_len - 1) - // - // 8. Let DB = maskedDB \xor dbMask - mgf1_xor_digest::(db, &mut hash, &*h); - - // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB - // to zero. - db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - - let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); - - // 11. Let salt be the last s_len octets of DB. - let salt = &db[db.len() - s_len..]; - - // 12. Let - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; - // M' is an octet string of length 8 + hLen + sLen with eight - // initial zero octets. - // - // 13. Let H' = Hash(M'), an octet string of length hLen. - let prefix = [0u8; 8]; - - Digest::update(&mut hash, &prefix[..]); - Digest::update(&mut hash, m_hash); - Digest::update(&mut hash, salt); - let h0 = hash.finalize_reset(); - - // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if (salt_valid & h0.ct_eq(h)).into() { - Ok(()) - } else { - Err(Error::Verification) - } -} - /// Signing key for producing RSASSA-PSS signatures as described in /// [RFC8017 § 8.1]. /// From 5b83a69eb8777f0d8a96c2b464dc677c566d7d2f Mon Sep 17 00:00:00 2001 From: Dmitry Baryshkov Date: Wed, 19 Apr 2023 04:00:20 +0300 Subject: [PATCH 6/6] feat: mark several internal functions as private There is no need to export some the functions from internals.rs. Mark them as private. Signed-off-by: Dmitry Baryshkov --- src/internals.rs | 6 +++--- src/key.rs | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/internals.rs b/src/internals.rs index a56b5f9a..1c3ba6be 100644 --- a/src/internals.rs +++ b/src/internals.rs @@ -18,7 +18,7 @@ pub fn encrypt(key: &K, m: &BigUint) -> BigUint { /// Performs raw RSA decryption with no padding, resulting in a plaintext `BigUint`. /// Peforms RSA blinding if an `Rng` is passed. #[inline] -pub fn decrypt( +fn decrypt( mut rng: Option<&mut R>, priv_key: &RsaPrivateKey, c: &BigUint, @@ -127,7 +127,7 @@ pub fn decrypt_and_check( } /// Returns the blinded c, along with the unblinding factor. -pub fn blind( +fn blind( rng: &mut R, key: &K, c: &BigUint, @@ -168,7 +168,7 @@ pub fn blind( } /// Given an m and and unblinding factor, unblind the m. -pub fn unblind(key: &impl PublicKeyParts, m: &BigUint, unblinder: &BigUint) -> BigUint { +fn unblind(key: &impl PublicKeyParts, m: &BigUint, unblinder: &BigUint) -> BigUint { (m * unblinder) % key.n() } diff --git a/src/key.rs b/src/key.rs index c3514d69..0f24d497 100644 --- a/src/key.rs +++ b/src/key.rs @@ -492,7 +492,6 @@ fn check_public_with_max_size(public_key: &impl PublicKeyParts, max_size: usize) #[cfg(test)] mod tests { use super::*; - use crate::internals; use hex_literal::hex; use num_traits::{FromPrimitive, ToPrimitive}; @@ -525,12 +524,16 @@ mod tests { let pub_key: RsaPublicKey = private_key.clone().into(); let m = BigUint::from_u64(42).expect("invalid 42"); - let c = internals::encrypt(&pub_key, &m); - let m2 = internals::decrypt::(None, private_key, &c) + let c = pub_key + .raw_int_encryption_primitive(&m) + .expect("encryption successfull"); + let m2 = private_key + .raw_int_decryption_primitive::(None, &c) .expect("unable to decrypt without blinding"); assert_eq!(m, m2); let mut rng = ChaCha8Rng::from_seed([42; 32]); - let m3 = internals::decrypt(Some(&mut rng), private_key, &c) + let m3 = private_key + .raw_int_decryption_primitive(Some(&mut rng), &c) .expect("unable to decrypt with blinding"); assert_eq!(m, m3); }