From a84b69be3ceb34b0f84d101849e2f76ccac85622 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Mon, 26 Jun 2023 12:03:54 -0600 Subject: [PATCH 01/41] Misc doc fixes --- sunscreen/src/fhe/mod.rs | 2 +- sunscreen_compiler_macros/src/zkp_program.rs | 5 +++- sunscreen_runtime/src/metadata.rs | 24 +++++++++++--------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/sunscreen/src/fhe/mod.rs b/sunscreen/src/fhe/mod.rs index 8245f8715..8473f8278 100644 --- a/sunscreen/src/fhe/mod.rs +++ b/sunscreen/src/fhe/mod.rs @@ -157,7 +157,7 @@ pub type FheFrontendCompilation = CompilationResult; thread_local! { /** - * Contains the graph of a ZKP program during compilation. An + * Contains the graph of an FHE program during compilation. An * implementation detail and not for public consumption. */ pub static CURRENT_FHE_CTX: RefCell> = RefCell::new(None); diff --git a/sunscreen_compiler_macros/src/zkp_program.rs b/sunscreen_compiler_macros/src/zkp_program.rs index fd32d489f..34ce612de 100644 --- a/sunscreen_compiler_macros/src/zkp_program.rs +++ b/sunscreen_compiler_macros/src/zkp_program.rs @@ -35,7 +35,10 @@ pub fn zkp_program_impl( fn get_generic_arg(generics: &Generics) -> Result<(Ident, Path)> { if generics.type_params().count() != 1 { - return Err(Error::compile_error(generics.span(), "ZKP programs must take 1 generic argument with bound sunscreen::BackendField.sunscreen::BackendField")); + return Err(Error::compile_error( + generics.span(), + "ZKP programs must take 1 generic argument with bound sunscreen::BackendField", + )); } if generics.lifetimes().count() > 0 { diff --git a/sunscreen_runtime/src/metadata.rs b/sunscreen_runtime/src/metadata.rs index 40ad8bd91..d34a9c143 100644 --- a/sunscreen_runtime/src/metadata.rs +++ b/sunscreen_runtime/src/metadata.rs @@ -8,7 +8,8 @@ use sunscreen_fhe_program::{FheProgram, SchemeType}; use crate::{Error, Result}; /** - * Indicates the type signatures of an Fhe Program. Serves as a piece of the [`FheProgramMetadata`]. + * Indicates the type signature of an FHE or ZKP program. Serves as a piece of the + * [`FheProgramMetadata`] or [`ZkpProgramFn`] respectively. * * # Remarks * This type is serializable and FHE program implementors can give this object @@ -18,28 +19,29 @@ use crate::{Error, Result}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CallSignature { /** - * The type of each argument in the FHE program. + * The type of each argument in the program. * * # Remarks - * The ith argument to the FHE program occupies the ith argument of the vector. - * The length of this vector equals the number of arguments to the FHE program. + * The ith argument to the program occupies the ith element of the vector. + * The length of this vector equals the number of arguments to the program. */ pub arguments: Vec, /** - * The type of the single return value of the FHE program if the return type is - * not a type. If the return type of the FHE program is a tuple, then this contains + * The type of the single return value of the program if the return type is + * not a tuple. If the return type of the program is a tuple, then this contains * each type in the tuple. - * - * # Remarks - * The ith argument to the FHE program occupies the ith argument of the vector. - * The length of this vector equals the number of arguments to the FHE program. */ pub returns: Vec, /** - * The number of ciphertexts that compose the nth return value. + * The number of ciphertexts that compose the corresponding return value. + * + * # Remarks + * The number of ciphertexts composing the ith return value of the program occupies the ith + * element of the vector. The length of this vector equals the length of the returns. */ + // TODO This field is specific to FHE; should we segment the types here? CallSignature ? pub num_ciphertexts: Vec, } From a91e07b93692467a8d4aab1ed08371a721aced26 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Mon, 26 Jun 2023 17:23:57 -0600 Subject: [PATCH 02/41] Fix sunscreen zkp exports --- Cargo.lock | 1 - examples/sudoku_zkp/Cargo.toml | 4 ++-- examples/sudoku_zkp/src/main.rs | 4 ++-- sunscreen/src/lib.rs | 2 ++ sunscreen_zkp_backend/src/bulletproofs.rs | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f4e32b6db..095c391e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2619,7 +2619,6 @@ name = "sudoku_zkp" version = "0.1.0" dependencies = [ "sunscreen", - "sunscreen_zkp_backend", ] [[package]] diff --git a/examples/sudoku_zkp/Cargo.toml b/examples/sudoku_zkp/Cargo.toml index 7abe4c726..d0431a6e3 100644 --- a/examples/sudoku_zkp/Cargo.toml +++ b/examples/sudoku_zkp/Cargo.toml @@ -4,5 +4,5 @@ version = "0.1.0" edition = "2021" [dependencies] -sunscreen = { path = "../../sunscreen" } -sunscreen_zkp_backend = { path = "../../sunscreen_zkp_backend" } \ No newline at end of file +sunscreen = { path = "../../sunscreen", features = ["bulletproofs"] } + diff --git a/examples/sudoku_zkp/src/main.rs b/examples/sudoku_zkp/src/main.rs index d46d085b2..b06f9329f 100644 --- a/examples/sudoku_zkp/src/main.rs +++ b/examples/sudoku_zkp/src/main.rs @@ -1,7 +1,7 @@ use sunscreen::{ - types::zkp::NativeField, zkp_program, BackendField, Compiler, Runtime, ZkpProgramInput, + types::zkp::NativeField, zkp_program, BackendField, BulletproofsBackend, Compiler, Runtime, + ZkpBackend, ZkpProgramInput, }; -use sunscreen_zkp_backend::{bulletproofs::BulletproofsBackend, ZkpBackend}; type BPField = NativeField<::Field>; diff --git a/sunscreen/src/lib.rs b/sunscreen/src/lib.rs index 55fd1cff8..ea2be6d16 100644 --- a/sunscreen/src/lib.rs +++ b/sunscreen/src/lib.rs @@ -85,6 +85,8 @@ pub use sunscreen_runtime::{ InnerPlaintext, Params, Plaintext, PrivateKey, PublicKey, RequiredKeys, Runtime, WithContext, ZkpProgramInput, ZkpRuntime, }; +#[cfg(feature = "bulletproofs")] +pub use sunscreen_zkp_backend::bulletproofs::{BulletproofsBackend, BulletproofsR1CSProof}; pub use sunscreen_zkp_backend::{BackendField, Error as ZkpError, Result as ZkpResult, ZkpBackend}; pub use zkp::ZkpProgramFn; pub use zkp::{ diff --git a/sunscreen_zkp_backend/src/bulletproofs.rs b/sunscreen_zkp_backend/src/bulletproofs.rs index 303a20dc3..2ec0c40e7 100644 --- a/sunscreen_zkp_backend/src/bulletproofs.rs +++ b/sunscreen_zkp_backend/src/bulletproofs.rs @@ -99,7 +99,7 @@ impl Neg for Node { /** * A Bulletproofs R1CS circuit. */ -pub struct BulletproofsCircuit { +struct BulletproofsCircuit { nodes: Vec>, } From ff8e3be8aa423eb8de13f730ba0f4236cde90862 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Mon, 26 Jun 2023 18:03:44 -0600 Subject: [PATCH 03/41] Fix broken api doc reference --- sunscreen_runtime/src/metadata.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sunscreen_runtime/src/metadata.rs b/sunscreen_runtime/src/metadata.rs index d34a9c143..8652fd004 100644 --- a/sunscreen_runtime/src/metadata.rs +++ b/sunscreen_runtime/src/metadata.rs @@ -8,8 +8,7 @@ use sunscreen_fhe_program::{FheProgram, SchemeType}; use crate::{Error, Result}; /** - * Indicates the type signature of an FHE or ZKP program. Serves as a piece of the - * [`FheProgramMetadata`] or [`ZkpProgramFn`] respectively. + * Indicates the type signature of an FHE or ZKP program. * * # Remarks * This type is serializable and FHE program implementors can give this object From e60ee450d2581d8b66d52ba3de7d905be2adbaf7 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Mon, 26 Jun 2023 23:29:18 -0600 Subject: [PATCH 04/41] Add starter zkp example --- Cargo.lock | 7 +++ examples/ordering_zkp/Cargo.toml | 7 +++ examples/ordering_zkp/src/main.rs | 82 +++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 examples/ordering_zkp/Cargo.toml create mode 100644 examples/ordering_zkp/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 095c391e0..87898a6f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1954,6 +1954,13 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "ordering_zkp" +version = "0.1.0" +dependencies = [ + "sunscreen", +] + [[package]] name = "os_str_bytes" version = "6.5.0" diff --git a/examples/ordering_zkp/Cargo.toml b/examples/ordering_zkp/Cargo.toml new file mode 100644 index 000000000..e3f8bb2bc --- /dev/null +++ b/examples/ordering_zkp/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "ordering_zkp" +version = "0.1.0" +edition = "2021" + +[dependencies] +sunscreen = { path = "../../sunscreen", features = ["bulletproofs"] } diff --git a/examples/ordering_zkp/src/main.rs b/examples/ordering_zkp/src/main.rs new file mode 100644 index 000000000..e48ed8f16 --- /dev/null +++ b/examples/ordering_zkp/src/main.rs @@ -0,0 +1,82 @@ +use sunscreen::{ + types::zkp::{ConstrainCmp, NativeField}, + zkp_program, BackendField, BulletproofsBackend, Compiler, ZkpBackend, ZkpRuntime, +}; + +type BPField = NativeField<::Field>; + +fn main() { + let app = Compiler::new() + .zkp_backend::() + .zkp_program(greater_than) + .compile() + .unwrap(); + + let greater_than_zkp = app.get_zkp_program(greater_than).unwrap(); + + let runtime = ZkpRuntime::new(&BulletproofsBackend::new()).unwrap(); + + let amount = BPField::from(64); + let threshold = BPField::from(232); + + // Prove that amount > threshold + + let proof = runtime + .prove(greater_than_zkp, vec![threshold], vec![], vec![amount]) + .unwrap(); + + let verify = runtime.verify(greater_than_zkp, &proof, vec![threshold], vec![]); + + assert!(verify.is_ok()); +} + +#[zkp_program(backend = "bulletproofs")] +fn greater_than(a: NativeField, #[constant] b: NativeField) { + a.constrain_gt_bounded(b, 32) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn run_test(amount: BPField, threshold: BPField, should_succeed: bool) { + let app = Compiler::new() + .zkp_backend::() + .zkp_program(greater_than) + .compile() + .unwrap(); + let gt_zkp = app.get_zkp_program(greater_than).unwrap(); + let runtime = ZkpRuntime::new(&BulletproofsBackend::new()).unwrap(); + let proof = runtime.prove(gt_zkp, vec![threshold], vec![], vec![amount]); + if !should_succeed { + assert!(proof.is_err()); + } else { + assert!(runtime + .verify(gt_zkp, &proof.unwrap(), vec![threshold], vec![]) + .is_ok()) + } + } + + #[test] + fn test_gt() { + run_test(1.into(), 0.into(), true); + run_test(100.into(), 0.into(), true); + run_test(100.into(), 99.into(), true); + run_test(u32::MAX.into(), 0.into(), true); + } + + #[test] + fn test_le() { + run_test(0.into(), 1.into(), false); + } + + #[test] + fn test_eq() { + run_test(1.into(), 1.into(), false); + } + + #[test] + fn test_bounded_failure() { + run_test(u64::MAX.into(), 0.into(), false); + } +} From 523a611715f6e517d81f83a1b05de1f65964af65 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Mon, 26 Jun 2023 23:29:45 -0600 Subject: [PATCH 05/41] Use ZkpRuntime::new in sudoku example --- examples/sudoku_zkp/src/main.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/sudoku_zkp/src/main.rs b/examples/sudoku_zkp/src/main.rs index b06f9329f..ea2957d1e 100644 --- a/examples/sudoku_zkp/src/main.rs +++ b/examples/sudoku_zkp/src/main.rs @@ -1,6 +1,6 @@ use sunscreen::{ - types::zkp::NativeField, zkp_program, BackendField, BulletproofsBackend, Compiler, Runtime, - ZkpBackend, ZkpProgramInput, + types::zkp::NativeField, zkp_program, BackendField, BulletproofsBackend, Compiler, ZkpBackend, + ZkpProgramInput, ZkpRuntime, }; type BPField = NativeField<::Field>; @@ -14,7 +14,7 @@ fn main() { let prog = app.get_zkp_program(sudoku_proof).unwrap(); - let runtime = Runtime::new_zkp(&BulletproofsBackend::new()).unwrap(); + let runtime = ZkpRuntime::new(&BulletproofsBackend::new()).unwrap(); let ex_puzzle = [ [0, 7, 0, 0, 2, 0, 0, 4, 6], @@ -120,7 +120,7 @@ mod tests { let prog = app.get_zkp_program(sudoku_proof).unwrap(); - let runtime = Runtime::new_zkp(&BulletproofsBackend::new()).unwrap(); + let runtime = ZkpRuntime::new(&BulletproofsBackend::new()).unwrap(); let ex_puzzle = [ [0, 7, 0, 0, 2, 0, 0, 4, 6], @@ -167,7 +167,7 @@ mod tests { let prog = app.get_zkp_program(sudoku_proof).unwrap(); - let runtime = Runtime::new_zkp(&BulletproofsBackend::new()).unwrap(); + let runtime = ZkpRuntime::new(&BulletproofsBackend::new()).unwrap(); let ex_puzzle = [ [0, 7, 0, 0, 2, 0, 0, 4, 6], @@ -212,7 +212,7 @@ mod tests { let prog = app.get_zkp_program(sudoku_proof).unwrap(); - let runtime = Runtime::new_zkp(&BulletproofsBackend::new()).unwrap(); + let runtime = ZkpRuntime::new(&BulletproofsBackend::new()).unwrap(); let ex_puzzle = [[0; 9]; 9]; From 7cbadbd5ec1b7e0bf2ef7f8e55002cd1a52a0d79 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Tue, 27 Jun 2023 09:53:29 -0600 Subject: [PATCH 06/41] Use ? over unwrap in zkp examples --- examples/ordering_zkp/src/main.rs | 17 +++++++---------- examples/sudoku_zkp/src/main.rs | 17 ++++++++--------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/examples/ordering_zkp/src/main.rs b/examples/ordering_zkp/src/main.rs index e48ed8f16..d87a79b9e 100644 --- a/examples/ordering_zkp/src/main.rs +++ b/examples/ordering_zkp/src/main.rs @@ -1,33 +1,30 @@ use sunscreen::{ types::zkp::{ConstrainCmp, NativeField}, - zkp_program, BackendField, BulletproofsBackend, Compiler, ZkpBackend, ZkpRuntime, + zkp_program, BackendField, BulletproofsBackend, Compiler, Error, ZkpBackend, ZkpRuntime, }; type BPField = NativeField<::Field>; -fn main() { +fn main() -> Result<(), Error> { let app = Compiler::new() .zkp_backend::() .zkp_program(greater_than) - .compile() - .unwrap(); + .compile()?; let greater_than_zkp = app.get_zkp_program(greater_than).unwrap(); - let runtime = ZkpRuntime::new(&BulletproofsBackend::new()).unwrap(); + let runtime = ZkpRuntime::new(&BulletproofsBackend::new())?; let amount = BPField::from(64); let threshold = BPField::from(232); // Prove that amount > threshold - let proof = runtime - .prove(greater_than_zkp, vec![threshold], vec![], vec![amount]) - .unwrap(); + let proof = runtime.prove(greater_than_zkp, vec![threshold], vec![], vec![amount])?; - let verify = runtime.verify(greater_than_zkp, &proof, vec![threshold], vec![]); + runtime.verify(greater_than_zkp, &proof, vec![threshold], vec![])?; - assert!(verify.is_ok()); + Ok(()) } #[zkp_program(backend = "bulletproofs")] diff --git a/examples/sudoku_zkp/src/main.rs b/examples/sudoku_zkp/src/main.rs index ea2957d1e..4b7e768a7 100644 --- a/examples/sudoku_zkp/src/main.rs +++ b/examples/sudoku_zkp/src/main.rs @@ -1,20 +1,19 @@ use sunscreen::{ - types::zkp::NativeField, zkp_program, BackendField, BulletproofsBackend, Compiler, ZkpBackend, - ZkpProgramInput, ZkpRuntime, + types::zkp::NativeField, zkp_program, BackendField, BulletproofsBackend, Compiler, Error, + ZkpBackend, ZkpProgramInput, ZkpRuntime, }; type BPField = NativeField<::Field>; -fn main() { +fn main() -> Result<(), Error> { let app = Compiler::new() .zkp_backend::() .zkp_program(sudoku_proof) - .compile() - .unwrap(); + .compile()?; let prog = app.get_zkp_program(sudoku_proof).unwrap(); - let runtime = ZkpRuntime::new(&BulletproofsBackend::new()).unwrap(); + let runtime = ZkpRuntime::new(&BulletproofsBackend::new())?; let ex_puzzle = [ [0, 7, 0, 0, 2, 0, 0, 4, 6], @@ -44,11 +43,11 @@ fn main() { let cons: Vec = vec![ex_puzzle.map(|a| a.map(BPField::from)).into()]; - let proof = runtime.prove(prog, cons.clone(), vec![], board).unwrap(); + let proof = runtime.prove(prog, cons.clone(), vec![], board)?; - let verify = runtime.verify(prog, &proof, cons, vec![]); + runtime.verify(prog, &proof, cons, vec![])?; - assert!(verify.is_ok()); + Ok(()) } #[zkp_program(backend = "bulletproofs")] From c99f1e9032a546312c2965afc491b4ad7ffbcc22 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Wed, 28 Jun 2023 11:34:33 -0600 Subject: [PATCH 07/41] Refactor pattern matching No functionality changes --- sunscreen_compiler_common/src/macros/mod.rs | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/sunscreen_compiler_common/src/macros/mod.rs b/sunscreen_compiler_common/src/macros/mod.rs index 674930ee4..14d21c6cc 100644 --- a/sunscreen_compiler_common/src/macros/mod.rs +++ b/sunscreen_compiler_common/src/macros/mod.rs @@ -95,7 +95,7 @@ pub enum ExtractFnArgumentsError { IllegalType(Span), /** - * The given type pattern is not of the a qualified path to a type. + * The given type pattern is not a qualified path to a type. */ IllegalPat(Span), } @@ -121,18 +121,11 @@ pub fn extract_fn_arguments( return Err(ExtractFnArgumentsError::ContainsSelf(i.span())); } FnArg::Typed(t) => match (&*t.ty, &*t.pat) { - (Type::Path(_), Pat::Ident(i)) => (t.attrs.clone(), &*t.ty, &i.ident), - (Type::Array(_), Pat::Ident(i)) => (t.attrs.clone(), &*t.ty, &i.ident), - _ => { - match &*t.pat { - Pat::Ident(_) => {} - _ => { - return Err(ExtractFnArgumentsError::IllegalPat(t.span())); - } - }; - - return Err(ExtractFnArgumentsError::IllegalType(t.span())); + (Type::Path(_) | Type::Array(_), Pat::Ident(i)) => { + (t.attrs.clone(), &*t.ty, &i.ident) } + (_, Pat::Ident(_)) => return Err(ExtractFnArgumentsError::IllegalType(t.span())), + _ => return Err(ExtractFnArgumentsError::IllegalPat(t.span())), }, }; From a6ddd6f41b25a380c50df8f822017fde662d135a Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Wed, 28 Jun 2023 11:56:32 -0600 Subject: [PATCH 08/41] Disallow `mut` args in fhe/zkp programs --- sunscreen/tests/array.rs | 4 +- sunscreen_compiler_common/src/macros/mod.rs | 44 +++++++++++++++++--- sunscreen_compiler_macros/src/fhe_program.rs | 3 ++ sunscreen_compiler_macros/src/zkp_program.rs | 3 +- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/sunscreen/tests/array.rs b/sunscreen/tests/array.rs index 37a2422e1..67631bf6f 100644 --- a/sunscreen/tests/array.rs +++ b/sunscreen/tests/array.rs @@ -195,7 +195,7 @@ fn cipher_plain_arrays() { #[test] fn can_mutate_array() { #[fhe_program(scheme = "bfv")] - fn mult(mut a: [Cipher; 6]) -> Cipher { + fn mult(a: [Cipher; 6]) -> Cipher { let mut a = a; for i in 0..a.len() { @@ -242,7 +242,7 @@ fn can_mutate_array() { #[test] fn can_return_array() { #[fhe_program(scheme = "bfv")] - fn mult(mut a: [Cipher; 6]) -> [Cipher; 6] { + fn mult(a: [Cipher; 6]) -> [Cipher; 6] { let mut a = a; for i in 0..a.len() { diff --git a/sunscreen_compiler_common/src/macros/mod.rs b/sunscreen_compiler_common/src/macros/mod.rs index 14d21c6cc..75963133a 100644 --- a/sunscreen_compiler_common/src/macros/mod.rs +++ b/sunscreen_compiler_common/src/macros/mod.rs @@ -2,8 +2,8 @@ use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::{ parse_quote, parse_quote_spanned, punctuated::Punctuated, spanned::Spanned, token::PathSep, - AngleBracketedGenericArguments, Attribute, FnArg, Ident, Index, Pat, PathArguments, ReturnType, - Token, Type, + AngleBracketedGenericArguments, Attribute, FnArg, Ident, Index, Pat, PatIdent, PathArguments, + ReturnType, Token, Type, }; mod type_name; @@ -89,6 +89,14 @@ pub enum ExtractFnArgumentsError { */ ContainsSelf(Span), + /** + * The method specifies a mutable argument. + * + * # Remarks + * FHE and ZKP programs must be pure functions. + */ + ContainsMut(Span), + /** * The given type is not allowed. */ @@ -120,11 +128,17 @@ pub fn extract_fn_arguments( FnArg::Receiver(_) => { return Err(ExtractFnArgumentsError::ContainsSelf(i.span())); } - FnArg::Typed(t) => match (&*t.ty, &*t.pat) { - (Type::Path(_) | Type::Array(_), Pat::Ident(i)) => { - (t.attrs.clone(), &*t.ty, &i.ident) + FnArg::Typed(t) => match &*t.pat { + Pat::Ident(PatIdent { + mutability: Some(m), + .. + }) => { + return Err(ExtractFnArgumentsError::ContainsMut(m.span())); } - (_, Pat::Ident(_)) => return Err(ExtractFnArgumentsError::IllegalType(t.span())), + Pat::Ident(i) => match *t.ty { + Type::Path(_) | Type::Array(_) => (t.attrs.clone(), &*t.ty, &i.ident), + _ => return Err(ExtractFnArgumentsError::IllegalType(t.span())), + }, _ => return Err(ExtractFnArgumentsError::IllegalPat(t.span())), }, }; @@ -545,6 +559,24 @@ mod test { }; } + #[test] + fn disallows_mut_arguments() { + let type_name = quote! { + mut a: [[Cipher; 7]; 6], b: Cipher + }; + + let args: Punctuated = parse_quote!(#type_name); + + let extracted = extract_fn_arguments(&args); + + match extracted { + Err(ExtractFnArgumentsError::ContainsMut(_)) => {} + _ => { + panic!("Expected ExtractFnArgumentsError::ContainsMut"); + } + }; + } + #[test] fn can_extract_no_return_type() { let return_type: Type = parse_quote! { diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index d53e150f3..85154f776 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -48,6 +48,9 @@ pub fn fhe_program_impl( ExtractFnArgumentsError::ContainsSelf(s) => { quote_spanned! {s => compile_error!("FHE programs must not contain `self`") } } + ExtractFnArgumentsError::ContainsMut(s) => { + quote_spanned! {s => compile_error!("FHE program arguments cannot be `mut`") } + } ExtractFnArgumentsError::IllegalPat(s) => quote_spanned! { s => compile_error! { "Expected Identifier" } }, diff --git a/sunscreen_compiler_macros/src/zkp_program.rs b/sunscreen_compiler_macros/src/zkp_program.rs index 34ce612de..2220b3361 100644 --- a/sunscreen_compiler_macros/src/zkp_program.rs +++ b/sunscreen_compiler_macros/src/zkp_program.rs @@ -132,7 +132,8 @@ fn parse_inner(_attr_params: ZkpProgramAttrs, input_fn: ItemFn) -> Result>>()? }, - Err(ExtractFnArgumentsError::ContainsSelf(s)) => Err(Error::compile_error(s, "ZKP programs must not contain self"))?, + Err(ExtractFnArgumentsError::ContainsSelf(s)) => Err(Error::compile_error(s, "ZKP programs must not contain `self`"))?, + Err(ExtractFnArgumentsError::ContainsMut(s)) => Err(Error::compile_error(s, "ZKP program arguments cannot be `mut`"))?, Err(ExtractFnArgumentsError::IllegalPat(s)) => Err(Error::compile_error(s, "Expected Identifier"))?, Err(ExtractFnArgumentsError::IllegalType(s)) => Err(Error::compile_error(s, "ZKP program arguments must be an array or named struct type"))?, }; From 2def1edae9c8afc4b782283ee22c636ebb59ceef Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Wed, 28 Jun 2023 20:57:18 -0600 Subject: [PATCH 09/41] Play around with allowing cipher|plain values --- .../src/types/intern/fhe_program_node.rs | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index b0da3721e..a3299ab99 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -508,3 +508,79 @@ where T::type_name() } } + +/// Marker that indeterminate value is in stage literal +pub struct StageLiteral; +/// Marker that indeterminate value is in stage cipher +pub struct StageCipher; // TODO or StageT ? do ppl lit + plain? + +/// This type comes in handy when constructing values that start off as literals but turn into +/// ciphertexts; e.g. `let sum = 0; sum = sum + cipher`. +// Notice `lit` is the only field; if this turns into a cipher text value, it won't actually be +// held at the value level, but rather in the graph as a node. This will be denoted at the type +// level via the stage `S`. +pub struct Indeterminate { + lit: L, + _type: std::marker::PhantomData, + _stage: std::marker::PhantomData, +} + +impl Indeterminate { + /// Create a new `Indeterminate` value. This _always_ starts off as a literal value. + pub fn new(lit: L) -> Self { + Self { + lit, + _type: std::marker::PhantomData, + _stage: std::marker::PhantomData, + } + } +} + +impl NumCiphertexts for Indeterminate { + const NUM_CIPHERTEXTS: usize = ::NUM_CIPHERTEXTS; +} + +// Below is kinda hacky, but it would be very tedious to add impls for GraphCipher* on each of the +// FHE types. + +// [literal]|cipher + cipher outputs literal|[cipher] +impl Add>> for Indeterminate +where + T: FheType + GraphCipherConstAdd, + L: FheLiteral, +{ + type Output = FheProgramNode>; + + fn add(self, rhs: FheProgramNode>) -> Self::Output { + // perform addition as usual + let node = T::graph_cipher_const_add(rhs, self.lit); + // but swap marker type + FheProgramNode { + ids: node.ids, + _phantom: std::marker::PhantomData, + } + } +} +// literal|[cipher] + cipher outputs literal|[cipher] +impl Add>> for FheProgramNode> +where + T: FheType + GraphCipherAdd, + L: FheLiteral, +{ + type Output = Self; + + fn add(self, rhs: FheProgramNode>) -> Self::Output { + // swap marker on the indeterminate; we know its in stage cipher + let cipher_node = FheProgramNode { + ids: self.ids, + _phantom: std::marker::PhantomData, + }; + // perform addition as usual + let out = T::graph_cipher_add(cipher_node, rhs); + // swap marker type back + FheProgramNode { + ids: out.ids, + _phantom: std::marker::PhantomData, + } + } +} From b06c848fcd26f4b2dd37dd28f1785489fb2f55ed Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Wed, 28 Jun 2023 23:04:21 -0600 Subject: [PATCH 10/41] Allow user-declared plain|cipher values NOTE: Not fully implemented. Will not work on Rational types until we factor out literal->plaintext into a proper trait. This allows, e.g. ```rust fn simple_sum(a: Cipher, b: Cipher) -> Cipher { let mut sum = fhe_var(0); sum = sum + a; sum = sum + b; fhe_out(sum) } ```` --- .../src/types/intern/fhe_program_node.rs | 169 ++++++++++++------ 1 file changed, 113 insertions(+), 56 deletions(-) diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index a3299ab99..8693390e3 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -1,5 +1,5 @@ use crate::{ - fhe::with_fhe_ctx, + fhe::{with_fhe_ctx, FheContextOps}, types::{ intern::FheLiteral, ops::*, Cipher, FheType, LaneCount, NumCiphertexts, SwapRows, Type, TypeName, @@ -7,6 +7,7 @@ use crate::{ INDEX_ARENA, }; use petgraph::stable_graph::NodeIndex; +use sunscreen_runtime::TypeNameInstance; use std::ops::{Add, Div, Mul, Neg, Shl, Shr, Sub}; @@ -42,13 +43,15 @@ use std::ops::{Add, Div, Mul, Neg, Shl, Shr, Sub}; * Violating any of these conditions may result in memory corruption or * use-after-free. */ -pub struct FheProgramNode { - /** - * The ids on this node. The 'static lifetime on this slice is a lie. The sunscreen - * compiler must ensure that no FheProgramNode exists after FHE program construction. - */ +pub struct FheProgramNode { + /// The ids on this node. The 'static lifetime on this slice is a lie. The sunscreen + /// compiler must ensure that no FheProgramNode exists after FHE program construction. pub ids: &'static [NodeIndex], + /// Typically unused, but can be added to store value-level information on the graph nodes. + stage: S, + + /// Marks the type of the value that this graph node corresponds to. _phantom: std::marker::PhantomData, } @@ -73,6 +76,12 @@ impl FheProgramNode { * result in use-after-free. */ pub fn new(ids: &[NodeIndex]) -> Self { + Self::new_with_stage(ids, ()) + } +} + +impl FheProgramNode { + fn new_with_stage(ids: &[NodeIndex], stage: S) -> Self { INDEX_ARENA.with(|allocator| { let allocator = allocator.borrow(); let ids_dest = allocator.alloc_slice_copy(ids); @@ -85,6 +94,7 @@ impl FheProgramNode { // We invoke the dark transmutation ritual to turn a finite lifetime into a 'static. Self { ids: unsafe { std::mem::transmute(ids_dest) }, + stage, _phantom: std::marker::PhantomData, } }) @@ -509,78 +519,125 @@ where } } -/// Marker that indeterminate value is in stage literal -pub struct StageLiteral; -/// Marker that indeterminate value is in stage cipher -pub struct StageCipher; // TODO or StageT ? do ppl lit + plain? - /// This type comes in handy when constructing values that start off as literals but turn into /// ciphertexts; e.g. `let sum = 0; sum = sum + cipher`. -// Notice `lit` is the only field; if this turns into a cipher text value, it won't actually be -// held at the value level, but rather in the graph as a node. This will be denoted at the type -// level via the stage `S`. -pub struct Indeterminate { - lit: L, +/// +/// # Warning +/// It is illegal to output an `FheProgramNode` with `S == Stage::Literal`. +pub enum Stage { + /// Initial stage of indeterminate type: literal/plaintext + Literal, + /// Ciphertext stage: occurs after any operations with ciphertext + Cipher, +} + +/// Used in tandem with `Stage`. Ultimately, the purpose is to allow a single type to span +/// plaintexts and ciphertexts. The only requirement is that, upon output, the type must resolve to +/// a ciphertext. +pub struct Indeterminate { + _lit: std::marker::PhantomData, _type: std::marker::PhantomData, - _stage: std::marker::PhantomData, } -impl Indeterminate { - /// Create a new `Indeterminate` value. This _always_ starts off as a literal value. - pub fn new(lit: L) -> Self { - Self { - lit, - _type: std::marker::PhantomData, - _stage: std::marker::PhantomData, - } +// TypeNameInstance + TryIntoPlaintext + TryFromPlaintext + FheProgramInputTrait + NumCiphertexts + +impl TypeNameInstance for Indeterminate +where + L: FheLiteral, + T: FheType + TypeName, +{ + fn type_name_instance(&self) -> Type { + T::type_name() } } -impl NumCiphertexts for Indeterminate { - const NUM_CIPHERTEXTS: usize = ::NUM_CIPHERTEXTS; +impl NumCiphertexts for Indeterminate +where + L: FheLiteral, + T: FheType + NumCiphertexts, +{ + const NUM_CIPHERTEXTS: usize = T::NUM_CIPHERTEXTS; } -// Below is kinda hacky, but it would be very tedious to add impls for GraphCipher* on each of the -// FHE types. - -// [literal]|cipher + cipher outputs literal|[cipher] -impl Add>> for Indeterminate +// TODO turn this into `fhe_var!` macro +/// Create a new fhe program variable from any supported literal type. +pub fn fhe_var(lit: L) -> FheProgramNode, Stage> +where + L: FheLiteral, + T: FheType + TryFrom, + >::Error: std::fmt::Debug, +{ + with_fhe_ctx(|ctx| { + // TODO We need a trait like FheLiteral to hold the logic for making plaintext node_ids in the + // graph context (the stuff that happens in graph_cipher_*); for now just restrict to types that + // have one ciphertext length; this WILL NOT WORK for rationals. + let lit = T::try_from(lit) + .unwrap() + .try_into_plaintext(&ctx.data) + .unwrap(); + let lit = ctx.add_plaintext_literal(lit.inner); + FheProgramNode::new_with_stage(&[lit], Stage::Literal) + }) +} + +// TODO make this automatic somehow +/// Output your fhe program variable as a ciphertext. This will fail (at fhe program compile time) +/// if the variable is still a literal. +pub fn fhe_out(var: FheProgramNode, Stage>) -> FheProgramNode> where - T: FheType + GraphCipherConstAdd, L: FheLiteral, + T: FheType, { - type Output = FheProgramNode>; - - fn add(self, rhs: FheProgramNode>) -> Self::Output { - // perform addition as usual - let node = T::graph_cipher_const_add(rhs, self.lit); - // but swap marker type - FheProgramNode { - ids: node.ids, - _phantom: std::marker::PhantomData, + match var.stage { + Stage::Literal => panic!("User created FHE variables must undergo arithmetic operations with ciphertexts before they are returned as output."), + Stage::Cipher => { + FheProgramNode { + ids: var.ids, + stage: (), + _phantom: std::marker::PhantomData, + } } } } -// literal|[cipher] + cipher outputs literal|[cipher] -impl Add>> for FheProgramNode> + +// Below is kinda hacky, but it would be very tedious to add impls for GraphCipher* on each of the +// FHE types. + +// literal|cipher + cipher outputs literal|[cipher] +impl Add>> for FheProgramNode, Stage> where - T: FheType + GraphCipherAdd, + T: FheType + GraphCipherPlainAdd + GraphCipherAdd, L: FheLiteral, { type Output = Self; fn add(self, rhs: FheProgramNode>) -> Self::Output { - // swap marker on the indeterminate; we know its in stage cipher - let cipher_node = FheProgramNode { - ids: self.ids, - _phantom: std::marker::PhantomData, + let node = match self.stage { + Stage::Literal => { + let lit_node = coerce(self, ()); + // N.B. we've already added this literal as a plaintext node + T::graph_cipher_plain_add(rhs, lit_node) + } + Stage::Cipher => { + let cipher_node = coerce(self, ()); + T::graph_cipher_add(rhs, cipher_node) + } }; - // perform addition as usual - let out = T::graph_cipher_add(cipher_node, rhs); - // swap marker type back - FheProgramNode { - ids: out.ids, - _phantom: std::marker::PhantomData, - } + // No matter what `self.stage` currently is, it is being added to a ciphertext, so its next + // stage is cipher. + coerce(node, Stage::Cipher) + } +} + +/// WARNING: This is an unsafe function. It allows casting graph nodes arbitrarily. Use with +/// caution. +fn coerce( + a: FheProgramNode, + t: T, +) -> FheProgramNode { + FheProgramNode { + ids: a.ids, + stage: t, + _phantom: std::marker::PhantomData, } } From b02dc8e5d16f17daa8b0c068365ac4e3da72c768 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 09:13:26 -0600 Subject: [PATCH 11/41] Refactor array::output() --- sunscreen/src/types/intern/output.rs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/sunscreen/src/types/intern/output.rs b/sunscreen/src/types/intern/output.rs index 2fd54b036..9743ee539 100644 --- a/sunscreen/src/types/intern/output.rs +++ b/sunscreen/src/types/intern/output.rs @@ -50,15 +50,6 @@ where type Output = [T::Output; N]; fn output(&self) -> Self::Output { - let mut outputs = Vec::with_capacity(N); - - for i in self { - outputs.push(i.output()); - } - - match outputs.try_into() { - Ok(v) => v, - _ => unreachable!("Internal error. Length mismatch"), - } + self.map(|i| i.output()) } } From a7358fd0c86e8d0dfc59242eae764173d28ab4d0 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 10:02:22 -0600 Subject: [PATCH 12/41] More targeted compiler error messages on invalid return values --- .../src/fhe_program_transforms.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sunscreen_compiler_macros/src/fhe_program_transforms.rs b/sunscreen_compiler_macros/src/fhe_program_transforms.rs index b017eb73e..69d5f2acf 100644 --- a/sunscreen_compiler_macros/src/fhe_program_transforms.rs +++ b/sunscreen_compiler_macros/src/fhe_program_transforms.rs @@ -119,17 +119,21 @@ pub fn pack_return_type(return_types: &[Type]) -> Type { } pub fn emit_output_capture(return_types: &[Type]) -> TokenStream2 { - match return_types.len() { - 1 => quote_spanned! { return_types[0].span() => v.output(); }, + match return_types { + [ty] => quote_spanned! { ty.span() => { + struct _AssertOutput where FheProgramNode<#ty>: Output; + v.output(); + }}, _ => return_types .iter() .enumerate() - .map(|(i, t)| { + .map(|(i, ty)| { let index = Index::from(i); - quote_spanned! {t.span() => + quote_spanned! {ty.span() => { + struct _AssertOutput where FheProgramNode<#ty>: Output; v.#index.output(); - } + }} }) .collect(), } From 1c179eda76e71128b1fc12058001f65a2a7ed1b1 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 13:09:43 -0600 Subject: [PATCH 13/41] Add option for var.into() rather than fhe_out(var) --- sunscreen/src/types/intern/fhe_program_node.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index 8693390e3..8d269a83b 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -580,9 +580,9 @@ where }) } -// TODO make this automatic somehow +// TODO make this automatic somehow? /// Output your fhe program variable as a ciphertext. This will fail (at fhe program compile time) -/// if the variable is still a literal. +/// if the variable is still a literal. You can also use `.into()` to accomplish the same thing. pub fn fhe_out(var: FheProgramNode, Stage>) -> FheProgramNode> where L: FheLiteral, @@ -600,6 +600,16 @@ where } } +impl From, Stage>> for FheProgramNode> +where + L: FheLiteral, + T: FheType, +{ + fn from(value: FheProgramNode, Stage>) -> Self { + fhe_out(value) + } +} + // Below is kinda hacky, but it would be very tedious to add impls for GraphCipher* on each of the // FHE types. From ee1455d69b27cf25ab271c4b6afb17e1db7d7803 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 13:28:06 -0600 Subject: [PATCH 14/41] Fix incorrect macro invocation --- sunscreen_compiler_macros/src/fhe_program.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index 85154f776..d2334b90f 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -46,10 +46,10 @@ pub fn fhe_program_impl( Err(e) => { return proc_macro::TokenStream::from(match e { ExtractFnArgumentsError::ContainsSelf(s) => { - quote_spanned! {s => compile_error!("FHE programs must not contain `self`") } + quote_spanned! {s => compile_error! { "FHE programs must not contain `self`" } } } ExtractFnArgumentsError::ContainsMut(s) => { - quote_spanned! {s => compile_error!("FHE program arguments cannot be `mut`") } + quote_spanned! {s => compile_error! { "FHE program arguments cannot be `mut`" } } } ExtractFnArgumentsError::IllegalPat(s) => quote_spanned! { s => compile_error! { "Expected Identifier" } From b48f870adf8b740b4b309560f1cac52e4d7ffa40 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 14:09:17 -0600 Subject: [PATCH 15/41] Add trait for inserting const as plaintext --- sunscreen/src/types/bfv/batched.rs | 14 ++++++++ sunscreen/src/types/bfv/fractional.rs | 20 +++++++++-- sunscreen/src/types/bfv/rational.rs | 19 +++++++++++ sunscreen/src/types/bfv/signed.rs | 15 +++++++++ sunscreen/src/types/bfv/unsigned.rs | 15 +++++++++ .../src/types/intern/fhe_program_node.rs | 33 ++++++------------- sunscreen/src/types/ops/insert.rs | 28 ++++++++++++++++ sunscreen/src/types/ops/mod.rs | 2 ++ 8 files changed, 120 insertions(+), 26 deletions(-) create mode 100644 sunscreen/src/types/ops/insert.rs diff --git a/sunscreen/src/types/bfv/batched.rs b/sunscreen/src/types/bfv/batched.rs index 2eff2de43..7f8e2ef1a 100644 --- a/sunscreen/src/types/bfv/batched.rs +++ b/sunscreen/src/types/bfv/batched.rs @@ -554,6 +554,20 @@ impl GraphCipherMul for Batched { } } +impl GraphCipherInsert for Batched { + type Lit = i64; + type Val = Self; + + fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode { + with_fhe_ctx(|ctx| { + let lit = Self::from(lit).try_into_plaintext(&ctx.data).unwrap(); + let l = ctx.add_plaintext_literal(lit.inner); + + FheProgramNode::new(&[l]) + }) + } +} + impl GraphCipherConstMul for Batched { type Left = Self; type Right = i64; diff --git a/sunscreen/src/types/bfv/fractional.rs b/sunscreen/src/types/bfv/fractional.rs index 867a4d2b9..ec0876f6e 100644 --- a/sunscreen/src/types/bfv/fractional.rs +++ b/sunscreen/src/types/bfv/fractional.rs @@ -5,9 +5,9 @@ use crate::{ types::{ ops::{ GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstDiv, GraphCipherConstMul, - GraphCipherConstSub, GraphCipherMul, GraphCipherNeg, GraphCipherPlainAdd, - GraphCipherPlainMul, GraphCipherPlainSub, GraphCipherSub, GraphConstCipherSub, - GraphPlainCipherSub, + GraphCipherConstSub, GraphCipherInsert, GraphCipherMul, GraphCipherNeg, + GraphCipherPlainAdd, GraphCipherPlainMul, GraphCipherPlainSub, GraphCipherSub, + GraphConstCipherSub, GraphPlainCipherSub, }, Cipher, }, @@ -236,6 +236,20 @@ impl GraphCipherPlainAdd for Fractional { } } +impl GraphCipherInsert for Fractional { + type Lit = f64; + type Val = Self; + + fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode { + with_fhe_ctx(|ctx| { + let lit = Self::from(lit).try_into_plaintext(&ctx.data).unwrap(); + let lit = ctx.add_plaintext_literal(lit.inner); + + FheProgramNode::new(&[lit]) + }) + } +} + impl GraphCipherConstAdd for Fractional { type Left = Fractional; type Right = f64; diff --git a/sunscreen/src/types/bfv/rational.rs b/sunscreen/src/types/bfv/rational.rs index b472bf0fb..2e677b1f2 100644 --- a/sunscreen/src/types/bfv/rational.rs +++ b/sunscreen/src/types/bfv/rational.rs @@ -314,6 +314,25 @@ impl GraphCipherPlainAdd for Rational { } } +impl GraphCipherInsert for Rational { + type Lit = f64; + type Val = Self; + + fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode { + with_fhe_ctx(|ctx| { + let lit = Self::try_from(lit).unwrap(); + + let lit_num = + ctx.add_plaintext_literal(lit.num.try_into_plaintext(&ctx.data).unwrap().inner); + + let lit_den = + ctx.add_plaintext_literal(lit.den.try_into_plaintext(&ctx.data).unwrap().inner); + + FheProgramNode::new(&[lit_num, lit_den]) + }) + } +} + impl GraphCipherConstAdd for Rational { type Left = Self; type Right = f64; diff --git a/sunscreen/src/types/bfv/signed.rs b/sunscreen/src/types/bfv/signed.rs index 874ad7081..1088bf968 100644 --- a/sunscreen/src/types/bfv/signed.rs +++ b/sunscreen/src/types/bfv/signed.rs @@ -1,6 +1,7 @@ use seal_fhe::Plaintext as SealPlaintext; use crate as sunscreen; +use crate::types::ops::GraphCipherInsert; use crate::{ fhe::{with_fhe_ctx, FheContextOps}, types::{ @@ -277,6 +278,20 @@ impl GraphCipherPlainAdd for Signed { } } +impl GraphCipherInsert for Signed { + type Lit = i64; + type Val = Self; + + fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode { + with_fhe_ctx(|ctx| { + let lit = Self::from(lit).try_into_plaintext(&ctx.data).unwrap(); + let lit = ctx.add_plaintext_literal(lit.inner); + + FheProgramNode::new(&[lit]) + }) + } +} + impl GraphCipherConstAdd for Signed { type Left = Self; type Right = i64; diff --git a/sunscreen/src/types/bfv/unsigned.rs b/sunscreen/src/types/bfv/unsigned.rs index b782c1f07..270230204 100644 --- a/sunscreen/src/types/bfv/unsigned.rs +++ b/sunscreen/src/types/bfv/unsigned.rs @@ -9,6 +9,7 @@ use sunscreen_runtime::{ }; use crate as sunscreen; +use crate::types::ops::GraphCipherInsert; use crate::{ fhe::{with_fhe_ctx, FheContextOps}, types::{ @@ -278,6 +279,20 @@ impl_graph_cipher_op! { (Mul, multiplication) } +impl GraphCipherInsert for Unsigned { + type Lit = UInt; + type Val = Self; + + fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode { + with_fhe_ctx(|ctx| { + let lit = Self::from(lit).try_into_plaintext(&ctx.data).unwrap(); + let lit = ctx.add_plaintext_literal(lit.inner); + + FheProgramNode::new(&[lit]) + }) + } +} + impl GraphConstCipherSub for Unsigned { type Left = UInt; type Right = Self; diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index 8d269a83b..1b61ac035 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -1,5 +1,5 @@ use crate::{ - fhe::{with_fhe_ctx, FheContextOps}, + fhe::with_fhe_ctx, types::{ intern::FheLiteral, ops::*, Cipher, FheType, LaneCount, NumCiphertexts, SwapRows, Type, TypeName, @@ -564,20 +564,10 @@ where pub fn fhe_var(lit: L) -> FheProgramNode, Stage> where L: FheLiteral, - T: FheType + TryFrom, - >::Error: std::fmt::Debug, -{ - with_fhe_ctx(|ctx| { - // TODO We need a trait like FheLiteral to hold the logic for making plaintext node_ids in the - // graph context (the stuff that happens in graph_cipher_*); for now just restrict to types that - // have one ciphertext length; this WILL NOT WORK for rationals. - let lit = T::try_from(lit) - .unwrap() - .try_into_plaintext(&ctx.data) - .unwrap(); - let lit = ctx.add_plaintext_literal(lit.inner); - FheProgramNode::new_with_stage(&[lit], Stage::Literal) - }) + T: FheType + GraphCipherInsert, +{ + let node = T::graph_cipher_insert(lit); + coerce(node, Stage::Literal) } // TODO make this automatic somehow? @@ -591,11 +581,11 @@ where match var.stage { Stage::Literal => panic!("User created FHE variables must undergo arithmetic operations with ciphertexts before they are returned as output."), Stage::Cipher => { - FheProgramNode { - ids: var.ids, - stage: (), - _phantom: std::marker::PhantomData, - } + FheProgramNode { + ids: var.ids, + stage: (), + _phantom: std::marker::PhantomData, + } } } } @@ -610,9 +600,6 @@ where } } -// Below is kinda hacky, but it would be very tedious to add impls for GraphCipher* on each of the -// FHE types. - // literal|cipher + cipher outputs literal|[cipher] impl Add>> for FheProgramNode, Stage> where diff --git a/sunscreen/src/types/ops/insert.rs b/sunscreen/src/types/ops/insert.rs new file mode 100644 index 000000000..adb719a34 --- /dev/null +++ b/sunscreen/src/types/ops/insert.rs @@ -0,0 +1,28 @@ +use crate::types::{ + intern::{FheLiteral, FheProgramNode}, + FheType, +}; + +/** + * Called when an Fhe Program encounters a literal type and inserts it as plaintext node. + * + * This trait is an implementation detail of FHE program compilation; + * you should not directly call methods on this trait. + */ +pub trait GraphCipherInsert { + /** + * The type of the literal + */ + type Lit: FheLiteral; + + /** + * The type of the plaintext encoding + */ + // TODO if this is always Self, then remove it + type Val: FheType; + + /** + * Process the insertion + */ + fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode; +} diff --git a/sunscreen/src/types/ops/mod.rs b/sunscreen/src/types/ops/mod.rs index 6afb682bf..f0690b6f7 100644 --- a/sunscreen/src/types/ops/mod.rs +++ b/sunscreen/src/types/ops/mod.rs @@ -1,5 +1,6 @@ mod add; mod div; +mod insert; mod mul; mod neg; mod rotate; @@ -7,6 +8,7 @@ mod sub; pub use add::*; pub use div::*; +pub use insert::*; pub use mul::*; pub use neg::*; pub use rotate::*; From 59664442bb1f3e5d5605a1d6c8f09c7884fcb21f Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 14:48:57 -0600 Subject: [PATCH 16/41] Impl all arithmetic operations for indeterminate nodes --- .../src/types/intern/fhe_program_node.rs | 96 +++++++++++++------ 1 file changed, 68 insertions(+), 28 deletions(-) diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index 1b61ac035..4d4d97b89 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -6,6 +6,7 @@ use crate::{ }, INDEX_ARENA, }; +use paste::paste; use petgraph::stable_graph::NodeIndex; use sunscreen_runtime::TypeNameInstance; @@ -157,6 +158,8 @@ where } } +// TODO can these literal impls be combined into `L: FheLiteral` ? + // literal + cipher impl Add>> for u64 where @@ -539,8 +542,6 @@ pub struct Indeterminate { _type: std::marker::PhantomData, } -// TypeNameInstance + TryIntoPlaintext + TryFromPlaintext + FheProgramInputTrait + NumCiphertexts - impl TypeNameInstance for Indeterminate where L: FheLiteral, @@ -600,32 +601,6 @@ where } } -// literal|cipher + cipher outputs literal|[cipher] -impl Add>> for FheProgramNode, Stage> -where - T: FheType + GraphCipherPlainAdd + GraphCipherAdd, - L: FheLiteral, -{ - type Output = Self; - - fn add(self, rhs: FheProgramNode>) -> Self::Output { - let node = match self.stage { - Stage::Literal => { - let lit_node = coerce(self, ()); - // N.B. we've already added this literal as a plaintext node - T::graph_cipher_plain_add(rhs, lit_node) - } - Stage::Cipher => { - let cipher_node = coerce(self, ()); - T::graph_cipher_add(rhs, cipher_node) - } - }; - // No matter what `self.stage` currently is, it is being added to a ciphertext, so its next - // stage is cipher. - coerce(node, Stage::Cipher) - } -} - /// WARNING: This is an unsafe function. It allows casting graph nodes arbitrarily. Use with /// caution. fn coerce( @@ -638,3 +613,68 @@ fn coerce( _phantom: std::marker::PhantomData, } } + +macro_rules! impl_indeterminate_arithmetic_op { + ($($op:ident),+) => { + $( + paste! { + // literal|cipher <> cipher outputs literal|[cipher] + impl $op>> for FheProgramNode, Stage> + where + T: FheType + [] + [], + L: FheLiteral, + { + type Output = Self; + + fn [<$op:lower>](self, rhs: FheProgramNode>) -> Self::Output { + let node = match self.stage { + Stage::Literal => { + let lit_node = coerce(self, ()); + // N.B. we've already added this literal as a plaintext node + T::[](rhs, lit_node) + } + Stage::Cipher => { + let cipher_node = coerce(self, ()); + T::[](rhs, cipher_node) + } + }; + // No matter what `self.stage` currently is, it is being operated on with a + // ciphertext, so its next stage is cipher. + coerce(node, Stage::Cipher) + } + } + + // cipher <> literal|cipher outputs literal|[cipher] + impl $op, Stage>> for FheProgramNode> + where + T: FheType + [] + [], + L: FheLiteral, + { + // A little bit of pick your poison here. However it is more likely that the + // user is mutating an `fhe_var` than a normal ciphertext. Worst case, they call + // `.into()` on the resulting node. + type Output = FheProgramNode, Stage>; + + fn [<$op:lower>](self, rhs: FheProgramNode, Stage>) -> Self::Output { + let node = match rhs.stage { + Stage::Literal => { + let lit_node = coerce(rhs, ()); + // N.B. we've already added this literal as a plaintext node + T::[](self, lit_node) + } + Stage::Cipher => { + let cipher_node = coerce(rhs, ()); + T::[](self, cipher_node) + } + }; + // No matter what `rhs.stage` currently is, it is being added to a ciphertext, so its next + // stage is cipher. + coerce(node, Stage::Cipher) + } + } + } + )+ + }; +} + +impl_indeterminate_arithmetic_op! {Add, Sub, Mul, Div} From 2c0f69d2270474108f447118853749669dc6489f Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 20:31:09 -0600 Subject: [PATCH 17/41] Offer an `fhe_var!` macro --- .../src/types/intern/fhe_program_node.rs | 6 +- sunscreen/src/types/mod.rs | 69 +++++++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index 4d4d97b89..71a18ec21 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -560,9 +560,8 @@ where const NUM_CIPHERTEXTS: usize = T::NUM_CIPHERTEXTS; } -// TODO turn this into `fhe_var!` macro -/// Create a new fhe program variable from any supported literal type. -pub fn fhe_var(lit: L) -> FheProgramNode, Stage> +/// Create a new fhe program node from any supported literal type. +pub fn fhe_node(lit: L) -> FheProgramNode, Stage> where L: FheLiteral, T: FheType + GraphCipherInsert, @@ -571,7 +570,6 @@ where coerce(node, Stage::Literal) } -// TODO make this automatic somehow? /// Output your fhe program variable as a ciphertext. This will fail (at fhe program compile time) /// if the variable is still a literal. You can also use `.into()` to accomplish the same thing. pub fn fhe_out(var: FheProgramNode, Stage>) -> FheProgramNode> diff --git a/sunscreen/src/types/mod.rs b/sunscreen/src/types/mod.rs index 434e19209..90f851b0f 100644 --- a/sunscreen/src/types/mod.rs +++ b/sunscreen/src/types/mod.rs @@ -135,3 +135,72 @@ where } } } + +/// Creates new FHE variables from literals. Note that literals can be used directly in +/// arithmetic operations with ciphertexts: +/// +/// ``` +/// #[fhe_program(scheme = "bfv")] +/// fn add_ten(a: Cipher) -> Cipher { +/// a + 10 +/// } +/// ```` +/// +/// But if you want to define a variable that starts as a literal and later takes on a ciphertext +/// value, this won't work: +/// +/// ```compile_fail +/// #[fhe_program(scheme = "bfv")] +/// fn add_ten(a: Cipher) -> Cipher { +/// let sum = 0; +/// sum = sum + a +/// sum = sum + 10 +/// sum +/// } +/// ``` +/// +/// This is because the literal `0` won't have the correct [`Cipher`] type. Instead, you can use +/// this macro: +/// +/// ``` +/// #[fhe_program(scheme = "bfv")] +/// fn add_ten(a: Cipher) -> Cipher { +/// let sum = fhe_var!(0); +/// sum = sum + a +/// sum = sum + 10 +/// sum +/// } +/// ``` +/// +/// You can also create arrays of variables: +/// +/// ``` +/// #[fhe_program(scheme = "bfv")] +/// fn add_ten(a: Cipher) -> Cipher { +/// let mut sum = fhe_var(0); +/// let arr = fhe_var![1, 2, 4]; +/// let ones = fhe_var![1; 3]; +/// for x in arr { +/// sum = sum + x; +/// } +/// for y in ones { +/// sum = sum + y; +/// } +/// sum + a +/// } +/// ``` +#[macro_export] +macro_rules! fhe_var { + ($elem:literal) => ( + $crate::types::intern::fhe_node($elem) + ); + ($elem:literal; $n:expr) => ( + // TODO this will just copy the same node IDs. But that's ok, right? the graph nodes are + // immtuable anyway. + [$crate::types::intern::fhe_node($elem); $n] + ); + ($($elem:literal),+ $(,)?) => ( + [$($crate::types::intern::fhe_node($x)),+] + ); +} +pub use fhe_var; From bfa6e47b05d1b056328a245012ab0fbb0b30eee8 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 20:48:51 -0600 Subject: [PATCH 18/41] Offer a zkp_var! macro --- .../src/types/intern/fhe_program_node.rs | 5 +++ sunscreen/src/types/mod.rs | 36 +++++++++++++++++-- sunscreen/src/types/zkp/program_node.rs | 12 ++++++- 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index 71a18ec21..3d1609476 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -527,6 +527,7 @@ where /// /// # Warning /// It is illegal to output an `FheProgramNode` with `S == Stage::Literal`. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Stage { /// Initial stage of indeterminate type: literal/plaintext Literal, @@ -537,6 +538,7 @@ pub enum Stage { /// Used in tandem with `Stage`. Ultimately, the purpose is to allow a single type to span /// plaintexts and ciphertexts. The only requirement is that, upon output, the type must resolve to /// a ciphertext. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct Indeterminate { _lit: std::marker::PhantomData, _type: std::marker::PhantomData, @@ -570,6 +572,9 @@ where coerce(node, Stage::Literal) } +// TODO make this automatic (in #[fhe_program], define `internal` with generics from within a +// function with the proper +// return values, and call .into() on each output). /// Output your fhe program variable as a ciphertext. This will fail (at fhe program compile time) /// if the variable is still a literal. You can also use `.into()` to accomplish the same thing. pub fn fhe_out(var: FheProgramNode, Stage>) -> FheProgramNode> diff --git a/sunscreen/src/types/mod.rs b/sunscreen/src/types/mod.rs index 90f851b0f..9c1607991 100644 --- a/sunscreen/src/types/mod.rs +++ b/sunscreen/src/types/mod.rs @@ -200,7 +200,39 @@ macro_rules! fhe_var { [$crate::types::intern::fhe_node($elem); $n] ); ($($elem:literal),+ $(,)?) => ( - [$($crate::types::intern::fhe_node($x)),+] + [$($crate::types::intern::fhe_node($elem)),+] + ); +} + +/// Creates new ZKP variables from literals. +/// ``` +/// #[zkp_program(backend = "bulletproofs")] +/// fn equals_ten(a: NativeField) { +/// let ten = zkp_var!(10); +/// a.constrain_eq(ten); +/// } +/// ``` +/// +/// You can also create arrays of variables: +/// +/// ``` +/// #[zkp_program(backend = "bulletproofs")] +/// fn equals_ten(a: NativeField) { +/// let tens = zkp_var![10, 10, 10]; +/// for ten in tens { +/// a.constrain_eq(ten); +/// } +/// } +/// ``` +#[macro_export] +macro_rules! zkp_var { + ($elem:literal) => ( + $crate::types::zkp::zkp_node($elem) + ); + ($elem:literal; $n:expr) => ( + [$crate::types::zkp::zkp_node($elem); $n] + ); + ($($elem:literal),+ $(,)?) => ( + [$($crate::types::zkp::zkp_node($elem)),+] ); } -pub use fhe_var; diff --git a/sunscreen/src/types/zkp/program_node.rs b/sunscreen/src/types/zkp/program_node.rs index 3df6bc4ef..ac1f77f3a 100644 --- a/sunscreen/src/types/zkp/program_node.rs +++ b/sunscreen/src/types/zkp/program_node.rs @@ -1,4 +1,5 @@ use petgraph::stable_graph::NodeIndex; +use sunscreen_zkp_backend::BackendField; use std::{ marker::PhantomData, @@ -11,7 +12,7 @@ use crate::{ INDEX_ARENA, }; -use super::{ConstrainCmpVarVar, ConstrainEqVarVar}; +use super::{ConstrainCmpVarVar, ConstrainEqVarVar, NativeField}; #[derive(Clone, Copy)] /** @@ -37,6 +38,15 @@ where _phantom: PhantomData, } +/// Convenience function to create a ZKP program node +pub fn zkp_node(lit: L) -> ProgramNode> +where + F: BackendField, + NativeField: From, +{ + NativeField::::from(lit).into_program_node() +} + /** * Trait for adding inputs to a ZKP program */ From bd70632b9b0aa9277c9c14f166cffe1c4ce8ba49 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 20:49:16 -0600 Subject: [PATCH 19/41] Offer a (safe) debug impl for zkp program nodes --- sunscreen/src/types/zkp/program_node.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sunscreen/src/types/zkp/program_node.rs b/sunscreen/src/types/zkp/program_node.rs index ac1f77f3a..d3878c226 100644 --- a/sunscreen/src/types/zkp/program_node.rs +++ b/sunscreen/src/types/zkp/program_node.rs @@ -38,6 +38,12 @@ where _phantom: PhantomData, } +impl std::fmt::Debug for ProgramNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("ProgramNode") + } +} + /// Convenience function to create a ZKP program node pub fn zkp_node(lit: L) -> ProgramNode> where From 1d7e3742815dd18e77e71c6d32d8698b6967ddb4 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 21:38:52 -0600 Subject: [PATCH 20/41] Fix tests --- .../src/fhe_program_transforms.rs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/sunscreen_compiler_macros/src/fhe_program_transforms.rs b/sunscreen_compiler_macros/src/fhe_program_transforms.rs index 69d5f2acf..f0d769408 100644 --- a/sunscreen_compiler_macros/src/fhe_program_transforms.rs +++ b/sunscreen_compiler_macros/src/fhe_program_transforms.rs @@ -429,7 +429,10 @@ mod test { let actual = emit_output_capture(&extracted); let expected = quote! { - v.output(); + { + struct _AssertOutput where FheProgramNode< Cipher < Signed > > : Output; + v.output(); + } }; assert_syn_eq(&actual, &expected); @@ -448,8 +451,14 @@ mod test { let actual = emit_output_capture(&extracted); let expected = quote! { - v.0.output(); - v.1.output(); + { + struct _AssertOutput where FheProgramNode< Cipher < Signed > > : Output; + v.0.output(); + } + { + struct _AssertOutput where FheProgramNode<[[Cipher; 6]; 7]>: Output; + v.1.output(); + } }; assert_syn_eq(&actual, &expected); From f0cdf89e66d4886ec09937bdf2dd77990932bac7 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 22:40:12 -0600 Subject: [PATCH 21/41] Add test for fhe_var! --- sunscreen/tests/fhe_program_tests.rs | 51 ++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/sunscreen/tests/fhe_program_tests.rs b/sunscreen/tests/fhe_program_tests.rs index 68b6727e5..095ce9667 100644 --- a/sunscreen/tests/fhe_program_tests.rs +++ b/sunscreen/tests/fhe_program_tests.rs @@ -1,6 +1,7 @@ +use petgraph::stable_graph::node_index; use sunscreen::{ - fhe::{FheFrontendCompilation, CURRENT_FHE_CTX}, - fhe_program, + fhe::{FheFrontendCompilation, FheOperation, Literal, CURRENT_FHE_CTX}, + fhe_program, fhe_var, types::{bfv::Signed, Cipher, TypeName}, CallSignature, FheProgramFn, Params, SchemeType, SecurityLevel, }; @@ -268,6 +269,52 @@ fn can_mul() { ); } +#[test] +fn can_insert_literals() { + #[fhe_program(scheme = "bfv")] + fn fhe_program_sum(xs: [Cipher; 2]) -> Cipher { + let mut sum = fhe_var!(0); + for x in xs { + sum = sum + x; + } + sum.into() + } + + let arg_type_name = <[Cipher; 2]>::type_name(); + let ret_type_name = Cipher::::type_name(); + + let expected_signature = CallSignature { + arguments: vec![arg_type_name], + returns: vec![ret_type_name], + num_ciphertexts: vec![1], + }; + assert_eq!(fhe_program_sum.signature(), expected_signature); + assert_eq!(fhe_program_sum.scheme_type(), SchemeType::Bfv); + + let context = fhe_program_sum.build(&get_params()).unwrap(); + + // N.B. Can't match on json like the other tests because the operation to insert a literal + // plaintext ends up with a pointer handle in the json that + // changes on each run. This appears necessary (the underlying seal ptr gets bincode encoded to + // be tossed around). + assert_eq!(context.node_count(), 6); + assert_eq!( + context[node_index(0)].operation, + FheOperation::InputCiphertext + ); + assert_eq!( + context[node_index(1)].operation, + FheOperation::InputCiphertext + ); + assert!(matches!( + context[node_index(2)].operation, + FheOperation::Literal(Literal::Plaintext(_)) + )); + assert_eq!(context[node_index(3)].operation, FheOperation::AddPlaintext); + assert_eq!(context[node_index(4)].operation, FheOperation::Add); + assert_eq!(context[node_index(5)].operation, FheOperation::Output); +} + #[test] fn can_collect_output() { #[fhe_program(scheme = "bfv")] From ae7cb9e5b04f787ddcbf2d5a79d265745fdc4163 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 22:58:39 -0600 Subject: [PATCH 22/41] Simplify tf out of sudoku --- examples/sudoku_zkp/src/main.rs | 26 ++++++++++++++------------ sunscreen/src/types/mod.rs | 6 +++--- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/examples/sudoku_zkp/src/main.rs b/examples/sudoku_zkp/src/main.rs index 4b7e768a7..bf962eb62 100644 --- a/examples/sudoku_zkp/src/main.rs +++ b/examples/sudoku_zkp/src/main.rs @@ -1,6 +1,6 @@ use sunscreen::{ - types::zkp::NativeField, zkp_program, BackendField, BulletproofsBackend, Compiler, Error, - ZkpBackend, ZkpProgramInput, ZkpRuntime, + types::zkp::NativeField, zkp_program, zkp_var, BackendField, BulletproofsBackend, Compiler, + Error, ZkpBackend, ZkpProgramInput, ZkpRuntime, }; type BPField = NativeField<::Field>; @@ -55,22 +55,24 @@ fn sudoku_proof( #[constant] constraints: [[NativeField; 9]; 9], board: [[NativeField; 9]; 9], ) { - fn assert_unique_numbers(arr: [ProgramNode>; 9]) { + let zero = zkp_var!(0); + + let assert_unique_numbers = |squares| { for i in 1..=9 { - let mut circuit = NativeField::::from(1).into_program_node(); - for a in arr { - circuit = circuit * (NativeField::::from(i).into_program_node() - a); + let mut circuit = zkp_var!(1); + for s in squares { + circuit = circuit * (zkp_var!(i) - s); } - circuit.constrain_eq(NativeField::::from(0)); + circuit.constrain_eq(zero); } - } + }; + // Proves that the board matches up with the puzzle where applicable - let zero = NativeField::::from(0).into_program_node(); for i in 0..9 { for j in 0..9 { - let square = board[i][j].into_program_node(); - let constraint = constraints[i][j].into_program_node(); + let square = board[i][j]; + let constraint = constraints[i][j]; (constraint * (constraint - square)).constrain_eq(zero); } } @@ -93,7 +95,7 @@ fn sudoku_proof( let square = rows.iter().map(|s| &s[(j * 3)..(j * 3 + 3)]); - let flattened_sq: [ProgramNode>; 9] = square + let flattened_sq = square .flatten() .copied() .collect::>() diff --git a/sunscreen/src/types/mod.rs b/sunscreen/src/types/mod.rs index 9c1607991..c1aca63fb 100644 --- a/sunscreen/src/types/mod.rs +++ b/sunscreen/src/types/mod.rs @@ -226,13 +226,13 @@ macro_rules! fhe_var { /// ``` #[macro_export] macro_rules! zkp_var { - ($elem:literal) => ( + ($elem:expr) => ( $crate::types::zkp::zkp_node($elem) ); - ($elem:literal; $n:expr) => ( + ($elem:expr; $n:expr) => ( [$crate::types::zkp::zkp_node($elem); $n] ); - ($($elem:literal),+ $(,)?) => ( + ($($elem:expr),+ $(,)?) => ( [$($crate::types::zkp::zkp_node($elem)),+] ); } From 1656ec8be0b7f78923480e65aa9d257866763801 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 23:31:23 -0600 Subject: [PATCH 23/41] Simplify fhe input() codegen --- .../src/fhe_program_transforms.rs | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/sunscreen_compiler_macros/src/fhe_program_transforms.rs b/sunscreen_compiler_macros/src/fhe_program_transforms.rs index f0d769408..352776997 100644 --- a/sunscreen_compiler_macros/src/fhe_program_transforms.rs +++ b/sunscreen_compiler_macros/src/fhe_program_transforms.rs @@ -9,8 +9,8 @@ pub enum MapFheTypeError { /** * Given an input type T, returns - * * FheProgramInput when T is a Path - * * [map_input_type(T); N] when T is Array + * * FheProgramNode when T is a Path + * * [FheProgramNode; N] when T is Array */ pub fn map_fhe_type(arg_type: &Type) -> Result { let transformed_type = match arg_type { @@ -44,20 +44,12 @@ pub fn create_fhe_program_node(var_name: &str, arg_type: &Type) -> TokenStream2 }; } }; - let var_name = format_ident!("{}", var_name); - let type_annotation = match arg_type { - Type::Path(ty) => quote_spanned! { ty.span() => FheProgramNode }, - Type::Array(a) => quote_spanned! { a.span() => - <#mapped_type> - }, - _ => quote! { - compile_error!("fhe_program arguments' name must be a simple identifier and type must be a plain path."); - }, - }; + let var_name = format_ident!("{}", var_name); quote_spanned! {arg_type.span() => - let #var_name: #mapped_type = #type_annotation::input(); + + let #var_name: #mapped_type = <#mapped_type as Input>::input(); } } @@ -305,7 +297,7 @@ mod test { let actual = create_fhe_program_node("horse", &type_name); let expected = quote! { - let horse: FheProgramNode > = FheProgramNode::input(); + let horse: FheProgramNode > = > as Input>::input(); }; assert_syn_eq(&actual, &expected); @@ -322,7 +314,7 @@ mod test { let actual = create_fhe_program_node("horse", &type_name); let expected = quote! { - let horse: [FheProgramNode >; 7] = <[FheProgramNode >; 7]>::input(); + let horse: [FheProgramNode >; 7] = <[FheProgramNode >; 7] as Input>::input(); }; assert_syn_eq(&actual, &expected); @@ -339,7 +331,7 @@ mod test { let actual = create_fhe_program_node("horse", &type_name); let expected = quote! { - let horse: [[FheProgramNode >; 7]; 6] = <[[FheProgramNode >; 7]; 6]>::input(); + let horse: [[FheProgramNode >; 7]; 6] = <[[FheProgramNode >; 7]; 6] as Input>::input(); }; assert_syn_eq(&actual, &expected); From 77c1980035159a58af4b3184864efd78d850bccf Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Thu, 29 Jun 2023 23:41:50 -0600 Subject: [PATCH 24/41] Marginally better compiler error messages on invalid fhe program arg types --- sunscreen_compiler_macros/src/fhe_program_transforms.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sunscreen_compiler_macros/src/fhe_program_transforms.rs b/sunscreen_compiler_macros/src/fhe_program_transforms.rs index 352776997..05cd5d4a2 100644 --- a/sunscreen_compiler_macros/src/fhe_program_transforms.rs +++ b/sunscreen_compiler_macros/src/fhe_program_transforms.rs @@ -48,7 +48,7 @@ pub fn create_fhe_program_node(var_name: &str, arg_type: &Type) -> TokenStream2 let var_name = format_ident!("{}", var_name); quote_spanned! {arg_type.span() => - + { struct _AssertInput where #mapped_type: Input; } let #var_name: #mapped_type = <#mapped_type as Input>::input(); } } @@ -297,6 +297,7 @@ mod test { let actual = create_fhe_program_node("horse", &type_name); let expected = quote! { + { struct _AssertInput where FheProgramNode >: Input; } let horse: FheProgramNode > = > as Input>::input(); }; @@ -314,6 +315,7 @@ mod test { let actual = create_fhe_program_node("horse", &type_name); let expected = quote! { + { struct _AssertInput where [FheProgramNode >; 7]: Input; } let horse: [FheProgramNode >; 7] = <[FheProgramNode >; 7] as Input>::input(); }; @@ -331,6 +333,7 @@ mod test { let actual = create_fhe_program_node("horse", &type_name); let expected = quote! { + { struct _AssertInput where [[FheProgramNode >; 7]; 6]: Input; } let horse: [[FheProgramNode >; 7]; 6] = <[[FheProgramNode >; 7]; 6] as Input>::input(); }; From c593f74ce14d2105e7b57d07553ec2fe58863195 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 00:22:11 -0600 Subject: [PATCH 25/41] Fix error for fhe program argument attributes --- sunscreen_compiler_macros/src/fhe_program.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index d2334b90f..34d955630 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -35,9 +35,9 @@ pub fn fhe_program_impl( Ok(v) => { for arg in &v { if !arg.0.is_empty() { - return proc_macro::TokenStream::from( - quote_spanned! { arg.1.span() => compile_error!("FHE program arguments do not support attributes.")}, - ); + return proc_macro::TokenStream::from(quote_spanned! { arg.0[0].span() => + compile_error!{"FHE program arguments do not support attributes."} + }); } } From 6cb0de3708df2d29db08c8f6bf0652298f26a49a Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 00:22:35 -0600 Subject: [PATCH 26/41] Throw appropriate compiler error on generics --- sunscreen_compiler_macros/src/fhe_program.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index 34d955630..59a316e65 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -16,6 +16,7 @@ pub fn fhe_program_impl( let fhe_program_name = &input_fn.sig.ident; let vis = &input_fn.vis; let body = &input_fn.block; + let generics = &input_fn.sig.generics; let inputs = &input_fn.sig.inputs; let ret = &input_fn.sig.output; @@ -31,6 +32,12 @@ pub fn fhe_program_impl( let chain_count = attr_params.chain_count; + if !generics.params.is_empty() { + return proc_macro::TokenStream::from( + quote_spanned! { generics.params.span() => compile_error!{"FHE programs do not support generics."}}, + ); + } + let unwrapped_inputs = match extract_fn_arguments(inputs) { Ok(v) => { for arg in &v { From c201db8f033876d527fc8a2ebe8e6899c3e1727e Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 00:35:35 -0600 Subject: [PATCH 27/41] Silence clippy warnings in generated code These I think are typically ignored by default when consuming proc macros but might as well be explicit --- sunscreen_compiler_macros/src/fhe_program.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index 59a316e65..eb0a20ddc 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -140,6 +140,7 @@ pub fn fhe_program_impl( } impl sunscreen::FheProgramFn for #fhe_program_struct_name { + #[allow(unused_imports)] fn build(&self, params: &sunscreen::Params) -> sunscreen::Result { use std::cell::RefCell; use std::mem::transmute; From 33a24e8493ec1d39db3c8bb7c3143db732148c96 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 08:57:59 -0600 Subject: [PATCH 28/41] Fixup quote_spanned invocations Unsure how important this is, but see here: https://docs.rs/quote/latest/quote/macro.quote_spanned.html#syntax --- sunscreen_compiler_macros/src/fhe_program.rs | 33 +++++++++----------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index eb0a20ddc..6a4913287 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -34,7 +34,7 @@ pub fn fhe_program_impl( if !generics.params.is_empty() { return proc_macro::TokenStream::from( - quote_spanned! { generics.params.span() => compile_error!{"FHE programs do not support generics."}}, + quote_spanned!(generics.params.span()=> compile_error!{"FHE programs do not support generics."}), ); } @@ -42,7 +42,7 @@ pub fn fhe_program_impl( Ok(v) => { for arg in &v { if !arg.0.is_empty() { - return proc_macro::TokenStream::from(quote_spanned! { arg.0[0].span() => + return proc_macro::TokenStream::from(quote_spanned! {arg.0[0].span() => compile_error!{"FHE program arguments do not support attributes."} }); } @@ -53,16 +53,16 @@ pub fn fhe_program_impl( Err(e) => { return proc_macro::TokenStream::from(match e { ExtractFnArgumentsError::ContainsSelf(s) => { - quote_spanned! {s => compile_error! { "FHE programs must not contain `self`" } } + quote_spanned!(s=> compile_error! { "FHE programs must not contain `self`" }) } ExtractFnArgumentsError::ContainsMut(s) => { - quote_spanned! {s => compile_error! { "FHE program arguments cannot be `mut`" } } + quote_spanned!(s=> compile_error! { "FHE program arguments cannot be `mut`" }) } - ExtractFnArgumentsError::IllegalPat(s) => quote_spanned! { - s => compile_error! { "Expected Identifier" } + ExtractFnArgumentsError::IllegalPat(s) => quote_spanned! {s=> + compile_error! { "Expected Identifier" } }, - ExtractFnArgumentsError::IllegalType(s) => quote_spanned! { - s => compile_error! { "FHE program arguments must be an array or named struct type" } + ExtractFnArgumentsError::IllegalType(s) => quote_spanned! {s=> + compile_error! { "FHE program arguments must be an array or named struct type" } }, }); } @@ -88,9 +88,9 @@ pub fn fhe_program_impl( let return_types = match extract_return_types(ret) { Ok(v) => v, Err(ExtractReturnTypesError::IllegalType(s)) => { - return proc_macro::TokenStream::from( - quote_spanned! {s => compile_error! {"FHE programs may return a single value or a tuple of values. Each type must be an FHE type or array of such."}}, - ); + return proc_macro::TokenStream::from(quote_spanned! {s=> + compile_error! {"FHE programs may return a single value or a tuple of values. Each type must be an FHE type or array of such."} + }); } }; @@ -103,9 +103,9 @@ pub fn fhe_program_impl( { Ok(v) => v, Err(MapFheTypeError::IllegalType(s)) => { - return proc_macro::TokenStream::from( - quote_spanned! {s => compile_error! {"Each return type for an FHE program must be either an array or named struct type."}}, - ); + return proc_macro::TokenStream::from(quote_spanned! {s=> + compile_error! {"Each return type for an FHE program must be either an array or named struct type."} + }); } }; @@ -121,10 +121,7 @@ pub fn fhe_program_impl( let args = unwrapped_inputs.iter().enumerate().map(|(i, t)| { let id = Ident::new(&format!("c_{}", i), Span::call_site()); - - quote_spanned! {t.1.span() => - #id - } + quote_spanned!(t.1.span()=> #id) }); let fhe_program_struct_name = From b0f3cafe31a28557bddee1d4b592bab3f261c01d Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 08:58:36 -0600 Subject: [PATCH 29/41] Automatically call `.into()` on fhe prog return values --- sunscreen/tests/fhe_program_tests.rs | 2 +- sunscreen_compiler_macros/src/fhe_program.rs | 35 +++++++++++++++++-- .../src/fhe_program_transforms.rs | 16 +++++++++ 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/sunscreen/tests/fhe_program_tests.rs b/sunscreen/tests/fhe_program_tests.rs index 095ce9667..3981d91e2 100644 --- a/sunscreen/tests/fhe_program_tests.rs +++ b/sunscreen/tests/fhe_program_tests.rs @@ -277,7 +277,7 @@ fn can_insert_literals() { for x in xs { sum = sum + x; } - sum.into() + sum } let arg_type_name = <[Cipher; 2]>::type_name(); diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index 6a4913287..d2fc7e013 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -111,6 +111,32 @@ pub fn fhe_program_impl( let fhe_program_return = pack_return_type(&fhe_program_returns); + // Tokens necessary for the `internal_internal` function, which returns any types that can + // `.into` the `fhe_program_return` types. + // E.g. (impl Into>, impl Into>) + let inner_return = pack_return_into_type(&fhe_program_returns); + // E.g. a, b + let inner_arg_values = unwrapped_inputs.iter().map(|(_, _, name)| *name); + // E.g. (_r1, _r2) + let inner_return_idents = fhe_program_returns + .iter() + .enumerate() + .map(|(i, t)| { + let id = Ident::new(&format!("__r_{}", i), Span::call_site()); + quote_spanned!(t.span()=> #id) + }) + .collect::>(); + // TODO generalize the pack function to do more than just types + let inner_return_values = match &inner_return_idents[..] { + [r1] => quote! { #r1 }, + _ => quote! { ( #(#inner_return_idents),* ) }, + }; + // E.g. (_r1.into(), _r2.into()) + let inner_return_into_values = match &inner_return_idents[..] { + [r1] => quote! { #r1.into() }, + _ => quote! { ( #(#inner_return_idents.into()),* ) }, + }; + let signature = emit_signature(&argument_types, &return_types); let var_decl = unwrapped_inputs.iter().enumerate().map(|(i, t)| { @@ -153,9 +179,12 @@ pub fn fhe_program_impl( CURRENT_FHE_CTX.with(|ctx| { #[allow(clippy::type_complexity)] #[forbid(unused_variables)] - let internal = | #(#fhe_program_args)* | -> #fhe_program_return - #body - ; + let internal = | #(#fhe_program_args)* | -> #fhe_program_return { + fn internal_internal(#(#fhe_program_args)*) -> #inner_return #body + + let #inner_return_values = internal_internal( #(#inner_arg_values),* ); + #inner_return_into_values + }; // Transmute away the lifetime to 'static. So long as we are careful with internal() // panicing, this is safe because we set the context back to none before the funtion diff --git a/sunscreen_compiler_macros/src/fhe_program_transforms.rs b/sunscreen_compiler_macros/src/fhe_program_transforms.rs index 05cd5d4a2..5796ee92c 100644 --- a/sunscreen_compiler_macros/src/fhe_program_transforms.rs +++ b/sunscreen_compiler_macros/src/fhe_program_transforms.rs @@ -110,6 +110,22 @@ pub fn pack_return_type(return_types: &[Type]) -> Type { } } +/** + * Takes an array of return types, wraps each in an `impl Into<_>`, and packs them up like +* `pack_return_type`. +*/ +pub fn pack_return_into_type(return_types: &[Type]) -> Type { + match return_types { + [] => parse_quote! { () }, + [ty] => parse_quote_spanned! { ty.span() => + impl Into<#ty> + }, + _ => parse_quote_spanned! { return_types[0].span() => + ( #(impl Into<#return_types>),* ) + }, + } +} + pub fn emit_output_capture(return_types: &[Type]) -> TokenStream2 { match return_types { [ty] => quote_spanned! { ty.span() => { From b49b4d4b5e9a7a8a519824028284880f9461ffd2 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 10:37:26 -0600 Subject: [PATCH 30/41] Factor fhe_program_impl --- sunscreen_compiler_macros/src/fhe_program.rs | 424 ++++++++++--------- 1 file changed, 215 insertions(+), 209 deletions(-) diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index d2fc7e013..4e68c529f 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -5,250 +5,256 @@ use crate::{ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; use sunscreen_compiler_common::macros::{extract_fn_arguments, ExtractFnArgumentsError}; -use syn::{parse_macro_input, spanned::Spanned, Ident, ItemFn, Type}; +use syn::{parse_macro_input, spanned::Spanned, Error, Ident, ItemFn, Result, Type}; pub fn fhe_program_impl( metadata: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - let input_fn = parse_macro_input!(input as ItemFn); - - let fhe_program_name = &input_fn.sig.ident; - let vis = &input_fn.vis; - let body = &input_fn.block; - let generics = &input_fn.sig.generics; - let inputs = &input_fn.sig.inputs; - let ret = &input_fn.sig.output; - + let item_fn = parse_macro_input!(input as ItemFn); let attr_params = parse_macro_input!(metadata as FheProgramAttrs); - - let scheme_type = match attr_params.scheme { - Scheme::Bfv => { - quote! { - sunscreen::SchemeType::Bfv - } - } + let fhe_program = FheProgram { + item_fn, + attr_params, }; - - let chain_count = attr_params.chain_count; - - if !generics.params.is_empty() { - return proc_macro::TokenStream::from( - quote_spanned!(generics.params.span()=> compile_error!{"FHE programs do not support generics."}), - ); + match fhe_program.output() { + Ok(t) => proc_macro::TokenStream::from(t), + Err(e) => proc_macro::TokenStream::from(Error::into_compile_error(e)), } +} - let unwrapped_inputs = match extract_fn_arguments(inputs) { - Ok(v) => { - for arg in &v { - if !arg.0.is_empty() { - return proc_macro::TokenStream::from(quote_spanned! {arg.0[0].span() => - compile_error!{"FHE program arguments do not support attributes."} - }); +// TODO move validation errors into a `FheProgram::new() -> Result` function, +// gather the useful fields onto the struct, +// then call `FheProgram::new().and_then(|f|f.output())`. +struct FheProgram { + item_fn: ItemFn, + attr_params: FheProgramAttrs, +} + +impl FheProgram { + fn output(self) -> Result { + let input_fn = self.item_fn; + let attr_params = self.attr_params; + + let fhe_program_name = &input_fn.sig.ident; + let vis = &input_fn.vis; + let body = &input_fn.block; + let generics = &input_fn.sig.generics; + let inputs = &input_fn.sig.inputs; + let ret = &input_fn.sig.output; + + let scheme_type = match attr_params.scheme { + Scheme::Bfv => { + quote! { + sunscreen::SchemeType::Bfv } } + }; - v + let chain_count = attr_params.chain_count; + + if !generics.params.is_empty() { + return Err(Error::new( + generics.params.span(), + "FHE programs do not support generics.", + )); } - Err(e) => { - return proc_macro::TokenStream::from(match e { + + let unwrapped_inputs = extract_fn_arguments(inputs) + .map_err(|e| match e { ExtractFnArgumentsError::ContainsSelf(s) => { - quote_spanned!(s=> compile_error! { "FHE programs must not contain `self`" }) + Error::new(s, "FHE programs must not contain `self`") } ExtractFnArgumentsError::ContainsMut(s) => { - quote_spanned!(s=> compile_error! { "FHE program arguments cannot be `mut`" }) + Error::new(s, "FHE program arguments cannot be `mut`") } - ExtractFnArgumentsError::IllegalPat(s) => quote_spanned! {s=> - compile_error! { "Expected Identifier" } - }, - ExtractFnArgumentsError::IllegalType(s) => quote_spanned! {s=> - compile_error! { "FHE program arguments must be an array or named struct type" } - }, - }); - } - }; - - let argument_types = unwrapped_inputs - .iter() - .map(|(_, t, _)| (**t).clone()) - .collect::>(); - - let fhe_program_args = unwrapped_inputs - .iter() - .map(|i| { - let (_, ty, name) = i; - let ty = map_fhe_type(ty).unwrap(); - - quote! { - #name: #ty, - } - }) - .collect::>(); - - let return_types = match extract_return_types(ret) { - Ok(v) => v, - Err(ExtractReturnTypesError::IllegalType(s)) => { - return proc_macro::TokenStream::from(quote_spanned! {s=> - compile_error! {"FHE programs may return a single value or a tuple of values. Each type must be an FHE type or array of such."} - }); - } - }; - - let output_capture = emit_output_capture(&return_types); - - let fhe_program_returns = match return_types - .iter() - .map(map_fhe_type) - .collect::, MapFheTypeError>>() - { - Ok(v) => v, - Err(MapFheTypeError::IllegalType(s)) => { - return proc_macro::TokenStream::from(quote_spanned! {s=> - compile_error! {"Each return type for an FHE program must be either an array or named struct type."} - }); - } - }; - - let fhe_program_return = pack_return_type(&fhe_program_returns); - - // Tokens necessary for the `internal_internal` function, which returns any types that can - // `.into` the `fhe_program_return` types. - // E.g. (impl Into>, impl Into>) - let inner_return = pack_return_into_type(&fhe_program_returns); - // E.g. a, b - let inner_arg_values = unwrapped_inputs.iter().map(|(_, _, name)| *name); - // E.g. (_r1, _r2) - let inner_return_idents = fhe_program_returns - .iter() - .enumerate() - .map(|(i, t)| { - let id = Ident::new(&format!("__r_{}", i), Span::call_site()); - quote_spanned!(t.span()=> #id) - }) - .collect::>(); - // TODO generalize the pack function to do more than just types - let inner_return_values = match &inner_return_idents[..] { - [r1] => quote! { #r1 }, - _ => quote! { ( #(#inner_return_idents),* ) }, - }; - // E.g. (_r1.into(), _r2.into()) - let inner_return_into_values = match &inner_return_idents[..] { - [r1] => quote! { #r1.into() }, - _ => quote! { ( #(#inner_return_idents.into()),* ) }, - }; - - let signature = emit_signature(&argument_types, &return_types); - - let var_decl = unwrapped_inputs.iter().enumerate().map(|(i, t)| { - let var_name = format!("c_{}", i); - - create_fhe_program_node(&var_name, t.1) - }); - - let args = unwrapped_inputs.iter().enumerate().map(|(i, t)| { - let id = Ident::new(&format!("c_{}", i), Span::call_site()); - quote_spanned!(t.1.span()=> #id) - }); - - let fhe_program_struct_name = - Ident::new(&format!("{}_struct", fhe_program_name), Span::call_site()); - - let fhe_program_name_literal = format!("{}", fhe_program_name); - - proc_macro::TokenStream::from(quote! { - #[allow(non_camel_case_types)] - #[derive(Clone)] - #vis struct #fhe_program_struct_name { - chain_count: usize - } - - impl sunscreen::FheProgramFn for #fhe_program_struct_name { - #[allow(unused_imports)] - fn build(&self, params: &sunscreen::Params) -> sunscreen::Result { - use std::cell::RefCell; - use std::mem::transmute; - use sunscreen::{fhe::{CURRENT_FHE_CTX, FheContext}, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{FheProgramNode, Input, Output}, NumCiphertexts, Type, TypeName, SwapRows, LaneCount, TypeNameInstance}}; - - if SchemeType::Bfv != params.scheme_type { - return Err(Error::IncorrectScheme) + ExtractFnArgumentsError::IllegalPat(s) => Error::new(s, "Expected Identifier"), + ExtractFnArgumentsError::IllegalType(s) => Error::new( + s, + "FHE program arguments must be an array or named struct type", + ), + }) + .and_then(|v| { + for arg in &v { + if !arg.0.is_empty() { + return Err(Error::new( + arg.0[0].span(), + "FHE program arguments do not support attributes.", + )); + } } + Ok(v) + })?; + + let argument_types = unwrapped_inputs + .iter() + .map(|(_, t, _)| (**t).clone()) + .collect::>(); + + let fhe_program_args = unwrapped_inputs + .iter() + .map(|(_, ty, name)| { + let ty = map_fhe_type(ty).unwrap(); + quote! { + #name: #ty, + } + }) + .collect::>(); + + let return_types = extract_return_types(ret).map_err(|ExtractReturnTypesError::IllegalType(s)| + Error::new(s, "FHE programs may return a single value or a tuple of values. Each type must be an FHE type or array of such.") + )?; + + let output_capture = emit_output_capture(&return_types); + + let fhe_program_returns = return_types + .iter() + .map(map_fhe_type) + .collect::, MapFheTypeError>>() + .map_err(|MapFheTypeError::IllegalType(s)| + Error::new(s, "Each return type for an FHE program must be either an array or named struct type.") + )?; + + let fhe_program_return = pack_return_type(&fhe_program_returns); + + // Tokens necessary for the `internal_internal` function, which returns any types that can + // `.into` the `fhe_program_return` types. + // E.g. (impl Into>, impl Into>) + let inner_return = pack_return_into_type(&fhe_program_returns); + // E.g. a, b + let inner_arg_values = unwrapped_inputs.iter().map(|(_, _, name)| *name); + // E.g. (_r1, _r2) + let inner_return_idents = fhe_program_returns + .iter() + .enumerate() + .map(|(i, t)| { + let id = Ident::new(&format!("__r_{}", i), Span::call_site()); + quote_spanned!(t.span()=> #id) + }) + .collect::>(); + // TODO generalize the pack function to do more than just types + let inner_return_values = match &inner_return_idents[..] { + [r1] => quote! { #r1 }, + _ => quote! { ( #(#inner_return_idents),* ) }, + }; + // E.g. (_r1.into(), _r2.into()) + let inner_return_into_values = match &inner_return_idents[..] { + [r1] => quote! { #r1.into() }, + _ => quote! { ( #(#inner_return_idents.into()),* ) }, + }; - // TODO: Other schemes. - let mut context = FheContext::new(params.clone()); + let signature = emit_signature(&argument_types, &return_types); - CURRENT_FHE_CTX.with(|ctx| { - #[allow(clippy::type_complexity)] - #[forbid(unused_variables)] - let internal = | #(#fhe_program_args)* | -> #fhe_program_return { - fn internal_internal(#(#fhe_program_args)*) -> #inner_return #body + let var_decl = unwrapped_inputs.iter().enumerate().map(|(i, t)| { + let var_name = format!("c_{}", i); + create_fhe_program_node(&var_name, t.1) + }); - let #inner_return_values = internal_internal( #(#inner_arg_values),* ); - #inner_return_into_values - }; + let args = unwrapped_inputs.iter().enumerate().map(|(i, t)| { + let id = Ident::new(&format!("c_{}", i), Span::call_site()); + quote_spanned!(t.1.span()=> #id) + }); - // Transmute away the lifetime to 'static. So long as we are careful with internal() - // panicing, this is safe because we set the context back to none before the funtion - // returns. - ctx.swap(&RefCell::new(Some(unsafe { transmute(&mut context) }))); + let fhe_program_struct_name = + Ident::new(&format!("{}_struct", fhe_program_name), Span::call_site()); - #(#var_decl)* + let fhe_program_name_literal = format!("{}", fhe_program_name); - let panic_res = std::panic::catch_unwind(|| { - internal(#(#args),*) - }); + Ok(quote! { + #[allow(non_camel_case_types)] + #[derive(Clone)] + #vis struct #fhe_program_struct_name { + chain_count: usize + } - // when panicing or not, we need to collect our indicies arena and - // unset the context reference. - match panic_res { - Ok(v) => { #output_capture }, - Err(err) => { - INDEX_ARENA.with(|allocator| { - allocator.borrow_mut().reset() - }); - ctx.swap(&RefCell::new(None)); - std::panic::resume_unwind(err) - } - }; - - INDEX_ARENA.with(|allocator| { - allocator.borrow_mut().reset() + impl sunscreen::FheProgramFn for #fhe_program_struct_name { + #[allow(unused_imports)] + fn build(&self, params: &sunscreen::Params) -> sunscreen::Result { + use std::cell::RefCell; + use std::mem::transmute; + use sunscreen::{fhe::{CURRENT_FHE_CTX, FheContext}, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{FheProgramNode, Input, Output}, NumCiphertexts, Type, TypeName, SwapRows, LaneCount, TypeNameInstance}}; + + if SchemeType::Bfv != params.scheme_type { + return Err(Error::IncorrectScheme) + } + + // TODO: Other schemes. + let mut context = FheContext::new(params.clone()); + + CURRENT_FHE_CTX.with(|ctx| { + #[allow(clippy::type_complexity)] + #[forbid(unused_variables)] + let internal = | #(#fhe_program_args)* | -> #fhe_program_return { + fn internal_internal(#(#fhe_program_args)*) -> #inner_return #body + + let #inner_return_values = internal_internal( #(#inner_arg_values),* ); + #inner_return_into_values + }; + + // Transmute away the lifetime to 'static. So long as we are careful with internal() + // panicing, this is safe because we set the context back to none before the funtion + // returns. + ctx.swap(&RefCell::new(Some(unsafe { transmute(&mut context) }))); + + #(#var_decl)* + + let panic_res = std::panic::catch_unwind(|| { + internal(#(#args),*) + }); + + // when panicing or not, we need to collect our indicies arena and + // unset the context reference. + match panic_res { + Ok(v) => { #output_capture }, + Err(err) => { + INDEX_ARENA.with(|allocator| { + allocator.borrow_mut().reset() + }); + ctx.swap(&RefCell::new(None)); + std::panic::resume_unwind(err) + } + }; + + INDEX_ARENA.with(|allocator| { + allocator.borrow_mut().reset() + }); + ctx.swap(&RefCell::new(None)); }); - ctx.swap(&RefCell::new(None)); - }); - Ok(context.graph) - } + Ok(context.graph) + } - fn signature(&self) -> sunscreen::CallSignature { - use sunscreen::types::NumCiphertexts; + fn signature(&self) -> sunscreen::CallSignature { + use sunscreen::types::NumCiphertexts; - #signature - } + #signature + } - fn scheme_type(&self) -> sunscreen::SchemeType { - #scheme_type - } + fn scheme_type(&self) -> sunscreen::SchemeType { + #scheme_type + } - fn name(&self) -> &str { - #fhe_program_name_literal - } + fn name(&self) -> &str { + #fhe_program_name_literal + } - fn chain_count(&self) -> usize { - self.chain_count + fn chain_count(&self) -> usize { + self.chain_count + } } - } - impl AsRef for #fhe_program_struct_name { - fn as_ref(&self) -> &str { - use sunscreen::FheProgramFn; + impl AsRef for #fhe_program_struct_name { + fn as_ref(&self) -> &str { + use sunscreen::FheProgramFn; - self.name() + self.name() + } } - } - #[allow(non_upper_case_globals)] - #vis const #fhe_program_name: #fhe_program_struct_name = #fhe_program_struct_name { - chain_count: #chain_count - }; - }) + #[allow(non_upper_case_globals)] + #vis const #fhe_program_name: #fhe_program_struct_name = #fhe_program_struct_name { + chain_count: #chain_count + }; + }) + } } From 341019706963af9869d6a397708b52180b3e1095 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 12:45:03 -0600 Subject: [PATCH 31/41] Further factor fhe_program_impl So that token generation happens in helper methods, and the ultimate output() func is readable --- sunscreen_compiler_macros/src/fhe_program.rs | 300 ++++++++++++------ .../src/fhe_program_transforms.rs | 43 +-- 2 files changed, 217 insertions(+), 126 deletions(-) diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index 4e68c529f..2ee74b6a2 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -4,7 +4,7 @@ use crate::{ }; use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; -use sunscreen_compiler_common::macros::{extract_fn_arguments, ExtractFnArgumentsError}; +use sunscreen_compiler_common::macros::{extract_fn_arguments, ExtractFnArgumentsError, FnArgInfo}; use syn::{parse_macro_input, spanned::Spanned, Error, Ident, ItemFn, Result, Type}; pub fn fhe_program_impl( @@ -13,11 +13,7 @@ pub fn fhe_program_impl( ) -> proc_macro::TokenStream { let item_fn = parse_macro_input!(input as ItemFn); let attr_params = parse_macro_input!(metadata as FheProgramAttrs); - let fhe_program = FheProgram { - item_fn, - attr_params, - }; - match fhe_program.output() { + match FheProgram::new(&item_fn, attr_params).map(|f| f.output()) { Ok(t) => proc_macro::TokenStream::from(t), Err(e) => proc_macro::TokenStream::from(Error::into_compile_error(e)), } @@ -26,32 +22,25 @@ pub fn fhe_program_impl( // TODO move validation errors into a `FheProgram::new() -> Result` function, // gather the useful fields onto the struct, // then call `FheProgram::new().and_then(|f|f.output())`. -struct FheProgram { - item_fn: ItemFn, +struct FheProgram<'a> { + // The function passed to the proc macro + item_fn: &'a ItemFn, + // The attributes on the proc macro attr_params: FheProgramAttrs, + // Each argument's attributes, type, and identifier + unwrapped_inputs: Vec>, + // Return types of the input program (tuple turns into vector) + return_types: Vec, + // Return types of the fhe program (i.e. wrapped in FheProgramNode) + fhe_program_return_types: Vec, } -impl FheProgram { - fn output(self) -> Result { - let input_fn = self.item_fn; - let attr_params = self.attr_params; - - let fhe_program_name = &input_fn.sig.ident; - let vis = &input_fn.vis; - let body = &input_fn.block; - let generics = &input_fn.sig.generics; - let inputs = &input_fn.sig.inputs; - let ret = &input_fn.sig.output; - - let scheme_type = match attr_params.scheme { - Scheme::Bfv => { - quote! { - sunscreen::SchemeType::Bfv - } - } - }; - - let chain_count = attr_params.chain_count; +impl<'a> FheProgram<'a> { + // Handles validation + fn new(item_fn: &'a ItemFn, attr_params: FheProgramAttrs) -> Result { + let generics = &item_fn.sig.generics; + let inputs = &item_fn.sig.inputs; + let ret = &item_fn.sig.output; if !generics.params.is_empty() { return Err(Error::new( @@ -60,38 +49,68 @@ impl FheProgram { )); } - let unwrapped_inputs = extract_fn_arguments(inputs) - .map_err(|e| match e { - ExtractFnArgumentsError::ContainsSelf(s) => { - Error::new(s, "FHE programs must not contain `self`") - } - ExtractFnArgumentsError::ContainsMut(s) => { - Error::new(s, "FHE program arguments cannot be `mut`") - } - ExtractFnArgumentsError::IllegalPat(s) => Error::new(s, "Expected Identifier"), - ExtractFnArgumentsError::IllegalType(s) => Error::new( - s, - "FHE program arguments must be an array or named struct type", - ), - }) - .and_then(|v| { - for arg in &v { - if !arg.0.is_empty() { - return Err(Error::new( - arg.0[0].span(), - "FHE program arguments do not support attributes.", - )); + let unwrapped_inputs: Vec<(Vec, &Type, &Ident)> = + extract_fn_arguments(inputs) + .map_err(|e| match e { + ExtractFnArgumentsError::ContainsSelf(s) => { + Error::new(s, "FHE programs must not contain `self`") } - } - Ok(v) - })?; + ExtractFnArgumentsError::ContainsMut(s) => { + Error::new(s, "FHE program arguments cannot be `mut`") + } + ExtractFnArgumentsError::IllegalPat(s) => Error::new(s, "Expected Identifier"), + ExtractFnArgumentsError::IllegalType(s) => Error::new( + s, + "FHE program arguments must be an array or named struct type", + ), + }) + .and_then(|v| { + for arg in &v { + if !arg.0.is_empty() { + return Err(Error::new( + arg.0[0].span(), + "FHE program arguments do not support attributes.", + )); + } + } + Ok(v) + })?; + + let return_types = extract_return_types(ret) + .map_err(|ExtractReturnTypesError::IllegalType(s)| + Error::new(s, "FHE programs may return a single value or a tuple of values. Each type must be an FHE type or array of such.") + )?; - let argument_types = unwrapped_inputs + let fhe_program_return_types = return_types + .iter() + .map(map_fhe_type) + .collect::, MapFheTypeError>>() + .map_err(|MapFheTypeError::IllegalType(s)| + Error::new(s, "Each return type for an FHE program must be either an array or named struct type.") + )?; + + Ok(Self { + item_fn, + attr_params, + unwrapped_inputs, + return_types, + fhe_program_return_types, + }) + } + + // The sunscreen::CallSignature value + fn signature(&self) -> TokenStream { + let argument_types = self + .unwrapped_inputs .iter() .map(|(_, t, _)| (**t).clone()) .collect::>(); + emit_signature(&argument_types, &self.return_types) + } - let fhe_program_args = unwrapped_inputs + // The arguments to the internal closure (input args wrapped in FheProgramNode) + fn fhe_program_args(&self) -> Vec { + self.unwrapped_inputs .iter() .map(|(_, ty, name)| { let ty = map_fhe_type(ty).unwrap(); @@ -99,68 +118,87 @@ impl FheProgram { #name: #ty, } }) - .collect::>(); - - let return_types = extract_return_types(ret).map_err(|ExtractReturnTypesError::IllegalType(s)| - Error::new(s, "FHE programs may return a single value or a tuple of values. Each type must be an FHE type or array of such.") - )?; - - let output_capture = emit_output_capture(&return_types); + .collect() + } - let fhe_program_returns = return_types + // Variable declarations like (but not exactly): + // `__c_0: FheProgramNode> = FheProgramNode::input()` + fn fhe_arg_var_decl(&self) -> Vec { + self.unwrapped_inputs .iter() - .map(map_fhe_type) - .collect::, MapFheTypeError>>() - .map_err(|MapFheTypeError::IllegalType(s)| - Error::new(s, "Each return type for an FHE program must be either an array or named struct type.") - )?; - - let fhe_program_return = pack_return_type(&fhe_program_returns); + .enumerate() + .map(|(i, t)| { + let var_name = format!("__c_{}", i); + create_fhe_program_node(&var_name, t.1) + }) + .collect() + } - // Tokens necessary for the `internal_internal` function, which returns any types that can - // `.into` the `fhe_program_return` types. - // E.g. (impl Into>, impl Into>) - let inner_return = pack_return_into_type(&fhe_program_returns); - // E.g. a, b - let inner_arg_values = unwrapped_inputs.iter().map(|(_, _, name)| *name); - // E.g. (_r1, _r2) - let inner_return_idents = fhe_program_returns + // The variables themselves (used after declaration): e.g. `__c_0` + // Note: must match naming format from `fhe_arg_var_decl`. + fn fhe_arg_vars(&self) -> Vec { + self.unwrapped_inputs .iter() .enumerate() .map(|(i, t)| { - let id = Ident::new(&format!("__r_{}", i), Span::call_site()); - quote_spanned!(t.span()=> #id) + let id = Ident::new(&format!("__c_{}", i), Span::call_site()); + quote_spanned!(t.1.span()=> #id) }) - .collect::>(); - // TODO generalize the pack function to do more than just types - let inner_return_values = match &inner_return_idents[..] { - [r1] => quote! { #r1 }, - _ => quote! { ( #(#inner_return_idents),* ) }, - }; - // E.g. (_r1.into(), _r2.into()) - let inner_return_into_values = match &inner_return_idents[..] { - [r1] => quote! { #r1.into() }, - _ => quote! { ( #(#inner_return_idents.into()),* ) }, + .collect() + } + + // Identifiers of the internal_inner return values, e.g. `__r_0` + // These are spanned on their respective return types. + fn inner_return_idents(&self) -> Vec { + self.fhe_program_return_types + .iter() + .enumerate() + .map(|(i, t)| Ident::new(&format!("__r_{}", i), t.span())) + .collect() + } + + fn output(self) -> TokenStream { + let input_fn = self.item_fn; + let attr_params = &self.attr_params; + let unwrapped_inputs = &self.unwrapped_inputs; + let return_types = &self.return_types; + let fhe_program_return_types = &self.fhe_program_return_types; + + let fhe_program_name = &input_fn.sig.ident; + let vis = &input_fn.vis; + let body = &input_fn.block; + + let chain_count = attr_params.chain_count; + let scheme_type = match attr_params.scheme { + Scheme::Bfv => { + quote! { + sunscreen::SchemeType::Bfv + } + } }; - let signature = emit_signature(&argument_types, &return_types); + let fhe_program_args = self.fhe_program_args(); + let fhe_program_return = pack_into_tuple(&fhe_program_return_types); + + let inner_return = pack_into_tuple(&wrap_impl_into(&fhe_program_return_types)); + let inner_arg_values = unwrapped_inputs.iter().map(|(_, _, name)| *name); + let inner_return_idents = self.inner_return_idents(); + let inner_return_values = pack_into_tuple(&inner_return_idents); + // E.g. (_r1.into(), _r2.into()) + let inner_return_into_values = pack_into_tuple(&suffix_into(&inner_return_idents)); - let var_decl = unwrapped_inputs.iter().enumerate().map(|(i, t)| { - let var_name = format!("c_{}", i); - create_fhe_program_node(&var_name, t.1) - }); + let signature = self.signature(); - let args = unwrapped_inputs.iter().enumerate().map(|(i, t)| { - let id = Ident::new(&format!("c_{}", i), Span::call_site()); - quote_spanned!(t.1.span()=> #id) - }); + let fhe_arg_var_decl = self.fhe_arg_var_decl(); + let fhe_arg_vars = self.fhe_arg_vars(); + let output_capture = emit_output_capture(&return_types); let fhe_program_struct_name = Ident::new(&format!("{}_struct", fhe_program_name), Span::call_site()); let fhe_program_name_literal = format!("{}", fhe_program_name); - Ok(quote! { + quote! { #[allow(non_camel_case_types)] #[derive(Clone)] #vis struct #fhe_program_struct_name { @@ -185,9 +223,9 @@ impl FheProgram { #[allow(clippy::type_complexity)] #[forbid(unused_variables)] let internal = | #(#fhe_program_args)* | -> #fhe_program_return { - fn internal_internal(#(#fhe_program_args)*) -> #inner_return #body + fn internal_inner(#(#fhe_program_args)*) -> #inner_return #body - let #inner_return_values = internal_internal( #(#inner_arg_values),* ); + let #inner_return_values = internal_inner( #(#inner_arg_values),* ); #inner_return_into_values }; @@ -196,10 +234,10 @@ impl FheProgram { // returns. ctx.swap(&RefCell::new(Some(unsafe { transmute(&mut context) }))); - #(#var_decl)* + #(#fhe_arg_var_decl)* let panic_res = std::panic::catch_unwind(|| { - internal(#(#args),*) + internal(#(#fhe_arg_vars),*) }); // when panicing or not, we need to collect our indicies arena and @@ -255,6 +293,58 @@ impl FheProgram { #vis const #fhe_program_name: #fhe_program_struct_name = #fhe_program_struct_name { chain_count: #chain_count }; - }) + } + } +} + +#[cfg(test)] +mod test { + use syn::parse_quote; + + use super::*; + + #[test] + fn basic_multiply_works() { + let attrs = FheProgramAttrs { + scheme: Scheme::Bfv, + chain_count: 1, + }; + let attempt_fn = parse_quote! { + fn simple_multiply(a: Cipher, b: Cipher) -> Cipher { + a * b + } + }; + + assert!(FheProgram::new(&attempt_fn, attrs).is_ok()) + } + + #[test] + fn disallows_mut() { + let attrs = FheProgramAttrs { + scheme: Scheme::Bfv, + chain_count: 1, + }; + let attempt_fn = parse_quote! { + fn simple_multiply(mut a: Cipher, b: Cipher) -> Cipher { + a * b + } + }; + + assert!(FheProgram::new(&attempt_fn, attrs).is_err()) + } + + #[test] + fn disallows_generics() { + let attrs = FheProgramAttrs { + scheme: Scheme::Bfv, + chain_count: 1, + }; + let attempt_fn = parse_quote! { + fn simple_multiply(a: Cipher, b: Cipher) -> Cipher { + b + } + }; + + assert!(FheProgram::new(&attempt_fn, attrs).is_err()) } } diff --git a/sunscreen_compiler_macros/src/fhe_program_transforms.rs b/sunscreen_compiler_macros/src/fhe_program_transforms.rs index 5796ee92c..870cf2c11 100644 --- a/sunscreen_compiler_macros/src/fhe_program_transforms.rs +++ b/sunscreen_compiler_macros/src/fhe_program_transforms.rs @@ -1,6 +1,6 @@ use proc_macro2::{Span, TokenStream as TokenStream2}; -use quote::{format_ident, quote, quote_spanned}; -use syn::{parse_quote, parse_quote_spanned, spanned::Spanned, Index, ReturnType, Type}; +use quote::{format_ident, quote, quote_spanned, ToTokens}; +use syn::{parse_quote, parse_quote_spanned, spanned::Spanned, Ident, Index, ReturnType, Type}; #[derive(Debug)] pub enum MapFheTypeError { @@ -97,33 +97,34 @@ pub fn extract_return_types(ret: &ReturnType) -> Result, ExtractReturn } /** - * Takes an array of return types and packages them into a tuple - * if needed. + * Takes an array of tokens and packages them into a tuple if needed. */ -pub fn pack_return_type(return_types: &[Type]) -> Type { - match return_types.len() { - 0 => parse_quote! { () }, - 1 => return_types[0].clone(), +pub fn pack_into_tuple(ts: &[T]) -> TokenStream2 { + match ts { + [] => parse_quote! { () }, + [t] => t.to_token_stream(), _ => { - parse_quote_spanned! {return_types[0].span() => ( #(#return_types),* ) } + quote_spanned! {ts[0].span()=> ( #(#ts),* ) } } } } /** - * Takes an array of return types, wraps each in an `impl Into<_>`, and packs them up like -* `pack_return_type`. + * Takes an array of types and wraps each in an `impl Into<_>` */ -pub fn pack_return_into_type(return_types: &[Type]) -> Type { - match return_types { - [] => parse_quote! { () }, - [ty] => parse_quote_spanned! { ty.span() => - impl Into<#ty> - }, - _ => parse_quote_spanned! { return_types[0].span() => - ( #(impl Into<#return_types>),* ) - }, - } +pub fn wrap_impl_into(ts: &[Type]) -> Vec { + ts.iter() + .map(|t| parse_quote_spanned! {t.span()=> impl Into<#t>}) + .collect() +} + +/** + * Takes an array of idents and suffixes each with `.into()` +*/ +pub fn suffix_into(is: &[Ident]) -> Vec { + is.iter() + .map(|i| quote_spanned! {i.span()=> #i.into()}) + .collect() } pub fn emit_output_capture(return_types: &[Type]) -> TokenStream2 { From 72797befd2ff2573634c57bb9424c14ac7be0ec0 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 12:45:16 -0600 Subject: [PATCH 32/41] Fix doctests --- sunscreen/src/types/mod.rs | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/sunscreen/src/types/mod.rs b/sunscreen/src/types/mod.rs index c1aca63fb..3c6771bb5 100644 --- a/sunscreen/src/types/mod.rs +++ b/sunscreen/src/types/mod.rs @@ -140,6 +140,7 @@ where /// arithmetic operations with ciphertexts: /// /// ``` +/// # use sunscreen::{fhe_program, types::{Cipher, bfv::Signed}}; /// #[fhe_program(scheme = "bfv")] /// fn add_ten(a: Cipher) -> Cipher { /// a + 10 @@ -150,11 +151,11 @@ where /// value, this won't work: /// /// ```compile_fail +/// # use sunscreen::{fhe_program, types::{Cipher, bfv::Signed}}; /// #[fhe_program(scheme = "bfv")] /// fn add_ten(a: Cipher) -> Cipher { -/// let sum = 0; -/// sum = sum + a -/// sum = sum + 10 +/// let sum = 10; +/// sum = sum + a; /// sum /// } /// ``` @@ -163,11 +164,11 @@ where /// this macro: /// /// ``` +/// # use sunscreen::{fhe_var, fhe_program, types::{Cipher, bfv::Signed}}; /// #[fhe_program(scheme = "bfv")] /// fn add_ten(a: Cipher) -> Cipher { -/// let sum = fhe_var!(0); -/// sum = sum + a -/// sum = sum + 10 +/// let mut sum = fhe_var!(10); +/// sum = sum + a; /// sum /// } /// ``` @@ -175,18 +176,15 @@ where /// You can also create arrays of variables: /// /// ``` +/// # use sunscreen::{fhe_var, fhe_program, types::{Cipher, bfv::Signed}}; /// #[fhe_program(scheme = "bfv")] -/// fn add_ten(a: Cipher) -> Cipher { -/// let mut sum = fhe_var(0); -/// let arr = fhe_var![1, 2, 4]; -/// let ones = fhe_var![1; 3]; -/// for x in arr { -/// sum = sum + x; -/// } -/// for y in ones { -/// sum = sum + y; +/// fn add_ten(arrs: [[Cipher; 10]; 10]) { +/// let mut sum = fhe_var![0; 10]; +/// for i in 0..10 { +/// for x in arrs[i] { +/// sum[i] = sum[i] + x; +/// } /// } -/// sum + a /// } /// ``` #[macro_export] @@ -205,7 +203,9 @@ macro_rules! fhe_var { } /// Creates new ZKP variables from literals. +/// /// ``` +/// # use sunscreen::{zkp_var, zkp_program, BackendField, types::zkp::NativeField}; /// #[zkp_program(backend = "bulletproofs")] /// fn equals_ten(a: NativeField) { /// let ten = zkp_var!(10); @@ -216,6 +216,7 @@ macro_rules! fhe_var { /// You can also create arrays of variables: /// /// ``` +/// # use sunscreen::{zkp_var, zkp_program, BackendField, types::zkp::NativeField}; /// #[zkp_program(backend = "bulletproofs")] /// fn equals_ten(a: NativeField) { /// let tens = zkp_var![10, 10, 10]; From 13f9a34b739f88946b4131118eab24efa721744a Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 13:06:02 -0600 Subject: [PATCH 33/41] Fix clippy warnings --- sunscreen_compiler_macros/src/fhe_program.rs | 12 ++++++++---- .../src/fhe_program_transforms.rs | 18 ++++++++++-------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index 2ee74b6a2..0ece218b7 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -178,9 +178,9 @@ impl<'a> FheProgram<'a> { }; let fhe_program_args = self.fhe_program_args(); - let fhe_program_return = pack_into_tuple(&fhe_program_return_types); + let fhe_program_return = pack_into_tuple(fhe_program_return_types); - let inner_return = pack_into_tuple(&wrap_impl_into(&fhe_program_return_types)); + let inner_return = pack_into_tuple(&wrap_impl_into(fhe_program_return_types)); let inner_arg_values = unwrapped_inputs.iter().map(|(_, _, name)| *name); let inner_return_idents = self.inner_return_idents(); let inner_return_values = pack_into_tuple(&inner_return_idents); @@ -191,7 +191,8 @@ impl<'a> FheProgram<'a> { let fhe_arg_var_decl = self.fhe_arg_var_decl(); let fhe_arg_vars = self.fhe_arg_vars(); - let output_capture = emit_output_capture(&return_types); + let output_var = Ident::new("__v", Span::call_site()); + let output_capture = emit_output_capture(&output_var, return_types); let fhe_program_struct_name = Ident::new(&format!("{}_struct", fhe_program_name), Span::call_site()); @@ -220,6 +221,8 @@ impl<'a> FheProgram<'a> { let mut context = FheContext::new(params.clone()); CURRENT_FHE_CTX.with(|ctx| { + #[allow(clippy::let_unit_value)] + #[allow(clippy::unused_unit)] #[allow(clippy::type_complexity)] #[forbid(unused_variables)] let internal = | #(#fhe_program_args)* | -> #fhe_program_return { @@ -243,7 +246,7 @@ impl<'a> FheProgram<'a> { // when panicing or not, we need to collect our indicies arena and // unset the context reference. match panic_res { - Ok(v) => { #output_capture }, + Ok(#output_var) => { #output_capture }, Err(err) => { INDEX_ARENA.with(|allocator| { allocator.borrow_mut().reset() @@ -263,6 +266,7 @@ impl<'a> FheProgram<'a> { } fn signature(&self) -> sunscreen::CallSignature { + #[allow(unused_imports)] use sunscreen::types::NumCiphertexts; #signature diff --git a/sunscreen_compiler_macros/src/fhe_program_transforms.rs b/sunscreen_compiler_macros/src/fhe_program_transforms.rs index 870cf2c11..9aa721ff9 100644 --- a/sunscreen_compiler_macros/src/fhe_program_transforms.rs +++ b/sunscreen_compiler_macros/src/fhe_program_transforms.rs @@ -127,11 +127,11 @@ pub fn suffix_into(is: &[Ident]) -> Vec { .collect() } -pub fn emit_output_capture(return_types: &[Type]) -> TokenStream2 { +pub fn emit_output_capture(var: &Ident, return_types: &[Type]) -> TokenStream2 { match return_types { [ty] => quote_spanned! { ty.span() => { struct _AssertOutput where FheProgramNode<#ty>: Output; - v.output(); + #var.output(); }}, _ => return_types .iter() @@ -141,7 +141,7 @@ pub fn emit_output_capture(return_types: &[Type]) -> TokenStream2 { quote_spanned! {ty.span() => { struct _AssertOutput where FheProgramNode<#ty>: Output; - v.#index.output(); + #var.#index.output(); }} }) .collect(), @@ -438,12 +438,13 @@ mod test { let extracted = extract_return_types(&return_type).unwrap(); - let actual = emit_output_capture(&extracted); + let var = format_ident!("__v"); + let actual = emit_output_capture(&var, &extracted); let expected = quote! { { struct _AssertOutput where FheProgramNode< Cipher < Signed > > : Output; - v.output(); + __v.output(); } }; @@ -460,16 +461,17 @@ mod test { let extracted = extract_return_types(&return_type).unwrap(); - let actual = emit_output_capture(&extracted); + let var = format_ident!("__v"); + let actual = emit_output_capture(&var, &extracted); let expected = quote! { { struct _AssertOutput where FheProgramNode< Cipher < Signed > > : Output; - v.0.output(); + __v.0.output(); } { struct _AssertOutput where FheProgramNode<[[Cipher; 6]; 7]>: Output; - v.1.output(); + __v.1.output(); } }; From e749c3e5e722b8c37c25ab288da8b55729960010 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 13:50:53 -0600 Subject: [PATCH 34/41] Remove TODOs --- sunscreen/src/types/intern/fhe_program_node.rs | 7 ++----- sunscreen/src/types/ops/insert.rs | 1 - sunscreen_compiler_macros/src/fhe_program.rs | 4 ---- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index 3d1609476..b1715478c 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -572,12 +572,9 @@ where coerce(node, Stage::Literal) } -// TODO make this automatic (in #[fhe_program], define `internal` with generics from within a -// function with the proper -// return values, and call .into() on each output). -/// Output your fhe program variable as a ciphertext. This will fail (at fhe program compile time) +/// Convert your fhe program variable to a proper ciphertext variable. This will fail (at fhe program compile time) /// if the variable is still a literal. You can also use `.into()` to accomplish the same thing. -pub fn fhe_out(var: FheProgramNode, Stage>) -> FheProgramNode> +fn fhe_out(var: FheProgramNode, Stage>) -> FheProgramNode> where L: FheLiteral, T: FheType, diff --git a/sunscreen/src/types/ops/insert.rs b/sunscreen/src/types/ops/insert.rs index adb719a34..51c75ec8a 100644 --- a/sunscreen/src/types/ops/insert.rs +++ b/sunscreen/src/types/ops/insert.rs @@ -18,7 +18,6 @@ pub trait GraphCipherInsert { /** * The type of the plaintext encoding */ - // TODO if this is always Self, then remove it type Val: FheType; /** diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index 0ece218b7..2e01a53fe 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -19,9 +19,6 @@ pub fn fhe_program_impl( } } -// TODO move validation errors into a `FheProgram::new() -> Result` function, -// gather the useful fields onto the struct, -// then call `FheProgram::new().and_then(|f|f.output())`. struct FheProgram<'a> { // The function passed to the proc macro item_fn: &'a ItemFn, @@ -217,7 +214,6 @@ impl<'a> FheProgram<'a> { return Err(Error::IncorrectScheme) } - // TODO: Other schemes. let mut context = FheContext::new(params.clone()); CURRENT_FHE_CTX.with(|ctx| { From a5e3670b74372c93bdb5c3d717a701eb397b82ee Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 13:54:52 -0600 Subject: [PATCH 35/41] Add missing example runs to CI --- .github/workflows/rust.yml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 73be0d975..8bffb3984 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -46,14 +46,22 @@ jobs: # Run our non-interactive examples and assert the complete without error - name: Verify examples (amm) run: cargo run --release --bin amm + - name: Verify examples (bigint) + run: cargo run --release --bin bigint - name: Verify examples (chi_sq) run: cargo run --release --bin chi_sq - - name: Verify examples (simple_multiply) - run: cargo run --release --bin simple_multiply - name: Verify examples (dot_prod) run: cargo run --release --bin dot_prod + - name: Verify examples (mean_variance) + run: cargo run --release --bin mean_variance + - name: Verify examples (ordering_zkp) + run: cargo run --release --bin ordering_zkp - name: Verify examples (pir) run: cargo run --release --bin pir + - name: Verify examples (simple_multiply) + run: cargo run --release --bin simple_multiply + - name: Verify examples (sudoku_zkp) + run: cargo run --release --bin sudoku_zkp - name: Build sunscreen and bincode run: cargo build --release --package sunscreen --package bincode - name: Build mdBook From ca31c91da3fa633427996ca491a41bd46b897688 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 14:04:47 -0600 Subject: [PATCH 36/41] Oops: fix 232 > 64 --- examples/ordering_zkp/src/main.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ordering_zkp/src/main.rs b/examples/ordering_zkp/src/main.rs index d87a79b9e..44f208600 100644 --- a/examples/ordering_zkp/src/main.rs +++ b/examples/ordering_zkp/src/main.rs @@ -15,8 +15,8 @@ fn main() -> Result<(), Error> { let runtime = ZkpRuntime::new(&BulletproofsBackend::new())?; - let amount = BPField::from(64); - let threshold = BPField::from(232); + let amount = BPField::from(232); + let threshold = BPField::from(64); // Prove that amount > threshold From e63c44ae8266ccf8891c2f967faa5597095a03ac Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Fri, 30 Jun 2023 15:26:16 -0600 Subject: [PATCH 37/41] Allow arbitrary expressions in fhe_var! --- sunscreen/src/types/mod.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sunscreen/src/types/mod.rs b/sunscreen/src/types/mod.rs index 3c6771bb5..ff79d408b 100644 --- a/sunscreen/src/types/mod.rs +++ b/sunscreen/src/types/mod.rs @@ -189,15 +189,13 @@ where /// ``` #[macro_export] macro_rules! fhe_var { - ($elem:literal) => ( + ($elem:expr) => ( $crate::types::intern::fhe_node($elem) ); - ($elem:literal; $n:expr) => ( - // TODO this will just copy the same node IDs. But that's ok, right? the graph nodes are - // immtuable anyway. + ($elem:expr; $n:expr) => ( [$crate::types::intern::fhe_node($elem); $n] ); - ($($elem:literal),+ $(,)?) => ( + ($($elem:expr),+ $(,)?) => ( [$($crate::types::intern::fhe_node($elem)),+] ); } From 3be1f39b635d9cfbb8d743f00938b8543b4836fd Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Sat, 1 Jul 2023 18:39:42 -0600 Subject: [PATCH 38/41] Use custom "into" to support impls on [] --- .../src/types/intern/fhe_program_node.rs | 79 ++++++++++++++----- sunscreen/tests/fhe_program_tests.rs | 18 +++++ sunscreen_compiler_macros/src/fhe_program.rs | 6 +- .../src/fhe_program_transforms.rs | 12 +-- 4 files changed, 86 insertions(+), 29 deletions(-) diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index b1715478c..7a3bb0bfd 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -569,41 +569,80 @@ where T: FheType + GraphCipherInsert, { let node = T::graph_cipher_insert(lit); - coerce(node, Stage::Literal) + unsafe_coerce(node, Stage::Literal) } -/// Convert your fhe program variable to a proper ciphertext variable. This will fail (at fhe program compile time) -/// if the variable is still a literal. You can also use `.into()` to accomplish the same thing. -fn fhe_out(var: FheProgramNode, Stage>) -> FheProgramNode> +/// Used for converting between types of an [`FheProgramNode`]. Implementations of this trait +/// define the legal conversions. +// Note: this is more of an `Into::into` flavor, which typically has weaker inference than +// `From::from`. However, the resulting type is specified explicitly by the `internal_inner` +// function generated in the `#[fhe_program]` macro, so there shouldn't be any problems. If +// necessary, we could adapt the macro to support a `From` style trait pretty easily. +pub trait Coerce { + /// Coerce one `FheProgramNode` to another. The underlying value stays the same, but the types + /// change. This function is allowed to panic on invalid coercions, which will result in a + /// failed `Err` state when compiling the FHE program. + /// + /// An invalid coercion would be defining a variable with [`crate::fhe_var!`] and coercing + /// it as a `Cipher` before assigning it to the result of an arithmetic operation with a + /// ciphertext. + fn coerce(self) -> T; +} + +// Allow trivial conversion T -> T +impl Coerce> for FheProgramNode { + fn coerce(self) -> Self { + self + } +} + +// If T -> V is legal, then so is [T] -> [V] +impl Coerce<[V; N]> for [T; N] +where + V: NumCiphertexts, + T: Coerce, +{ + fn coerce(self) -> [V; N] { + self.map(Coerce::::coerce) + } +} + +// Allow `Indeterminate` -> `Cipher`, but panic if at `Stage::Literal` +impl Coerce>> for FheProgramNode, Stage> where L: FheLiteral, T: FheType, { - match var.stage { - Stage::Literal => panic!("User created FHE variables must undergo arithmetic operations with ciphertexts before they are returned as output."), - Stage::Cipher => { - FheProgramNode { - ids: var.ids, - stage: (), - _phantom: std::marker::PhantomData, + fn coerce(self) -> FheProgramNode> { + match self.stage { + Stage::Literal => panic!("User created FHE variables must undergo arithmetic operations with ciphertexts before they are returned as output."), + Stage::Cipher => { + FheProgramNode { + ids: self.ids, + stage: (), + _phantom: std::marker::PhantomData, + } } - } + } } } +// This is such a common one, let the user call `var.into()` in their programs. +// We unfortunately can't make a very generic impl here, as it would conflict with the blanket +// `From for T`. impl From, Stage>> for FheProgramNode> where L: FheLiteral, T: FheType, { fn from(value: FheProgramNode, Stage>) -> Self { - fhe_out(value) + value.coerce() } } /// WARNING: This is an unsafe function. It allows casting graph nodes arbitrarily. Use with /// caution. -fn coerce( +fn unsafe_coerce( a: FheProgramNode, t: T, ) -> FheProgramNode { @@ -629,18 +668,18 @@ macro_rules! impl_indeterminate_arithmetic_op { fn [<$op:lower>](self, rhs: FheProgramNode>) -> Self::Output { let node = match self.stage { Stage::Literal => { - let lit_node = coerce(self, ()); + let lit_node = unsafe_coerce(self, ()); // N.B. we've already added this literal as a plaintext node T::[](rhs, lit_node) } Stage::Cipher => { - let cipher_node = coerce(self, ()); + let cipher_node = unsafe_coerce(self, ()); T::[](rhs, cipher_node) } }; // No matter what `self.stage` currently is, it is being operated on with a // ciphertext, so its next stage is cipher. - coerce(node, Stage::Cipher) + unsafe_coerce(node, Stage::Cipher) } } @@ -658,18 +697,18 @@ macro_rules! impl_indeterminate_arithmetic_op { fn [<$op:lower>](self, rhs: FheProgramNode, Stage>) -> Self::Output { let node = match rhs.stage { Stage::Literal => { - let lit_node = coerce(rhs, ()); + let lit_node = unsafe_coerce(rhs, ()); // N.B. we've already added this literal as a plaintext node T::[](self, lit_node) } Stage::Cipher => { - let cipher_node = coerce(rhs, ()); + let cipher_node = unsafe_coerce(rhs, ()); T::[](self, cipher_node) } }; // No matter what `rhs.stage` currently is, it is being added to a ciphertext, so its next // stage is cipher. - coerce(node, Stage::Cipher) + unsafe_coerce(node, Stage::Cipher) } } } diff --git a/sunscreen/tests/fhe_program_tests.rs b/sunscreen/tests/fhe_program_tests.rs index 3981d91e2..abecc583e 100644 --- a/sunscreen/tests/fhe_program_tests.rs +++ b/sunscreen/tests/fhe_program_tests.rs @@ -453,3 +453,21 @@ fn can_collect_multiple_outputs() { serde_json::from_value::(expected).unwrap() ); } + +#[test] +fn coercion_supports_arbitrarily_nested_outputs() { + // This test just tests compilation is valid + #[fhe_program(scheme = "bfv")] + fn fhe_program_just_cipher( + a: Cipher, + ) -> ([Cipher; 2], [[Cipher; 3]; 2]) { + ([a; 2], [[a; 3]; 2]) + } + + #[fhe_program(scheme = "bfv")] + fn fhe_program_var(a: Cipher) -> ([Cipher; 2], [[Cipher; 3]; 2]) { + let mut sum = fhe_var!(0); + sum = sum + a; + ([sum; 2], [[sum; 3]; 2]) + } +} diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index 2e01a53fe..957755cf1 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -177,12 +177,12 @@ impl<'a> FheProgram<'a> { let fhe_program_args = self.fhe_program_args(); let fhe_program_return = pack_into_tuple(fhe_program_return_types); - let inner_return = pack_into_tuple(&wrap_impl_into(fhe_program_return_types)); + let inner_return = pack_into_tuple(&wrap_impl_coerce(fhe_program_return_types)); let inner_arg_values = unwrapped_inputs.iter().map(|(_, _, name)| *name); let inner_return_idents = self.inner_return_idents(); let inner_return_values = pack_into_tuple(&inner_return_idents); // E.g. (_r1.into(), _r2.into()) - let inner_return_into_values = pack_into_tuple(&suffix_into(&inner_return_idents)); + let inner_return_into_values = pack_into_tuple(&suffix_coerce(&inner_return_idents)); let signature = self.signature(); @@ -208,7 +208,7 @@ impl<'a> FheProgram<'a> { fn build(&self, params: &sunscreen::Params) -> sunscreen::Result { use std::cell::RefCell; use std::mem::transmute; - use sunscreen::{fhe::{CURRENT_FHE_CTX, FheContext}, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{FheProgramNode, Input, Output}, NumCiphertexts, Type, TypeName, SwapRows, LaneCount, TypeNameInstance}}; + use sunscreen::{fhe::{CURRENT_FHE_CTX, FheContext}, Error, INDEX_ARENA, Result, Params, SchemeType, Value, types::{intern::{FheProgramNode, Input, Output, Coerce}, NumCiphertexts, Type, TypeName, SwapRows, LaneCount, TypeNameInstance}}; if SchemeType::Bfv != params.scheme_type { return Err(Error::IncorrectScheme) diff --git a/sunscreen_compiler_macros/src/fhe_program_transforms.rs b/sunscreen_compiler_macros/src/fhe_program_transforms.rs index 9aa721ff9..1b8b3dc58 100644 --- a/sunscreen_compiler_macros/src/fhe_program_transforms.rs +++ b/sunscreen_compiler_macros/src/fhe_program_transforms.rs @@ -110,20 +110,20 @@ pub fn pack_into_tuple(ts: &[T]) -> TokenStream2 { } /** - * Takes an array of types and wraps each in an `impl Into<_>` + * Takes an array of types and wraps each in an `impl Coerce<_>` */ -pub fn wrap_impl_into(ts: &[Type]) -> Vec { +pub fn wrap_impl_coerce(ts: &[Type]) -> Vec { ts.iter() - .map(|t| parse_quote_spanned! {t.span()=> impl Into<#t>}) + .map(|t| parse_quote_spanned! {t.span()=> impl Coerce<#t>}) .collect() } /** - * Takes an array of idents and suffixes each with `.into()` + * Takes an array of idents and suffixes each with `.coerce()` */ -pub fn suffix_into(is: &[Ident]) -> Vec { +pub fn suffix_coerce(is: &[Ident]) -> Vec { is.iter() - .map(|i| quote_spanned! {i.span()=> #i.into()}) + .map(|i| quote_spanned! {i.span()=> #i.coerce()}) .collect() } From 16a62772cdc067a6a786824f553a2e845d81f06e Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Mon, 3 Jul 2023 11:38:12 -0500 Subject: [PATCH 39/41] Support explicit #[private] params --- sunscreen_compiler_macros/src/zkp_program.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sunscreen_compiler_macros/src/zkp_program.rs b/sunscreen_compiler_macros/src/zkp_program.rs index 2220b3361..a23ec61a7 100644 --- a/sunscreen_compiler_macros/src/zkp_program.rs +++ b/sunscreen_compiler_macros/src/zkp_program.rs @@ -117,10 +117,11 @@ fn parse_inner(_attr_params: ZkpProgramAttrs, input_fn: ItemFn) -> Result {}, Some("public") => arg_kind = ArgumentKind::Public, Some("constant") => arg_kind = ArgumentKind::Constant, _ => { - return Err(Error::compile_error(a.0[0].path().span(), &format!("Expected #[public] or #[constant], found {}", a.0[0].path().to_token_stream()))); + return Err(Error::compile_error(a.0[0].path().span(), &format!("Expected #[private], #[public] or #[constant], found {}", a.0[0].path().to_token_stream()))); } } }, From 2e6a0815833f42be5ae1eb7fb27f60baf068050b Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Mon, 3 Jul 2023 12:40:07 -0500 Subject: [PATCH 40/41] Remove `backend = "bulletproofs"` attribute --- benchmarks/bfv_zkp/src/bfv.rs | 2 +- examples/ordering_zkp/src/main.rs | 10 ++--- examples/sudoku_zkp/src/main.rs | 2 +- sunscreen/Cargo.toml | 4 +- sunscreen/benches/fractional_range_proof.rs | 8 ++-- sunscreen/src/compiler.rs | 6 +-- sunscreen/src/types/mod.rs | 4 +- sunscreen/src/types/zkp/gadgets/arithmetic.rs | 2 +- sunscreen/src/types/zkp/gadgets/binary.rs | 4 +- sunscreen/src/types/zkp/native_field.rs | 8 ++-- sunscreen/src/types/zkp/rns_polynomial.rs | 6 +-- sunscreen/tests/zkp_program_tests.rs | 10 ++--- .../src/internals/attr.rs | 44 ++----------------- 13 files changed, 38 insertions(+), 72 deletions(-) diff --git a/benchmarks/bfv_zkp/src/bfv.rs b/benchmarks/bfv_zkp/src/bfv.rs index b77e1eee4..fc4fb4370 100644 --- a/benchmarks/bfv_zkp/src/bfv.rs +++ b/benchmarks/bfv_zkp/src/bfv.rs @@ -201,7 +201,7 @@ fn div_round_bigint(a: BigInt, b: BigInt) -> BigInt { type BfvPoly = RnsRingPolynomial; -#[zkp_program(backend = "bulletproofs")] +#[zkp_program] fn prove_enc( m: BfvPoly, e_1: BfvPoly, diff --git a/examples/ordering_zkp/src/main.rs b/examples/ordering_zkp/src/main.rs index 44f208600..2c2ffeb4d 100644 --- a/examples/ordering_zkp/src/main.rs +++ b/examples/ordering_zkp/src/main.rs @@ -3,6 +3,11 @@ use sunscreen::{ zkp_program, BackendField, BulletproofsBackend, Compiler, Error, ZkpBackend, ZkpRuntime, }; +#[zkp_program] +fn greater_than(a: NativeField, #[constant] b: NativeField) { + a.constrain_gt_bounded(b, 32) +} + type BPField = NativeField<::Field>; fn main() -> Result<(), Error> { @@ -27,11 +32,6 @@ fn main() -> Result<(), Error> { Ok(()) } -#[zkp_program(backend = "bulletproofs")] -fn greater_than(a: NativeField, #[constant] b: NativeField) { - a.constrain_gt_bounded(b, 32) -} - #[cfg(test)] mod tests { use super::*; diff --git a/examples/sudoku_zkp/src/main.rs b/examples/sudoku_zkp/src/main.rs index bf962eb62..22c297d41 100644 --- a/examples/sudoku_zkp/src/main.rs +++ b/examples/sudoku_zkp/src/main.rs @@ -50,7 +50,7 @@ fn main() -> Result<(), Error> { Ok(()) } -#[zkp_program(backend = "bulletproofs")] +#[zkp_program] fn sudoku_proof( #[constant] constraints: [[NativeField; 9]; 9], board: [[NativeField; 9]; 9], diff --git a/sunscreen/Cargo.toml b/sunscreen/Cargo.toml index 5039e5914..b933941c2 100644 --- a/sunscreen/Cargo.toml +++ b/sunscreen/Cargo.toml @@ -47,7 +47,9 @@ env_logger = "0.10.0" float-cmp = "0.9.0" lazy_static = "1.4.0" proptest = "1.1.0" -sunscreen_zkp_backend = { path = "../sunscreen_zkp_backend", features = ["bulletproofs"] } +sunscreen_zkp_backend = { path = "../sunscreen_zkp_backend", features = [ + "bulletproofs", +] } sunscreen_compiler_common = { path = "../sunscreen_compiler_common" } serde_json = "1.0.74" diff --git a/sunscreen/benches/fractional_range_proof.rs b/sunscreen/benches/fractional_range_proof.rs index dfac8459a..e71630e27 100644 --- a/sunscreen/benches/fractional_range_proof.rs +++ b/sunscreen/benches/fractional_range_proof.rs @@ -80,7 +80,7 @@ fn make_fractional_value(bits: &[i8]) -> [[BPField; 8]; 64] { /// the number of decimal places in the fractional amount. This is /// basically free, so we don't need to time it here. fn unshield_tx_fractional_range_proof(_c: &mut Criterion) { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] /** * Proves the 0 < a <= b and a == c */ @@ -153,7 +153,7 @@ fn unshield_tx_fractional_range_proof(_c: &mut Criterion) { /// in the SDLP, which reduces the number of circuit inputs. However, this proof is /// orders of magnitude faster than the SDLP so 🤷‍♀️. fn private_tx_fractional_range_proof(_c: &mut Criterion) { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] /** * Proves the 0 < a <= b and a == c */ @@ -228,7 +228,7 @@ fn private_tx_fractional_range_proof(_c: &mut Criterion) { /// * a is the submitted value under a given user's key. /// * b is the maximum value encrypted under the same user's key. fn mean_variance_fractional_range_proof(_c: &mut Criterion) { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] /** * Proves the 0 < a <= b and a == c */ @@ -291,7 +291,7 @@ fn mean_variance_fractional_range_proof(_c: &mut Criterion) { /// /// * a_0, a_1, a_2 are submitted value under a given user's key. fn chi_sq_fractional_range_proof(_c: &mut Criterion) { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] /** * Proves the 0 < a <= b and a == c */ diff --git a/sunscreen/src/compiler.rs b/sunscreen/src/compiler.rs index ab1d48843..12792217b 100644 --- a/sunscreen/src/compiler.rs +++ b/sunscreen/src/compiler.rs @@ -591,7 +591,7 @@ mod tests { #[test] fn fhe_zkp_program_yields_fhezkp_compiler() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn kitty() {} #[fhe_program(scheme = "bfv")] @@ -619,7 +619,7 @@ mod tests { #[test] fn compiling_zkp_program_yields_zkp_application() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn kitty() {} let app = GenericCompiler::new() @@ -633,7 +633,7 @@ mod tests { #[test] fn compiling_fhe_and_zkp_program_yields_fhezkp_application() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn kitty(_a: NativeField) {} #[fhe_program(scheme = "bfv")] diff --git a/sunscreen/src/types/mod.rs b/sunscreen/src/types/mod.rs index ff79d408b..c0ce25ffc 100644 --- a/sunscreen/src/types/mod.rs +++ b/sunscreen/src/types/mod.rs @@ -204,7 +204,7 @@ macro_rules! fhe_var { /// /// ``` /// # use sunscreen::{zkp_var, zkp_program, BackendField, types::zkp::NativeField}; -/// #[zkp_program(backend = "bulletproofs")] +/// #[zkp_program] /// fn equals_ten(a: NativeField) { /// let ten = zkp_var!(10); /// a.constrain_eq(ten); @@ -215,7 +215,7 @@ macro_rules! fhe_var { /// /// ``` /// # use sunscreen::{zkp_var, zkp_program, BackendField, types::zkp::NativeField}; -/// #[zkp_program(backend = "bulletproofs")] +/// #[zkp_program] /// fn equals_ten(a: NativeField) { /// let tens = zkp_var![10, 10, 10]; /// for ten in tens { diff --git a/sunscreen/src/types/zkp/gadgets/arithmetic.rs b/sunscreen/src/types/zkp/gadgets/arithmetic.rs index 86a0e1fc2..d07099f81 100644 --- a/sunscreen/src/types/zkp/gadgets/arithmetic.rs +++ b/sunscreen/src/types/zkp/gadgets/arithmetic.rs @@ -225,7 +225,7 @@ mod tests { #[test] fn modulus_gadget_works() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn div_rem( x: NativeField, m: NativeField, diff --git a/sunscreen/src/types/zkp/gadgets/binary.rs b/sunscreen/src/types/zkp/gadgets/binary.rs index 7bd328061..c97413a8e 100644 --- a/sunscreen/src/types/zkp/gadgets/binary.rs +++ b/sunscreen/src/types/zkp/gadgets/binary.rs @@ -166,7 +166,7 @@ mod tests { #[test] fn can_assert_binary() { // Prove we know the value that decomposes into 0b101010 - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn test(a: NativeField) { invoke_gadget(AssertBinary, a.ids); } @@ -206,7 +206,7 @@ mod tests { #[test] fn can_convert_to_binary() { // Prove we know the value that decomposes into 0b101010 - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn test(a: NativeField) { let bits = a.to_unsigned::<6>(); diff --git a/sunscreen/src/types/zkp/native_field.rs b/sunscreen/src/types/zkp/native_field.rs index e3e7d44b1..9fc239baa 100644 --- a/sunscreen/src/types/zkp/native_field.rs +++ b/sunscreen/src/types/zkp/native_field.rs @@ -428,7 +428,7 @@ mod tests { #[test] fn can_compare_le_bounded() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn le(x: NativeField, y: NativeField) { x.constrain_le_bounded(y, 16); } @@ -477,7 +477,7 @@ mod tests { #[test] fn can_compare_lt_bounded() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn le(x: NativeField, y: NativeField) { x.constrain_lt_bounded(y, 16); } @@ -526,7 +526,7 @@ mod tests { #[test] fn can_compare_ge_bounded() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn le(x: NativeField, y: NativeField) { x.constrain_ge_bounded(y, 16); } @@ -575,7 +575,7 @@ mod tests { #[test] fn can_compare_gt_bounded() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn le(x: NativeField, y: NativeField) { x.constrain_gt_bounded(y, 16); } diff --git a/sunscreen/src/types/zkp/rns_polynomial.rs b/sunscreen/src/types/zkp/rns_polynomial.rs index 572fc90e2..3f0cc86c6 100644 --- a/sunscreen/src/types/zkp/rns_polynomial.rs +++ b/sunscreen/src/types/zkp/rns_polynomial.rs @@ -201,7 +201,7 @@ mod tests { #[test] fn can_prove_added_polynomials() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn add_poly( #[constant] a: RnsRingPolynomial, #[constant] b: RnsRingPolynomial, @@ -258,7 +258,7 @@ mod tests { #[test] fn can_prove_multiply_polynomials() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn add_poly( #[constant] a: RnsRingPolynomial, #[constant] b: RnsRingPolynomial, @@ -311,7 +311,7 @@ mod tests { #[test] fn can_scale_polynomial() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn scale_poly( #[constant] a: RnsRingPolynomial, #[constant] b: NativeField, diff --git a/sunscreen/tests/zkp_program_tests.rs b/sunscreen/tests/zkp_program_tests.rs index ec139c403..e05f71509 100644 --- a/sunscreen/tests/zkp_program_tests.rs +++ b/sunscreen/tests/zkp_program_tests.rs @@ -6,7 +6,7 @@ type BPField = NativeField<::Field>; #[test] fn can_add_and_mul_native_fields() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn add_mul(a: NativeField, b: NativeField, c: NativeField) { let x = a * b + c; @@ -42,7 +42,7 @@ fn get_input_mismatch_on_incorrect_args() { use sunscreen_runtime::Error; use sunscreen_zkp_backend::Error as ZkpError; - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn add_mul(a: NativeField, b: NativeField) { let _ = a + b * a; } @@ -67,7 +67,7 @@ fn get_input_mismatch_on_incorrect_args() { #[test] fn can_use_public_inputs() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn add_mul(#[public] a: NativeField, b: NativeField, c: NativeField) { let x = a * b + c; @@ -100,7 +100,7 @@ fn can_use_public_inputs() { #[test] fn can_use_constant_inputs() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn add_mul( #[constant] a: NativeField, b: NativeField, @@ -137,7 +137,7 @@ fn can_use_constant_inputs() { #[test] fn can_declare_array_inputs() { - #[zkp_program(backend = "bulletproofs")] + #[zkp_program] fn in_range(a: [[NativeField; 9]; 64]) { for (i, a_i) in a.iter().enumerate() { for (j, a_i_j) in a_i.iter().enumerate() { diff --git a/sunscreen_compiler_macros/src/internals/attr.rs b/sunscreen_compiler_macros/src/internals/attr.rs index d3185a1b1..abb6e2599 100644 --- a/sunscreen_compiler_macros/src/internals/attr.rs +++ b/sunscreen_compiler_macros/src/internals/attr.rs @@ -221,48 +221,12 @@ impl Parse for FheProgramAttrs { } } -pub enum BackendType { - Bulletproofs, -} - -impl TryFrom<&AttrValue> for BackendType { - type Error = SynError; - - fn try_from(value: &AttrValue) -> SynResult { - let as_str = value.as_str()?; - - match as_str { - "bulletproofs" => Ok(BackendType::Bulletproofs), - _ => Err(SynError::new( - value.span(), - format!("Unknown backend `{}`", as_str.to_owned()), - )), - } - } -} - +// Keeping this struct around in case we add other attributes. #[allow(unused)] -pub struct ZkpProgramAttrs { - backend_type: BackendType, -} +pub struct ZkpProgramAttrs; impl Parse for ZkpProgramAttrs { - fn parse(input: ParseStream) -> SynResult { - let attrs = try_parse_dict(input)?; - - const VALUE_KEYS: &[&str] = &["backend"]; - - for i in attrs.keys() { - if !VALUE_KEYS.iter().any(|x| x == i) { - return Err(SynError::new(input.span(), format!("Unknown key '{}'", i))); - } - } - - let backend_type = attrs.get("backend").ok_or_else(|| { - SynError::new(input.span(), "required 'backend' is missing".to_owned()) - })?; - let backend_type = BackendType::try_from(backend_type)?; - - Ok(Self { backend_type }) + fn parse(_input: ParseStream) -> SynResult { + Ok(Self) } } From 863e05240f73eb7f7fd70c474178fcc9a74e7bb0 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Wed, 5 Jul 2023 16:33:31 -0500 Subject: [PATCH 41/41] Address PR reveiw --- .../src/types/intern/fhe_program_node.rs | 22 +++++++++---------- sunscreen/tests/fhe_program_tests.rs | 4 ---- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index 7a3bb0bfd..474012ef5 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -506,14 +506,14 @@ where } } -impl NumCiphertexts for FheProgramNode +impl NumCiphertexts for FheProgramNode where T: NumCiphertexts, { const NUM_CIPHERTEXTS: usize = T::NUM_CIPHERTEXTS; } -impl TypeName for FheProgramNode +impl TypeName for FheProgramNode where T: TypeName + NumCiphertexts, { @@ -569,7 +569,7 @@ where T: FheType + GraphCipherInsert, { let node = T::graph_cipher_insert(lit); - unsafe_coerce(node, Stage::Literal) + reinterpret_cast(node, Stage::Literal) } /// Used for converting between types of an [`FheProgramNode`]. Implementations of this trait @@ -590,7 +590,7 @@ pub trait Coerce { } // Allow trivial conversion T -> T -impl Coerce> for FheProgramNode { +impl Coerce> for FheProgramNode { fn coerce(self) -> Self { self } @@ -642,7 +642,7 @@ where /// WARNING: This is an unsafe function. It allows casting graph nodes arbitrarily. Use with /// caution. -fn unsafe_coerce( +fn reinterpret_cast( a: FheProgramNode, t: T, ) -> FheProgramNode { @@ -668,18 +668,18 @@ macro_rules! impl_indeterminate_arithmetic_op { fn [<$op:lower>](self, rhs: FheProgramNode>) -> Self::Output { let node = match self.stage { Stage::Literal => { - let lit_node = unsafe_coerce(self, ()); + let lit_node = reinterpret_cast(self, ()); // N.B. we've already added this literal as a plaintext node T::[](rhs, lit_node) } Stage::Cipher => { - let cipher_node = unsafe_coerce(self, ()); + let cipher_node = reinterpret_cast(self, ()); T::[](rhs, cipher_node) } }; // No matter what `self.stage` currently is, it is being operated on with a // ciphertext, so its next stage is cipher. - unsafe_coerce(node, Stage::Cipher) + reinterpret_cast(node, Stage::Cipher) } } @@ -697,18 +697,18 @@ macro_rules! impl_indeterminate_arithmetic_op { fn [<$op:lower>](self, rhs: FheProgramNode, Stage>) -> Self::Output { let node = match rhs.stage { Stage::Literal => { - let lit_node = unsafe_coerce(rhs, ()); + let lit_node = reinterpret_cast(rhs, ()); // N.B. we've already added this literal as a plaintext node T::[](self, lit_node) } Stage::Cipher => { - let cipher_node = unsafe_coerce(rhs, ()); + let cipher_node = reinterpret_cast(rhs, ()); T::[](self, cipher_node) } }; // No matter what `rhs.stage` currently is, it is being added to a ciphertext, so its next // stage is cipher. - unsafe_coerce(node, Stage::Cipher) + reinterpret_cast(node, Stage::Cipher) } } } diff --git a/sunscreen/tests/fhe_program_tests.rs b/sunscreen/tests/fhe_program_tests.rs index abecc583e..fe722ab40 100644 --- a/sunscreen/tests/fhe_program_tests.rs +++ b/sunscreen/tests/fhe_program_tests.rs @@ -293,10 +293,6 @@ fn can_insert_literals() { let context = fhe_program_sum.build(&get_params()).unwrap(); - // N.B. Can't match on json like the other tests because the operation to insert a literal - // plaintext ends up with a pointer handle in the json that - // changes on each run. This appears necessary (the underlying seal ptr gets bincode encoded to - // be tossed around). assert_eq!(context.node_count(), 6); assert_eq!( context[node_index(0)].operation,