aboutsummaryrefslogtreecommitdiff
path: root/src/PushButtonSynthesis/Primitives.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/PushButtonSynthesis/Primitives.v')
-rw-r--r--src/PushButtonSynthesis/Primitives.v472
1 files changed, 454 insertions, 18 deletions
diff --git a/src/PushButtonSynthesis/Primitives.v b/src/PushButtonSynthesis/Primitives.v
index f36f4cb9c..e45940369 100644
--- a/src/PushButtonSynthesis/Primitives.v
+++ b/src/PushButtonSynthesis/Primitives.v
@@ -133,19 +133,7 @@ Local Notation out_bounds_of_pipeline result
(only parsing).
Notation FromPipelineToString prefix name result
- := (((prefix ++ name)%string,
- match result with
- | Success E'
- => let E := ToString.C.ToFunctionLines
- true true (* static *) prefix (prefix ++ name)%string [] E' None
- (arg_bounds_of_pipeline result)
- (out_bounds_of_pipeline result) in
- match E with
- | inl E => Success E
- | inr err => Error (Pipeline.Stringification_failed E' err)
- end
- | Error err => Error err
- end)).
+ := (Pipeline.FromPipelineToString prefix name result).
Ltac prove_correctness use_curve_good :=
let Hres := match goal with H : _ = Success _ |- _ => H end in
@@ -175,6 +163,419 @@ Ltac prove_correctness use_curve_good :=
| progress autorewrite with distr_length in * ]
| .. ].
+Module CorrectnessStringification.
+ Module dyn_context.
+ Inductive list :=
+ | nil
+ | cons {T1 T2} (k : T1) (v : T2) (ctx : list).
+ End dyn_context.
+
+ Ltac strip_bounds_info correctness :=
+ lazymatch correctness with
+ | (fun x : ?T => ?f)
+ => let fx := fresh in
+ constr:(fun x : T => match f return _ with
+ | fx => ltac:(let fx := (eval cbv [fx] in fx) in
+ let v := strip_bounds_info fx in
+ exact v)
+ end)
+ | ((lower ?r <=? ?v) && (?v <=? upper ?r))%bool%Z = true -> ?T
+ => strip_bounds_info T
+ | list_Z_bounded_by _ _ -> ?T
+ => strip_bounds_info T
+ | ?T /\ list_Z_bounded_by _ _
+ => T
+ | ?T /\ (match _ with pair _ _ => _ end = true)
+ => T
+ | ?T /\ ((lower ?r <=? ?v) && (?v <=? upper ?r))%bool%Z = true
+ => T
+ | iff _ _
+ => correctness
+ | _ = _ /\ (_ = _ /\ (_ <= _ < _))
+ => correctness
+ | _ = _ :> list Z
+ => correctness
+ | forall x : ?T, ?f
+ => let fx := fresh in
+ constr:(forall x : T, match f return _ with
+ | fx => ltac:(let fx := (eval cbv [fx] in fx) in
+ let v := strip_bounds_info fx in
+ exact v)
+ end)
+ | ?T
+ => let __ := match goal with _ => idtac "Unrecognized bounds component:" T end in
+ constr:(I : I)
+ end.
+
+ Ltac with_assoc_list ctx correctness arg_var_names out_var_names cont :=
+ lazymatch correctness with
+ | (fun x : ?T => ?f)
+ => let fx := fresh in
+ constr:(fun x : T
+ => match f return _ with
+ | fx
+ => ltac:(let fx' := (eval cbv delta [fx] in fx) in
+ clear fx;
+ let ret := with_assoc_list
+ (dyn_context.cons x out_var_names ctx)
+ fx'
+ arg_var_names
+ ()
+ cont in
+ exact ret)
+ end)
+ | _
+ => let T := type of arg_var_names in
+ lazymatch (eval hnf in T) with
+ | prod _ _
+ => lazymatch correctness with
+ | (forall x : ?T, ?f)
+ => let fx := fresh in
+ constr:(fun x : T
+ => match f return _ with
+ | fx
+ => ltac:(let fx' := (eval cbv delta [fx] in fx) in
+ clear fx;
+ let ret := with_assoc_list
+ (dyn_context.cons x (fst arg_var_names) ctx)
+ fx'
+ (snd arg_var_names)
+ out_var_names
+ cont in
+ exact ret)
+ end)
+ | ?T
+ => cont ctx T
+ end
+ | _ => cont ctx correctness
+ end
+ end.
+
+ Ltac maybe_parenthesize str natural cur_lvl :=
+ let should_paren := (eval cbv in (Z.ltb cur_lvl natural)) in
+ lazymatch should_paren with
+ | true => constr:(("(" ++ str ++ ")")%string)
+ | false => str
+ end.
+
+ Ltac find_head_in_ctx' ctx x cont :=
+ let h := head x in
+ lazymatch ctx with
+ | context[dyn_context.cons h ?name _] => cont name
+ | context[dyn_context.cons x ?name _] => cont name
+ | _ => lazymatch x with
+ | fst ?x
+ => find_head_in_ctx' ctx x ltac:(fun x => cont (fst x))
+ | snd ?x
+ => find_head_in_ctx' ctx x ltac:(fun x => cont (snd x))
+ | _ => constr:(@None string)
+ end
+ end.
+ Ltac find_head_in_ctx ctx x :=
+ find_head_in_ctx' ctx x ltac:(fun x => constr:(Some x)).
+
+ Ltac find_in_ctx' ctx x cont :=
+ lazymatch ctx with
+ | context[dyn_context.cons x ?name _] => cont name
+ | _ => lazymatch x with
+ | fst ?x
+ => find_in_ctx' ctx x ltac:(fun x => cont (fst x))
+ | snd ?x
+ => find_in_ctx' ctx x ltac:(fun x => cont (snd x))
+ | _ => constr:(@None string)
+ end
+ end.
+ Ltac find_in_ctx ctx x :=
+ find_in_ctx' ctx x ltac:(fun x => constr:(Some x)).
+
+ Ltac test_is_var v :=
+ constr:(ltac:(tryif is_var v then exact true else exact false)).
+
+ Local Open Scope string_scope.
+
+ Ltac fresh_from' ctx check_list start_val :=
+ lazymatch check_list with
+ | cons ?n ?check_list
+ => lazymatch ctx with
+ | context[dyn_context.cons _ n]
+ => fresh_from' ctx check_list start_val
+ | _ => n
+ end
+ | _
+ => let n := (eval cbv in ("x" ++ decimal_string_of_Z start_val)) in
+ lazymatch ctx with
+ | context[dyn_context.cons _ n]
+ => fresh_from' ctx check_list (Z.succ start_val)
+ | _ => n
+ end
+ end.
+
+ Ltac fresh_from ctx := fresh_from' ctx ["x"; "y"; "z"] 0%Z.
+
+ Ltac stringify_function_binders ctx correctness stringify_body :=
+ lazymatch correctness with
+ | (fun x : ?T => ?f)
+ => let fx := fresh in
+ let xn := fresh_from ctx in
+ lazymatch
+ constr:(
+ fun x : T
+ => match f return string with
+ | fx
+ => ltac:(
+ let fx' := (eval cbv delta [fx] in fx) in
+ clear fx;
+ let res := stringify_function_binders
+ (dyn_context.cons x xn ctx)
+ fx'
+ stringify_body in
+ exact (" " ++ xn ++ res))
+ end) with
+ | fun _ => ?f => f
+ | ?F => let __ := match goal with _ => idtac "Failed to eliminate functional dependencies in" F end in
+ constr:(I : I)
+ end
+ | ?v => let res := stringify_body ctx v in
+ constr:(", " ++ res)
+ end.
+
+ Ltac is_literal x :=
+ lazymatch x with
+ | O => true
+ | S ?x => is_literal x
+ | _ => false
+ end.
+
+ Ltac stringify_rec0 evalf ctx correctness lvl :=
+ let recurse v lvl := stringify_rec0 evalf ctx v lvl in
+ let name_of_var := find_head_in_ctx ctx correctness in
+ let weightf := lazymatch evalf with eval ?weightf _ => weightf | _ => I end in
+ let stringify_if testv t f :=
+ let stest := recurse testv 200 in
+ let st := recurse t 200 in
+ let sf := recurse f 200 in
+ maybe_parenthesize (("if " ++ stest ++ " then " ++ st ++ " else " ++ sf)%string) 200 lvl in
+ let show_Z _ :=
+ maybe_parenthesize (Show.Decimal.show_Z false correctness) 1 lvl in
+ let show_nat _ :=
+ maybe_parenthesize (Show.Decimal.show_nat false correctness) 1 lvl in
+ let stringify_prefix f natural arg_lvl :=
+ lazymatch correctness with
+ | ?F ?x
+ => let sx := recurse x arg_lvl in
+ maybe_parenthesize (f ++ sx)%string natural lvl
+ end in
+ let stringify_postfix f natural arg_lvl :=
+ lazymatch correctness with
+ | ?F ?x
+ => let sx := recurse x arg_lvl in
+ maybe_parenthesize (sx ++ f)%string natural lvl
+ end in
+ let stringify_infix' lvl space f natural l_lvl r_lvl :=
+ lazymatch correctness with
+ | ?F ?x ?y
+ => let sx := recurse x l_lvl in
+ let sy := recurse y r_lvl in
+ maybe_parenthesize (sx ++ space ++ f ++ space ++ sy)%string natural lvl
+ end in
+ let stringify_infix := stringify_infix' lvl " " in
+ let stringify_infix_without_space := stringify_infix' lvl "" in
+ let stringify_infix2 f1 f2 natural l_lvl c_lvl r_lvl :=
+ lazymatch correctness with
+ | and (?F1 ?x ?y) (?F2 ?y ?z)
+ => let sx := recurse x l_lvl in
+ let sy := recurse y c_lvl in
+ let sz := recurse z r_lvl in
+ maybe_parenthesize (sx ++ " " ++ f1 ++ " " ++ sy ++ " " ++ f2 ++ " " ++ sz)%string natural lvl
+ end in
+ let name_of_fun :=
+ lazymatch correctness with
+ | ?f ?x => find_in_ctx ctx f
+ | _ => constr:(@None string)
+ end in
+ lazymatch constr:((name_of_var, name_of_fun)) with
+ | (Some ?name, _)
+ => maybe_parenthesize name 1 lvl
+ | (None, Some ?name)
+ => lazymatch correctness with
+ | ?f ?x
+ => let sx := recurse x 9 in
+ maybe_parenthesize ((name ++ " " ++ sx)%string) 10 lvl
+ end
+ | (None, None)
+ => lazymatch correctness with
+ | ?x = ?y :> ?T
+ => lazymatch (eval cbv in T) with
+ | Z => let sx := recurse x 69 in
+ let sy := recurse y 69 in
+ maybe_parenthesize ((sx ++ " = " ++ sy)%string) 70 lvl
+ | list Z
+ => let sx := recurse x 69 in
+ let sy := recurse y 69 in
+ maybe_parenthesize ((sx ++ " = " ++ sy)%string) 70 lvl
+ | prod ?A ?B
+ => let v := (eval cbn [fst snd] in (fst x = fst y /\ snd x = snd y)) in
+ recurse v lvl
+ | ?T' => let __ := match goal with _ => idtac "Error: Unrecognized type for equality:" T' end in
+ constr:(I : I)
+ end
+ | evalf ?v
+ => let sv := recurse v 9 in
+ maybe_parenthesize (("eval " ++ sv)%string) 10 lvl
+ | weightf ?v
+ => let sv := recurse v 9 in
+ maybe_parenthesize (("weight " ++ sv)%string) 10 lvl
+ | eval (weight 8 1) _ ?v
+ => let sv := recurse v 9 in
+ maybe_parenthesize (("bytes_eval " ++ sv)%string) 10 lvl
+ | UniformWeight.uweight ?machine_wordsize ?v
+ => recurse (2^(machine_wordsize * Z.of_nat v)) lvl
+ | weight 8 1 ?i
+ => recurse (2^(8 * Z.of_nat i)) lvl
+ | List.map ?x ?y
+ => let sx := recurse x 9 in
+ let sy := recurse y 9 in
+ maybe_parenthesize (("map " ++ sx ++ " " ++ sy)%string) 10 lvl
+ | match ?testv with true => ?t | false => ?f end
+ => stringify_if testv t f
+ | match ?testv with or_introl _ => ?t | or_intror _ => ?f end
+ => stringify_if testv t f
+ | match ?testv with left _ => ?t | right _ => ?f end
+ => stringify_if testv t f
+ | Decidable.dec ?p
+ => recurse p lvl
+ | Z0 => show_Z ()
+ | Zpos _ => show_Z ()
+ | Zneg _ => show_Z ()
+ | O => show_nat ()
+ | S ?x
+ => let is_lit := is_literal x in
+ lazymatch is_lit with
+ | true => show_nat ()
+ | false => recurse (x + 1)%nat lvl
+ end
+ | Z.of_nat ?x => recurse x lvl
+ | ?x <= ?y < ?z => stringify_infix2 "≤" "<" 70 69 69 69
+ | ?x <= ?y <= ?z => stringify_infix2 "≤" "≤" 70 69 69 69
+ | ?x < ?y <= ?z => stringify_infix2 "<" "≤" 70 69 69 69
+ | ?x < ?y < ?z => stringify_infix2 "<" "<" 70 69 69 69
+ | iff _ _ => stringify_infix "↔" 95 94 94
+ | and _ _ => stringify_infix "∧" 80 80 80
+ | Z.modulo _ _ => stringify_infix "mod" 40 39 39
+ | Z.mul _ _ => stringify_infix "*" 40 40 39
+ | Z.pow _ _ => stringify_infix_without_space "^" 30 29 30
+ | Z.add _ _ => stringify_infix "+" 50 50 49
+ | Z.sub _ _ => stringify_infix "-" 50 50 49
+ | Z.opp _ => stringify_prefix "-" 35 35
+ | Z.le _ _ => stringify_infix "≤" 70 69 69
+ | Z.lt _ _ => stringify_infix "<" 70 69 69
+ | Nat.mul _ _ => stringify_infix "*" 40 40 39
+ | Nat.pow _ _ => stringify_infix "^" 30 29 30
+ | Nat.add _ _ => stringify_infix "+" 50 50 49
+ | Nat.sub _ _ => stringify_infix "-ℕ" 50 50 49
+ | Z.div _ _
+ => let res := stringify_infix' 69 " " "/" 40 40 39 in
+ maybe_parenthesize ("⌊" ++ res ++ "⌋") 9 lvl
+ | List.seq ?x ?y
+ => let sx := recurse x 9 in
+ let sy := recurse (pred y) 9 in
+ constr:("[" ++ sx ++ ".." ++ sy ++ "]")
+ | pred ?n
+ => let iv := test_is_var n in
+ let il := is_literal n in
+ lazymatch (eval cbv in (orb il iv)) with
+ | true => show_nat ()
+ | false
+ => recurse (n - 1)%nat lvl
+ end
+ | fun x : ?T => ?f
+ => let slam := stringify_function_binders ctx correctness ltac:(fun ctx body => stringify_rec0 evalf ctx body 200) in
+ maybe_parenthesize ("λ" ++ slam) 200 lvl
+ | ?v
+ => let iv := test_is_var v in
+ lazymatch iv with
+ | true
+ => let T := type of v in
+ lazymatch (eval hnf in T) with
+ | Z => show_Z ()
+ | nat => show_nat ()
+ | _
+ => let __ := match goal with _ => idtac "Error: Unrecognized var:" v " in " ctx end in
+ constr:(I : I)
+ end
+ | false
+ => let __ := match goal with _ => idtac "Error: Unrecognized term:" v " in " ctx end in
+ constr:(I : I)
+ end
+ end
+ end.
+
+ Ltac stringify_rec prefix evalf ctx correctness lvl :=
+ let recurse' prefix v lvl := stringify_rec prefix evalf ctx v lvl in
+ let recurse := recurse' "" in
+ let default _ := let v := stringify_rec0 evalf ctx correctness lvl in
+ constr:((prefix ++ v)::nil) in
+ lazymatch correctness with
+ | ?A -> ?B
+ => let sA := stringify_rec0 evalf ctx A 98 in
+ let sB := recurse B 200 in
+ constr:((prefix ++ sA ++ " →")%string :: sB)
+ | _ <= _ < _ => default ()
+ | _ <= _ <= _ => default ()
+ | _ < _ <= _ => default ()
+ | _ < _ < _ => default ()
+ | and ?A ?B
+ => let sA := recurse' prefix A 80 in
+ let sB := recurse' "∧ " B 80 in
+ constr:(List.app sA sB)
+ | ?x = ?y :> prod ?A ?B
+ => let v := (eval cbn [fst snd] in (fst x = fst y /\ snd x = snd y)) in
+ recurse' prefix v lvl
+ | _
+ => default ()
+ end.
+
+ Ltac strip_lambdas v :=
+ lazymatch v with
+ | fun _ => ?f => strip_lambdas f
+ | ?v => v
+ end.
+
+ Ltac stringify ctx correctness evalf fname arg_var_data out_var_data :=
+ let G := match goal with |- ?G => G end in
+ let correctness := (eval hnf in correctness) in
+ let correctness := (eval cbv [Partition.partition Arithmetic.WordByWordMontgomery.valid Arithmetic.WordByWordMontgomery.small] in correctness) in
+ let correctness := strip_bounds_info correctness in
+ let arg_var_names := constr:(type.map_for_each_lhs_of_arrow (@ToString.C.OfPHOAS.names_of_var_data) arg_var_data) in
+ let out_var_names := constr:(ToString.C.OfPHOAS.names_of_base_var_data out_var_data) in
+ let res := with_assoc_list
+ ctx
+ correctness
+ arg_var_names
+ out_var_names
+ ltac:(
+ fun ctx T
+ => let v := stringify_rec "" evalf ctx T 200 in refine v
+ ) in
+ let res := strip_lambdas res in
+ res.
+
+ Notation stringify_correctness_with_ctx ctx evalf pre_extra correctness
+ := (fun fname arg_var_data out_var_data
+ => ltac:(let res := stringify ctx correctness evalf fname arg_var_data out_var_data in
+ refine (List.app (pre_extra fname) res))) (only parsing).
+ Notation stringify_correctness evalf pre_extra correctness
+ := (match dyn_context.nil with
+ | ctx' => stringify_correctness_with_ctx ctx' evalf pre_extra correctness
+ end)
+ (only parsing).
+End CorrectnessStringification.
+
+Notation stringify_correctness_with_ctx ctx evalf pre_extra correctness
+ := (CorrectnessStringification.stringify_correctness_with_ctx ctx evalf pre_extra correctness) (only parsing).
+Notation stringify_correctness evalf pre_extra correctness
+ := (CorrectnessStringification.stringify_correctness evalf pre_extra correctness) (only parsing).
+
Section __.
Context (n : nat)
(machine_wordsize : Z).
@@ -199,6 +600,9 @@ Section __.
Proof using Type. cbv [saturated_bounds_list]; now autorewrite with distr_length. Qed.
Hint Rewrite length_saturated_bounds_list : distr_length.
+ Local Notation dummy_weight := (weight 0 0).
+ Local Notation evalf := (eval dummy_weight n).
+
Definition selectznz
:= Pipeline.BoundsPipeline
false (* subst01 *)
@@ -210,7 +614,13 @@ Section __.
Definition sselectznz (prefix : string)
: string * (Pipeline.ErrorT (list string * ToString.C.ident_infos))
- := Eval cbv beta in FromPipelineToString prefix "selectznz" selectznz.
+ := Eval cbv beta in
+ FromPipelineToString
+ prefix "selectznz" selectznz
+ (stringify_correctness
+ evalf
+ (fun fname : string => ["The function " ++ fname ++ " is a multi-limb conditional select."]%string)
+ (selectznz_correct dummy_weight n saturated_bounds_list)).
Definition mulx (s : Z)
:= Pipeline.BoundsPipeline
@@ -224,7 +634,13 @@ Section __.
Definition smulx (prefix : string) (s : Z)
: string * (Pipeline.ErrorT (list string * ToString.C.ident_infos))
- := Eval cbv beta in FromPipelineToString prefix ("mulx_u" ++ decimal_string_of_Z s) (mulx s).
+ := Eval cbv beta in
+ FromPipelineToString
+ prefix ("mulx_u" ++ decimal_string_of_Z s) (mulx s)
+ (stringify_correctness
+ evalf
+ (fun fname : string => ["The function " ++ fname ++ " is an extended multiplication."]%string)
+ (mulx_correct s)).
Definition addcarryx (s : Z)
:= Pipeline.BoundsPipeline
@@ -236,9 +652,16 @@ Section __.
(Some r[0~>1], (Some r[0~>2^s-1], (Some r[0~>2^s-1], tt)))%zrange
(Some r[0~>2^s-1], Some r[0~>1])%zrange.
+
Definition saddcarryx (prefix : string) (s : Z)
: string * (Pipeline.ErrorT (list string * ToString.C.ident_infos))
- := Eval cbv beta in FromPipelineToString prefix ("addcarryx_u" ++ decimal_string_of_Z s) (addcarryx s).
+ := Eval cbv beta in
+ FromPipelineToString
+ prefix ("addcarryx_u" ++ decimal_string_of_Z s) (addcarryx s)
+ (stringify_correctness
+ evalf
+ (fun fname : string => ["The function " ++ fname ++ " is an add with carry."]%string)
+ (addcarryx_correct s)).
Definition subborrowx (s : Z)
:= Pipeline.BoundsPipeline
@@ -252,7 +675,14 @@ Section __.
Definition ssubborrowx (prefix : string) (s : Z)
: string * (Pipeline.ErrorT (list string * ToString.C.ident_infos))
- := Eval cbv beta in FromPipelineToString prefix ("subborrowx_u" ++ decimal_string_of_Z s) (subborrowx s).
+ := Eval cbv beta in
+ FromPipelineToString
+ prefix ("subborrowx_u" ++ decimal_string_of_Z s) (subborrowx s)
+ (stringify_correctness
+ evalf
+ (fun fname : string => ["The function " ++ fname ++ " is a sub with borrow."]%string)
+ (subborrowx_correct s)).
+
Definition cmovznz (s : Z)
:= Pipeline.BoundsPipeline
@@ -266,7 +696,13 @@ Section __.
Definition scmovznz (prefix : string) (s : Z)
: string * (Pipeline.ErrorT (list string * ToString.C.ident_infos))
- := Eval cbv beta in FromPipelineToString prefix ("cmovznz_u" ++ decimal_string_of_Z s) (cmovznz s).
+ := Eval cbv beta in
+ FromPipelineToString
+ prefix ("cmovznz_u" ++ decimal_string_of_Z s) (cmovznz s)
+ (stringify_correctness
+ evalf
+ (fun fname : string => ["The function " ++ fname ++ " is a single-word conditional move."]%string)
+ (cmovznz_correct s)).
Local Ltac solve_extra_bounds_side_conditions :=
cbn [lower upper fst snd] in *; Bool.split_andb; Z.ltb_to_lt; lia.