diff options
author | Jason Gross <jagro@google.com> | 2018-06-28 18:22:17 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-07-03 19:28:55 -0400 |
commit | 1ca948e2ca70d2b16109aeea7ea173be2d827367 (patch) | |
tree | eb227028dfabd7daf066993ad5f1fb15c3480879 | |
parent | 6e6e636fb91290b13c78596907582294d71fa65c (diff) |
Allow passing functions to synthesize on the command line, and scmul for 25519
-rw-r--r-- | src/Experiments/NewPipeline/Arithmetic.v | 20 | ||||
-rw-r--r-- | src/Experiments/NewPipeline/CLI.v | 34 | ||||
-rw-r--r-- | src/Experiments/NewPipeline/Toplevel1.v | 112 |
3 files changed, 149 insertions, 17 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index 3ab5a5e09..87cc5c7b2 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -455,6 +455,7 @@ Module Positional. Section Positional. := let ca := add n balance a in let _b := negate_snd b in add n ca _b. + Lemma eval_sub a b : (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) -> (List.length a = n) -> (List.length b = n) -> @@ -652,6 +653,25 @@ Section mod_ops. subst carry_squaremod; reflexivity. Qed. + Derive carry_scmulmod + SuchThat (forall (x : Z) (f : list Z) + (Hf : length f = n), + (eval weight n (carry_scmulmod x f)) mod (s - Associational.eval c) + = (x * eval weight n f) mod (s - Associational.eval c)) + As eval_carry_scmulmod. + Proof. + intros. + push_Zmod. + rewrite <-eval_encode with (s:=s) (c:=c) (x:=x) (weight:=weight) (n:=n) by auto. + pull_Zmod. + rewrite<-eval_mulmod with (s:=s) (c:=c) by (auto; distr_length). + etransitivity; + [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs) + by auto; reflexivity ]. + eapply f_equal2; [|trivial]. eapply f_equal. + subst carry_scmulmod; reflexivity. + Qed. + Derive carrymod SuchThat (forall (f : list Z) (Hf : length f = n), diff --git a/src/Experiments/NewPipeline/CLI.v b/src/Experiments/NewPipeline/CLI.v index ebe1fcca9..abcf70f1b 100644 --- a/src/Experiments/NewPipeline/CLI.v +++ b/src/Experiments/NewPipeline/CLI.v @@ -111,6 +111,7 @@ Module ForExtraction. (s : string) (c : string) (machine_wordsize : string) + (requests : list string) : list (string * Pipeline.ErrorT (list string)) + list string := let prefix := ("fiat_" ++ curve_description ++ "_")%string in let str_n := n in @@ -119,6 +120,7 @@ Module ForExtraction. let str_c := c in let str_s := s in let machine_wordsize := parse_machine_wordsize machine_wordsize in + let show_requests := match requests with nil => "(all)" | _ => String.concat ", " requests end in match parse_s s, parse_c c with | None, None => inr ["Could not parse s (" ++ s ++ ") nor c (" ++ c ++ ")"] @@ -128,10 +130,11 @@ Module ForExtraction. => inr ["Could not parse c (" ++ c ++ ")"] | Some s, Some c => let '(res, types_used) - := UnsaturatedSolinas.Synthesize n s c machine_wordsize prefix in + := UnsaturatedSolinas.Synthesize n s c machine_wordsize prefix requests in let header := ((["/* Autogenerated */"; "/* curve description: " ++ curve_description ++ " */"; + "/* requested operations: " ++ show_requests ++ "*/"; "/* n = " ++ show false n ++ " (from """ ++ str_n ++ """) */"; "/* s = " ++ Hex.show_Z false s ++ " (from """ ++ str_s ++ """) */"; "/* c = " ++ show false c ++ " (from """ ++ str_c ++ """) */"; @@ -153,8 +156,9 @@ Module ForExtraction. (s : string) (c : string) (machine_wordsize : string) + (requests : list string) : list string + list string - := match CollectErrors (PipelineLines curve_description n s c machine_wordsize) with + := match CollectErrors (PipelineLines curve_description n s c machine_wordsize requests) with | inl ls => inl (List.map (fun s => String.concat NewLine s ++ NewLine ++ NewLine) @@ -173,10 +177,11 @@ Module ForExtraction. (s : string) (c : string) (machine_wordsize : string) + (requests : list string) (success : list string -> A) (error : list string -> A) : A - := match ProcessedLines curve_description n s c machine_wordsize with + := match ProcessedLines curve_description n s c machine_wordsize requests with | inl s => success s | inr s => error s end. @@ -188,14 +193,14 @@ Module ForExtraction. (error : list string -> A) : A := match argv with - | _::curve_description::n::s::c::machine_wordsize::nil + | _::curve_description::n::s::c::machine_wordsize::requests => Pipeline - curve_description n s c machine_wordsize + curve_description n s c machine_wordsize requests success error | nil => error ["empty argv"] | prog::args - => error ["Expected arguments curve_description, n, s, c, machine_wordsize, got " ++ show false (List.length args) ++ " arguments in " ++ prog] + => error ["Expected arguments curve_description, n, s, c, machine_wordsize, [function_to_synthesize*] got " ++ show false (List.length args) ++ " arguments in " ++ prog] end. End UnsaturatedSolinas. @@ -205,12 +210,14 @@ Module ForExtraction. (s : string) (c : string) (machine_wordsize : string) + (requests : list string) : list (string * Pipeline.ErrorT (list string)) + list string := let prefix := ("fiat_" ++ curve_description ++ "_")%string in let str_machine_wordsize := machine_wordsize in let str_c := c in let str_s := s in let machine_wordsize := parse_machine_wordsize machine_wordsize in + let show_requests := match requests with nil => "(all)" | _ => String.concat ", " requests end in match parse_s s, parse_c c with | None, None => inr ["Could not parse s (" ++ s ++ ") nor c (" ++ c ++ ")"] @@ -220,10 +227,11 @@ Module ForExtraction. => inr ["Could not parse c (" ++ c ++ ")"] | Some s, Some c => let '(res, types_used) - := SaturatedSolinas.Synthesize s c machine_wordsize prefix in + := SaturatedSolinas.Synthesize s c machine_wordsize prefix requests in let header := ((["/* Autogenerated */"; "/* curve description: " ++ curve_description ++ " */"; + "/* requested operations: " ++ show_requests ++ "*/"; "/* s = " ++ Hex.show_Z false s ++ " (from """ ++ str_s ++ """) */"; "/* c = " ++ show false c ++ " (from """ ++ str_c ++ """) */"; "/* machine_wordsize = " ++ show false machine_wordsize ++ " (from """ ++ str_machine_wordsize ++ """) */"; @@ -243,8 +251,9 @@ Module ForExtraction. (s : string) (c : string) (machine_wordsize : string) + (requests : list string) : list string + list string - := match CollectErrors (PipelineLines curve_description s c machine_wordsize) with + := match CollectErrors (PipelineLines curve_description s c machine_wordsize requests) with | inl ls => inl (List.map (fun s => String.concat NewLine s ++ NewLine ++ NewLine) @@ -262,10 +271,11 @@ Module ForExtraction. (s : string) (c : string) (machine_wordsize : string) + (requests : list string) (success : list string -> A) (error : list string -> A) : A - := match ProcessedLines curve_description s c machine_wordsize with + := match ProcessedLines curve_description s c machine_wordsize requests with | inl s => success s | inr s => error s end. @@ -277,14 +287,14 @@ Module ForExtraction. (error : list string -> A) : A := match argv with - | _::curve_description::s::c::machine_wordsize::nil + | _::curve_description::s::c::machine_wordsize::requests => Pipeline - curve_description s c machine_wordsize + curve_description s c machine_wordsize requests success error | nil => error ["empty argv"] | prog::args - => error ["Expected arguments curve_description, s, c, machine_wordsize, got " ++ show false (List.length args) ++ " arguments in " ++ prog] + => error ["Expected arguments curve_description, s, c, machine_wordsize, [function_to_synthesize*] got " ++ show false (List.length args) ++ " arguments in " ++ prog] end. End SaturatedSolinas. End ForExtraction. diff --git a/src/Experiments/NewPipeline/Toplevel1.v b/src/Experiments/NewPipeline/Toplevel1.v index ac1edbbea..80fbf398e 100644 --- a/src/Experiments/NewPipeline/Toplevel1.v +++ b/src/Experiments/NewPipeline/Toplevel1.v @@ -39,6 +39,7 @@ Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. Require Import Crypto.Util.ErrorT. Require Import Crypto.Util.Strings.Show. Require Import Crypto.Util.ZRange.Show. +Require Import Crypto.Util.Strings.Equality. Require Import Crypto.Experiments.NewPipeline.Arithmetic. Require Crypto.Experiments.NewPipeline.Language. Require Crypto.Experiments.NewPipeline.UnderLets. @@ -302,7 +303,8 @@ Module Pipeline. | Values_not_provably_distinctZ (descr : string) (lhs rhs : Z) | Values_not_provably_equalZ (descr : string) (lhs rhs : Z) | Values_not_provably_equal_listZ (descr : string) (lhs rhs : list Z) - | Stringification_failed {t} (e : @Compilers.defaults.Expr t) (err : string). + | Stringification_failed {t} (e : @Compilers.defaults.Expr t) (err : string) + | Invalid_argument (msg : string). Notation ErrorT := (ErrorT ErrorMessage). @@ -409,6 +411,8 @@ Module Pipeline. | Values_not_provably_equal_listZ descr lhs rhs => ["Values not provably equal (" ++ descr ++ ") : expected " ++ show true lhs ++ " = " ++ show true rhs] | Stringification_failed t e err => ["Stringification failed on the syntax tree:"] ++ show_lines false e ++ [err] + | Invalid_argument msg + => ["Invalid argument:" ++ msg]%string end. Local Instance show_ErrorMessage : Show ErrorMessage := fun parens err => String.concat String.NewLine (show_lines parens err). @@ -678,6 +682,20 @@ Derive carry_square_gen Proof. Time cache_reify (). Time Qed. Hint Extern 1 (_ = carry_squaremod _ _ _ _ _ _ _) => simple apply carry_square_gen_correct : reify_gen_cache. +Derive carry_scmul_gen + SuchThat (forall (limbwidth_num limbwidth_den : Z) + (x : Z) (f : list Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (idxs : list nat), + Interp (t:=reify_type_of carry_scmulmod) + carry_scmul_gen limbwidth_num limbwidth_den s c n idxs x f + = carry_scmulmod limbwidth_num limbwidth_den s c n idxs x f) + As carry_scmul_gen_correct. +Proof. Time cache_reify (). Time Qed. +Hint Extern 1 (_ = carry_scmulmod _ _ _ _ _ _ _ _) => simple apply carry_scmul_gen_correct : reify_gen_cache. + Derive carry_gen SuchThat (forall (limbwidth_num limbwidth_den : Z) (f : list Z) @@ -1013,6 +1031,20 @@ Module Import UnsaturatedSolinas. (Some tight_bounds) (carry_squaremod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n idxs). + Definition srcarry_scmul_const prefix (x : Z) + := BoundsPipelineToStrings_no_subst01 + prefix ("carry_scmul_" ++ decimal_string_of_Z x) + (carry_scmul_gen + @ GallinaReify.Reify (Qnum limbwidth) @ GallinaReify.Reify (Z.pos (Qden limbwidth)) @ GallinaReify.Reify s @ GallinaReify.Reify c @ GallinaReify.Reify n @ GallinaReify.Reify idxs @ GallinaReify.Reify x) + (Some loose_bounds, tt) + (Some tight_bounds). + + Definition rcarry_scmul_const_correct (x : Z) + := BoundsPipeline_no_subst01_correct + (Some loose_bounds, tt) + (Some tight_bounds) + (carry_scmulmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n idxs x). + Definition srcarry prefix := BoundsPipelineToStrings prefix "carry" @@ -1187,6 +1219,10 @@ Module Import UnsaturatedSolinas. (* we need to strip off [Hrv : ... = Pipeline.Success rv] and related arguments *) Definition rcarry_mul_correctT rv : Prop := type_of_strip_3arrow (@rcarry_mul_correct rv). + Definition rcarry_square_correctT rv : Prop + := type_of_strip_3arrow (@rcarry_square_correct rv). + Definition rcarry_scmul_const_correctT x rv : Prop + := type_of_strip_3arrow (@rcarry_scmul_const_correct x rv). Definition rcarry_correctT rv : Prop := type_of_strip_3arrow (@rcarry_correct rv). Definition rrelax_correctT rv : Prop @@ -1444,10 +1480,50 @@ Module Import UnsaturatedSolinas. (List.map (fun '(name, res) => (name, (res <- res; Success (fst res))%error)) ls, ToString.C.bitwidths_used infos). + Local Open Scope string_scope. + Local Open Scope list_scope. + + Definition known_functions + := [("carry_mul", srcarry_mul); + ("carry_square", srcarry_square); + ("carry", srcarry); + ("add", sradd); + ("sub", srsub); + ("opp", sropp); + ("to_bytes", srto_bytes); + ("from_bytes", srfrom_bytes)]. + + Definition synthesize_of_name (function_name_prefix : string) (name : string) + : string * ErrorT Pipeline.ErrorMessage (list string * ToString.C.ident_infos) + := fold_right + (fun v default + => match v with + | Some res => res + | None => default + end) + ((name, + Error + (Pipeline.Invalid_argument + ("Unrecognized request to synthesize """ ++ name ++ """; valid names are " ++ String.concat ", " (List.map (@fst _ _) known_functions) ++ ", or 'carry_scmul' followed by a decimal literal.")))) + ((map + (fun '(expected_name, resf) => if string_beq name expected_name then Some (resf function_name_prefix) else None) + known_functions) + ++ [if prefix "carry_scmul" name + then let sc := substring (String.length "carry_scmul") (String.length name) name in + let scZ := Z_of_decimal_string sc in + if string_beq sc (decimal_string_of_Z scZ) + then Some (srcarry_scmul_const function_name_prefix scZ) + else None + else None]). + (** Note: If you change the name or type signature of this function, you will need to update the code in CLI.v *) - Definition Synthesize (function_name_prefix : string) : list (string * Pipeline.ErrorT (list string)) * PositiveSet.t (* types used *) - := let ls := List.map (fun sr => sr function_name_prefix) [srcarry_mul; srcarry_square; srcarry; sradd; srsub; sropp; srto_bytes; srfrom_bytes] in + Definition Synthesize (function_name_prefix : string) (requests : list string) + : list (string * Pipeline.ErrorT (list string)) * PositiveSet.t (* types used *) + := let ls := match requests with + | nil => List.map (fun '(_, sr) => sr function_name_prefix) known_functions + | requests => List.map (synthesize_of_name function_name_prefix) requests + end in let infos := aggregate_infos ls in let '(extra_ls, extra_bit_widths) := extra_synthesis function_name_prefix infos in (extra_ls ++ List.map (fun '(name, res) => (name, (res <- res; Success (fst res))%error)) ls, @@ -1871,10 +1947,36 @@ Module SaturatedSolinas. (List.map (fun '(name, res) => (name, (res <- res; Success (fst res))%error)) ls, ToString.C.bitwidths_used infos). + Local Open Scope string_scope. + Local Open Scope list_scope. + + Definition known_functions + := [("mulmod", srmulmod)]. + + Definition synthesize_of_name (function_name_prefix : string) (name : string) + : string * ErrorT Pipeline.ErrorMessage (list string * ToString.C.ident_infos) + := fold_right + (fun v default + => match v with + | Some res => res + | None => default + end) + ((name, + Error + (Pipeline.Invalid_argument + ("Unrecognized request to synthesize """ ++ name ++ """; valid names are " ++ String.concat ", " (List.map (@fst _ _) known_functions) ++ ".")))) + (map + (fun '(expected_name, resf) => if string_beq name expected_name then Some (resf function_name_prefix) else None) + known_functions). + (** Note: If you change the name or type signature of this function, you will need to update the code in CLI.v *) - Definition Synthesize (function_name_prefix : string) : list (string * Pipeline.ErrorT (list string)) * PositiveSet.t (* types used *) - := let ls := List.map (fun sr => sr function_name_prefix) [srmulmod] in + Definition Synthesize (function_name_prefix : string) (requests : list string) + : list (string * Pipeline.ErrorT (list string)) * PositiveSet.t (* types used *) + := let ls := match requests with + | nil => List.map (fun '(_, sr) => sr function_name_prefix) known_functions + | requests => List.map (synthesize_of_name function_name_prefix) requests + end in let infos := aggregate_infos ls in let '(extra_ls, extra_bit_widths) := extra_synthesis function_name_prefix infos in (extra_ls ++ List.map (fun '(name, res) => (name, (res <- res; Success (fst res))%error)) ls, |