aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jason Gross <jagro@google.com>2018-06-28 18:22:17 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2018-07-03 19:28:55 -0400
commit1ca948e2ca70d2b16109aeea7ea173be2d827367 (patch)
treeeb227028dfabd7daf066993ad5f1fb15c3480879
parent6e6e636fb91290b13c78596907582294d71fa65c (diff)
Allow passing functions to synthesize on the command line, and scmul for 25519
-rw-r--r--src/Experiments/NewPipeline/Arithmetic.v20
-rw-r--r--src/Experiments/NewPipeline/CLI.v34
-rw-r--r--src/Experiments/NewPipeline/Toplevel1.v112
3 files changed, 149 insertions, 17 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v
index 3ab5a5e09..87cc5c7b2 100644
--- a/src/Experiments/NewPipeline/Arithmetic.v
+++ b/src/Experiments/NewPipeline/Arithmetic.v
@@ -455,6 +455,7 @@ Module Positional. Section Positional.
:= let ca := add n balance a in
let _b := negate_snd b in
add n ca _b.
+
Lemma eval_sub a b
: (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) ->
(List.length a = n) -> (List.length b = n) ->
@@ -652,6 +653,25 @@ Section mod_ops.
subst carry_squaremod; reflexivity.
Qed.
+ Derive carry_scmulmod
+ SuchThat (forall (x : Z) (f : list Z)
+ (Hf : length f = n),
+ (eval weight n (carry_scmulmod x f)) mod (s - Associational.eval c)
+ = (x * eval weight n f) mod (s - Associational.eval c))
+ As eval_carry_scmulmod.
+ Proof.
+ intros.
+ push_Zmod.
+ rewrite <-eval_encode with (s:=s) (c:=c) (x:=x) (weight:=weight) (n:=n) by auto.
+ pull_Zmod.
+ rewrite<-eval_mulmod with (s:=s) (c:=c) by (auto; distr_length).
+ etransitivity;
+ [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs)
+ by auto; reflexivity ].
+ eapply f_equal2; [|trivial]. eapply f_equal.
+ subst carry_scmulmod; reflexivity.
+ Qed.
+
Derive carrymod
SuchThat (forall (f : list Z)
(Hf : length f = n),
diff --git a/src/Experiments/NewPipeline/CLI.v b/src/Experiments/NewPipeline/CLI.v
index ebe1fcca9..abcf70f1b 100644
--- a/src/Experiments/NewPipeline/CLI.v
+++ b/src/Experiments/NewPipeline/CLI.v
@@ -111,6 +111,7 @@ Module ForExtraction.
(s : string)
(c : string)
(machine_wordsize : string)
+ (requests : list string)
: list (string * Pipeline.ErrorT (list string)) + list string
:= let prefix := ("fiat_" ++ curve_description ++ "_")%string in
let str_n := n in
@@ -119,6 +120,7 @@ Module ForExtraction.
let str_c := c in
let str_s := s in
let machine_wordsize := parse_machine_wordsize machine_wordsize in
+ let show_requests := match requests with nil => "(all)" | _ => String.concat ", " requests end in
match parse_s s, parse_c c with
| None, None
=> inr ["Could not parse s (" ++ s ++ ") nor c (" ++ c ++ ")"]
@@ -128,10 +130,11 @@ Module ForExtraction.
=> inr ["Could not parse c (" ++ c ++ ")"]
| Some s, Some c
=> let '(res, types_used)
- := UnsaturatedSolinas.Synthesize n s c machine_wordsize prefix in
+ := UnsaturatedSolinas.Synthesize n s c machine_wordsize prefix requests in
let header :=
((["/* Autogenerated */";
"/* curve description: " ++ curve_description ++ " */";
+ "/* requested operations: " ++ show_requests ++ "*/";
"/* n = " ++ show false n ++ " (from """ ++ str_n ++ """) */";
"/* s = " ++ Hex.show_Z false s ++ " (from """ ++ str_s ++ """) */";
"/* c = " ++ show false c ++ " (from """ ++ str_c ++ """) */";
@@ -153,8 +156,9 @@ Module ForExtraction.
(s : string)
(c : string)
(machine_wordsize : string)
+ (requests : list string)
: list string + list string
- := match CollectErrors (PipelineLines curve_description n s c machine_wordsize) with
+ := match CollectErrors (PipelineLines curve_description n s c machine_wordsize requests) with
| inl ls
=> inl
(List.map (fun s => String.concat NewLine s ++ NewLine ++ NewLine)
@@ -173,10 +177,11 @@ Module ForExtraction.
(s : string)
(c : string)
(machine_wordsize : string)
+ (requests : list string)
(success : list string -> A)
(error : list string -> A)
: A
- := match ProcessedLines curve_description n s c machine_wordsize with
+ := match ProcessedLines curve_description n s c machine_wordsize requests with
| inl s => success s
| inr s => error s
end.
@@ -188,14 +193,14 @@ Module ForExtraction.
(error : list string -> A)
: A
:= match argv with
- | _::curve_description::n::s::c::machine_wordsize::nil
+ | _::curve_description::n::s::c::machine_wordsize::requests
=> Pipeline
- curve_description n s c machine_wordsize
+ curve_description n s c machine_wordsize requests
success
error
| nil => error ["empty argv"]
| prog::args
- => error ["Expected arguments curve_description, n, s, c, machine_wordsize, got " ++ show false (List.length args) ++ " arguments in " ++ prog]
+ => error ["Expected arguments curve_description, n, s, c, machine_wordsize, [function_to_synthesize*] got " ++ show false (List.length args) ++ " arguments in " ++ prog]
end.
End UnsaturatedSolinas.
@@ -205,12 +210,14 @@ Module ForExtraction.
(s : string)
(c : string)
(machine_wordsize : string)
+ (requests : list string)
: list (string * Pipeline.ErrorT (list string)) + list string
:= let prefix := ("fiat_" ++ curve_description ++ "_")%string in
let str_machine_wordsize := machine_wordsize in
let str_c := c in
let str_s := s in
let machine_wordsize := parse_machine_wordsize machine_wordsize in
+ let show_requests := match requests with nil => "(all)" | _ => String.concat ", " requests end in
match parse_s s, parse_c c with
| None, None
=> inr ["Could not parse s (" ++ s ++ ") nor c (" ++ c ++ ")"]
@@ -220,10 +227,11 @@ Module ForExtraction.
=> inr ["Could not parse c (" ++ c ++ ")"]
| Some s, Some c
=> let '(res, types_used)
- := SaturatedSolinas.Synthesize s c machine_wordsize prefix in
+ := SaturatedSolinas.Synthesize s c machine_wordsize prefix requests in
let header :=
((["/* Autogenerated */";
"/* curve description: " ++ curve_description ++ " */";
+ "/* requested operations: " ++ show_requests ++ "*/";
"/* s = " ++ Hex.show_Z false s ++ " (from """ ++ str_s ++ """) */";
"/* c = " ++ show false c ++ " (from """ ++ str_c ++ """) */";
"/* machine_wordsize = " ++ show false machine_wordsize ++ " (from """ ++ str_machine_wordsize ++ """) */";
@@ -243,8 +251,9 @@ Module ForExtraction.
(s : string)
(c : string)
(machine_wordsize : string)
+ (requests : list string)
: list string + list string
- := match CollectErrors (PipelineLines curve_description s c machine_wordsize) with
+ := match CollectErrors (PipelineLines curve_description s c machine_wordsize requests) with
| inl ls
=> inl
(List.map (fun s => String.concat NewLine s ++ NewLine ++ NewLine)
@@ -262,10 +271,11 @@ Module ForExtraction.
(s : string)
(c : string)
(machine_wordsize : string)
+ (requests : list string)
(success : list string -> A)
(error : list string -> A)
: A
- := match ProcessedLines curve_description s c machine_wordsize with
+ := match ProcessedLines curve_description s c machine_wordsize requests with
| inl s => success s
| inr s => error s
end.
@@ -277,14 +287,14 @@ Module ForExtraction.
(error : list string -> A)
: A
:= match argv with
- | _::curve_description::s::c::machine_wordsize::nil
+ | _::curve_description::s::c::machine_wordsize::requests
=> Pipeline
- curve_description s c machine_wordsize
+ curve_description s c machine_wordsize requests
success
error
| nil => error ["empty argv"]
| prog::args
- => error ["Expected arguments curve_description, s, c, machine_wordsize, got " ++ show false (List.length args) ++ " arguments in " ++ prog]
+ => error ["Expected arguments curve_description, s, c, machine_wordsize, [function_to_synthesize*] got " ++ show false (List.length args) ++ " arguments in " ++ prog]
end.
End SaturatedSolinas.
End ForExtraction.
diff --git a/src/Experiments/NewPipeline/Toplevel1.v b/src/Experiments/NewPipeline/Toplevel1.v
index ac1edbbea..80fbf398e 100644
--- a/src/Experiments/NewPipeline/Toplevel1.v
+++ b/src/Experiments/NewPipeline/Toplevel1.v
@@ -39,6 +39,7 @@ Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs.
Require Import Crypto.Util.ErrorT.
Require Import Crypto.Util.Strings.Show.
Require Import Crypto.Util.ZRange.Show.
+Require Import Crypto.Util.Strings.Equality.
Require Import Crypto.Experiments.NewPipeline.Arithmetic.
Require Crypto.Experiments.NewPipeline.Language.
Require Crypto.Experiments.NewPipeline.UnderLets.
@@ -302,7 +303,8 @@ Module Pipeline.
| Values_not_provably_distinctZ (descr : string) (lhs rhs : Z)
| Values_not_provably_equalZ (descr : string) (lhs rhs : Z)
| Values_not_provably_equal_listZ (descr : string) (lhs rhs : list Z)
- | Stringification_failed {t} (e : @Compilers.defaults.Expr t) (err : string).
+ | Stringification_failed {t} (e : @Compilers.defaults.Expr t) (err : string)
+ | Invalid_argument (msg : string).
Notation ErrorT := (ErrorT ErrorMessage).
@@ -409,6 +411,8 @@ Module Pipeline.
| Values_not_provably_equal_listZ descr lhs rhs
=> ["Values not provably equal (" ++ descr ++ ") : expected " ++ show true lhs ++ " = " ++ show true rhs]
| Stringification_failed t e err => ["Stringification failed on the syntax tree:"] ++ show_lines false e ++ [err]
+ | Invalid_argument msg
+ => ["Invalid argument:" ++ msg]%string
end.
Local Instance show_ErrorMessage : Show ErrorMessage
:= fun parens err => String.concat String.NewLine (show_lines parens err).
@@ -678,6 +682,20 @@ Derive carry_square_gen
Proof. Time cache_reify (). Time Qed.
Hint Extern 1 (_ = carry_squaremod _ _ _ _ _ _ _) => simple apply carry_square_gen_correct : reify_gen_cache.
+Derive carry_scmul_gen
+ SuchThat (forall (limbwidth_num limbwidth_den : Z)
+ (x : Z) (f : list Z)
+ (n : nat)
+ (s : Z)
+ (c : list (Z * Z))
+ (idxs : list nat),
+ Interp (t:=reify_type_of carry_scmulmod)
+ carry_scmul_gen limbwidth_num limbwidth_den s c n idxs x f
+ = carry_scmulmod limbwidth_num limbwidth_den s c n idxs x f)
+ As carry_scmul_gen_correct.
+Proof. Time cache_reify (). Time Qed.
+Hint Extern 1 (_ = carry_scmulmod _ _ _ _ _ _ _ _) => simple apply carry_scmul_gen_correct : reify_gen_cache.
+
Derive carry_gen
SuchThat (forall (limbwidth_num limbwidth_den : Z)
(f : list Z)
@@ -1013,6 +1031,20 @@ Module Import UnsaturatedSolinas.
(Some tight_bounds)
(carry_squaremod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n idxs).
+ Definition srcarry_scmul_const prefix (x : Z)
+ := BoundsPipelineToStrings_no_subst01
+ prefix ("carry_scmul_" ++ decimal_string_of_Z x)
+ (carry_scmul_gen
+ @ GallinaReify.Reify (Qnum limbwidth) @ GallinaReify.Reify (Z.pos (Qden limbwidth)) @ GallinaReify.Reify s @ GallinaReify.Reify c @ GallinaReify.Reify n @ GallinaReify.Reify idxs @ GallinaReify.Reify x)
+ (Some loose_bounds, tt)
+ (Some tight_bounds).
+
+ Definition rcarry_scmul_const_correct (x : Z)
+ := BoundsPipeline_no_subst01_correct
+ (Some loose_bounds, tt)
+ (Some tight_bounds)
+ (carry_scmulmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n idxs x).
+
Definition srcarry prefix
:= BoundsPipelineToStrings
prefix "carry"
@@ -1187,6 +1219,10 @@ Module Import UnsaturatedSolinas.
(* we need to strip off [Hrv : ... = Pipeline.Success rv] and related arguments *)
Definition rcarry_mul_correctT rv : Prop
:= type_of_strip_3arrow (@rcarry_mul_correct rv).
+ Definition rcarry_square_correctT rv : Prop
+ := type_of_strip_3arrow (@rcarry_square_correct rv).
+ Definition rcarry_scmul_const_correctT x rv : Prop
+ := type_of_strip_3arrow (@rcarry_scmul_const_correct x rv).
Definition rcarry_correctT rv : Prop
:= type_of_strip_3arrow (@rcarry_correct rv).
Definition rrelax_correctT rv : Prop
@@ -1444,10 +1480,50 @@ Module Import UnsaturatedSolinas.
(List.map (fun '(name, res) => (name, (res <- res; Success (fst res))%error)) ls,
ToString.C.bitwidths_used infos).
+ Local Open Scope string_scope.
+ Local Open Scope list_scope.
+
+ Definition known_functions
+ := [("carry_mul", srcarry_mul);
+ ("carry_square", srcarry_square);
+ ("carry", srcarry);
+ ("add", sradd);
+ ("sub", srsub);
+ ("opp", sropp);
+ ("to_bytes", srto_bytes);
+ ("from_bytes", srfrom_bytes)].
+
+ Definition synthesize_of_name (function_name_prefix : string) (name : string)
+ : string * ErrorT Pipeline.ErrorMessage (list string * ToString.C.ident_infos)
+ := fold_right
+ (fun v default
+ => match v with
+ | Some res => res
+ | None => default
+ end)
+ ((name,
+ Error
+ (Pipeline.Invalid_argument
+ ("Unrecognized request to synthesize """ ++ name ++ """; valid names are " ++ String.concat ", " (List.map (@fst _ _) known_functions) ++ ", or 'carry_scmul' followed by a decimal literal."))))
+ ((map
+ (fun '(expected_name, resf) => if string_beq name expected_name then Some (resf function_name_prefix) else None)
+ known_functions)
+ ++ [if prefix "carry_scmul" name
+ then let sc := substring (String.length "carry_scmul") (String.length name) name in
+ let scZ := Z_of_decimal_string sc in
+ if string_beq sc (decimal_string_of_Z scZ)
+ then Some (srcarry_scmul_const function_name_prefix scZ)
+ else None
+ else None]).
+
(** Note: If you change the name or type signature of this
function, you will need to update the code in CLI.v *)
- Definition Synthesize (function_name_prefix : string) : list (string * Pipeline.ErrorT (list string)) * PositiveSet.t (* types used *)
- := let ls := List.map (fun sr => sr function_name_prefix) [srcarry_mul; srcarry_square; srcarry; sradd; srsub; sropp; srto_bytes; srfrom_bytes] in
+ Definition Synthesize (function_name_prefix : string) (requests : list string)
+ : list (string * Pipeline.ErrorT (list string)) * PositiveSet.t (* types used *)
+ := let ls := match requests with
+ | nil => List.map (fun '(_, sr) => sr function_name_prefix) known_functions
+ | requests => List.map (synthesize_of_name function_name_prefix) requests
+ end in
let infos := aggregate_infos ls in
let '(extra_ls, extra_bit_widths) := extra_synthesis function_name_prefix infos in
(extra_ls ++ List.map (fun '(name, res) => (name, (res <- res; Success (fst res))%error)) ls,
@@ -1871,10 +1947,36 @@ Module SaturatedSolinas.
(List.map (fun '(name, res) => (name, (res <- res; Success (fst res))%error)) ls,
ToString.C.bitwidths_used infos).
+ Local Open Scope string_scope.
+ Local Open Scope list_scope.
+
+ Definition known_functions
+ := [("mulmod", srmulmod)].
+
+ Definition synthesize_of_name (function_name_prefix : string) (name : string)
+ : string * ErrorT Pipeline.ErrorMessage (list string * ToString.C.ident_infos)
+ := fold_right
+ (fun v default
+ => match v with
+ | Some res => res
+ | None => default
+ end)
+ ((name,
+ Error
+ (Pipeline.Invalid_argument
+ ("Unrecognized request to synthesize """ ++ name ++ """; valid names are " ++ String.concat ", " (List.map (@fst _ _) known_functions) ++ "."))))
+ (map
+ (fun '(expected_name, resf) => if string_beq name expected_name then Some (resf function_name_prefix) else None)
+ known_functions).
+
(** Note: If you change the name or type signature of this
function, you will need to update the code in CLI.v *)
- Definition Synthesize (function_name_prefix : string) : list (string * Pipeline.ErrorT (list string)) * PositiveSet.t (* types used *)
- := let ls := List.map (fun sr => sr function_name_prefix) [srmulmod] in
+ Definition Synthesize (function_name_prefix : string) (requests : list string)
+ : list (string * Pipeline.ErrorT (list string)) * PositiveSet.t (* types used *)
+ := let ls := match requests with
+ | nil => List.map (fun '(_, sr) => sr function_name_prefix) known_functions
+ | requests => List.map (synthesize_of_name function_name_prefix) requests
+ end in
let infos := aggregate_infos ls in
let '(extra_ls, extra_bit_widths) := extra_synthesis function_name_prefix infos in
(extra_ls ++ List.map (fun '(name, res) => (name, (res <- res; Success (fst res))%error)) ls,