diff options
author | Andres Erbsen <andreser@google.com> | 2017-12-19 12:28:19 -0500 |
---|---|---|
committer | Andres Erbsen <andreser@mit.edu> | 2017-12-22 12:55:33 -0500 |
commit | c3ee4f53a93c54ab2a65a5535e3335f4e8e8a3f5 (patch) | |
tree | 4710a627e26d28e8e87cc8fe5377e56e2f20a0ae | |
parent | 1b8932d64d40329678dcdb230fb5cc6a95064799 (diff) |
Montgomery.XZ, Loops: montladder proof scaffolding
-rw-r--r-- | src/Curves/Montgomery/XZ.v | 39 | ||||
-rw-r--r-- | src/Curves/Montgomery/XZProofs.v | 39 | ||||
-rw-r--r-- | src/Experiments/Loops.v | 346 |
3 files changed, 236 insertions, 188 deletions
diff --git a/src/Curves/Montgomery/XZ.v b/src/Curves/Montgomery/XZ.v index 87a53b7fe..336ee6b95 100644 --- a/src/Curves/Montgomery/XZ.v +++ b/src/Curves/Montgomery/XZ.v @@ -2,7 +2,7 @@ Require Import Crypto.Algebra.Field. Require Import Crypto.Util.GlobalSettings Crypto.Util.Notations. Require Import Crypto.Util.Sum Crypto.Util.Prod Crypto.Util.LetIn. Require Import Crypto.Util.Decidable. -Require Import Crypto.Util.ForLoop. +Require Import Crypto.Experiments.Loops. Require Import Crypto.Spec.MontgomeryCurve Crypto.Curves.Montgomery.Affine. Module M. @@ -110,26 +110,27 @@ Module M. ((x2, z2), (x3, z3))%core end. - Context {cswap:bool->F*F->F*F->(F*F)*(F*F)}. - + Context {cswap:bool->F->F->F*F}. Local Notation xor := Coq.Init.Datatypes.xorb. - - (* Ideally, we would verify that this corresponds to x coordinate - multiplication *) Local Open Scope core_scope. - Definition montladder (bound : positive) (testbit:Z->bool) (u:F) := - let '(P1, P2, swap) := - for (int i = BinInt.Z.pos bound; i >= 0; i--) - updating ('(P1, P2, swap) = ((1%F, 0%F), (u, 1%F), false)) {{ - dlet s_i := testbit i in - dlet swap := xor swap s_i in - let '(P1, P2) := cswap swap P1 P2 in - dlet swap := s_i in - let '(P1, P2) := xzladderstep u P1 P2 in - (P1, P2, swap) - }} in - let '((x, z), _) := cswap swap P1 P2 in - x * Finv z. + (* TODO: make a nice notations for loops like here *) + Definition montladder (scalarbits : Z) (testbit:Z->bool) (x1:F) : F := + let '(x2, z2, x3, z3, swap, _) := (* names of variables as used after the loop *) + (while (fun '(_, i) => BinInt.Z.geb i 0) (* the test of the loop *) + (fun '(x2, z2, x3, z3, swap, i) => (* names of variables as used in the loop; we should reuse the same names as for after the loop *) + dlet b := testbit i in (* the body... *) + dlet swap := xor swap b in + let (x2, x3) := cswap swap x2 x3 in + let (z2, z3) := cswap swap z2 z3 in + dlet swap := b in + let '((x2, z2), (x3, z3)) := xzladderstep x1 (x2, z2) (x3, z3) in + let i := BinInt.Z.pred i in (* the third "increment" component of a for loop; either between the test and body or just inlined into the body like here *) + (x2, z2, x3, z3, swap, i)) (* the "return value" of the body is always the exact same variable names as in the beginning of the body because we shadow the original binders, but I think for now this will be unavoidable boilerplate. *) + (BinInt.Z.to_nat scalarbits) (* bound on number of loop iterations, should come between test and body *) + (1%F, 0%F, x1, 1%F, false, BinInt.Z.pred scalarbits)) in (* initial values, these should come before the test and body *) + let (x2, x3) := cswap swap x2 x3 in + let (z2, z3) := cswap swap z2 z3 in + x2 * Finv z2. End MontgomeryCurve. End M. diff --git a/src/Curves/Montgomery/XZProofs.v b/src/Curves/Montgomery/XZProofs.v index fd2a7ee49..1e584f35f 100644 --- a/src/Curves/Montgomery/XZProofs.v +++ b/src/Curves/Montgomery/XZProofs.v @@ -138,12 +138,17 @@ Module M. Proof. cbv [ladder_invariant] in *. pose proof difference_preserved Q Q' as Hrw. - (* TODO: rewrite with in match argument with [sumwise (fieldwise (n:=2) Feq) (fun _ _ => True)] *) - match type of Hrw with - M.eq ?X ?Y => replace X with Y by admit - end. - assumption. - Admitted. + (* FIXME: what we actually want to do here is to rewrite with in + match argument with + [sumwise (fieldwise (n:=2) Feq) (fun _ _ => True)] *) + cbv [M.eq] in *; break_match; break_match_hyps; + destruct_head' @and; repeat split; subst; + try solve [intuition congruence]. + congruence (* congruence failed, idk WHY *) + || (match goal with + [H:?f1 = ?x1, G:?f = ?f1 |- ?f = ?x1] => rewrite G; exact H + end). + Qed. Lemma to_xz_add x1 xz x'z' Q Q' (Hxz : projective xz) (Hx'z' : projective x'z') @@ -179,5 +184,27 @@ Module M. Lemma projective_to_xz Q : projective (to_xz Q). Proof. t. Qed. + + Local Notation montladder := (M.montladder(a24:=a24)(Fadd:=Fadd)(Fsub:=Fsub)(Fmul:=Fmul)(Fzero:=Fzero)(Fone:=Fone)(Finv:=Finv)(cswap:=fun b x y => if b then pair y x else pair x y)). + Import Crypto.Experiments.Loops. + Lemma montladder_correct P scalarbits testbit point : P (montladder scalarbits testbit point). + Proof. + cbv beta delta [M.montladder]. + lazymatch goal with + | |- context [while ?t ?b ?l ?i] => eassert (_ (while t b l i) : Prop) + end. + lazymatch goal with + | |- ?e ?x + => eapply (while.by_invariant + (* loop invariant *) _ + (* decreasing measure *) (fun s => BinInt.Z.to_nat (snd s)) + e) + end. + { (* invariant start *) admit. } + { (* invariant preservation *) admit. } + { (* measure start *) admit. } + { (* measure decreases *) admit. } + { (* invariant implies postcondition *) admit. } + Abort. End MontgomeryCurve. End M. diff --git a/src/Experiments/Loops.v b/src/Experiments/Loops.v index 125e0f197..94e127764 100644 --- a/src/Experiments/Loops.v +++ b/src/Experiments/Loops.v @@ -1,10 +1,14 @@ Require Import Coq.Lists.List. Require Import Lia. -Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.DestructHead. Require Import Crypto.Util.Decidable. Require Import Crypto.Util.Prod. Require Import Crypto.Util.Option. +Require Import Crypto.Util.Sum. Require Import Crypto.Util.LetIn. +Require Crypto.Util.ListUtil. (* for tests *) Section Loops. Context {continue_state break_state} @@ -16,31 +20,31 @@ Section Loops. Definition funapp {A B} (f : A -> B) (x : A) := f x. Fixpoint loop_cps (fuel: nat) (start : continue_state) - {T} (ret : break_state -> T) : option T := + {T} (ret : break_state -> T) : continue_state + T := funapp (body_cps start _) (fun next => match next with - | inl state => Some (ret state) + | inl state => inr (ret state) | inr state => match fuel with - | O => None + | O => inl state | S fuel' => loop_cps fuel' state ret end end). Fixpoint loop (fuel: nat) (start : continue_state) - : option break_state := + : continue_state + break_state := match (body start) with - | inl state => Some state + | inl state => inr state | inr state => match fuel with - | O => None + | O => inl state | S fuel' => loop fuel' state end end. Lemma loop_break_step fuel start state : (body start = inl state) -> - loop fuel start = Some state. + loop fuel start = inr state. Proof. destruct fuel; simpl loop; break_match; intros; congruence. Qed. @@ -48,25 +52,42 @@ Section Loops. Lemma loop_continue_step fuel start state : (body start = inr state) -> loop fuel start = - match fuel with | O => None | S fuel' => loop fuel' state end. + match fuel with | O => inl state | S fuel' => loop fuel' state end. Proof. destruct fuel; simpl loop; break_match; intros; congruence. Qed. - Definition terminates fuel start := - loop fuel start <> None. + (* TODO: provide [invariant state] to proofs of this *) + Definition progress (measure : continue_state -> nat) := + forall state state', body state = inr state' -> measure state' < measure state. + Definition terminates fuel start := forall l, loop fuel start <> inl l. + Lemma terminates_by_measure measure (H : progress measure) : + forall fuel start, measure start <= fuel -> terminates fuel start. + Proof. + induction fuel; intros; + repeat match goal with + | _ => solve [ congruence | lia ] + | _ => progress cbv [progress terminates] in * + | _ => progress cbn [loop] + | _ => progress break_match + | H : forall _ _, body _ = inr _ -> _ , Heq : body _ = inr _ |- _ => specialize (H _ _ Heq) + | _ => eapply IHfuel + end. + Qed. Definition loop_default fuel start default - : break_state := match (loop fuel start) with - | None => default - | Some result => result - end. + : break_state := + sum_rect + (fun _ => break_state) + (fun _ => default) + (fun result => result) + (loop fuel start). Lemma loop_default_eq fuel start default (Hterm : terminates fuel start) : - loop fuel start = Some (loop_default fuel start default). + loop fuel start = inr (loop_default fuel start default). Proof. - cbv [terminates loop_default] in *; break_match; congruence. + cbv [terminates loop_default sum_rect] in *; break_match; congruence. Qed. Lemma invariant_iff fuel start default (H : terminates fuel start) P : @@ -77,188 +98,187 @@ Section Loops. /\ (forall s s', body s = inl s' -> inv s -> P s')). Proof. split; - [ exists (fun st => exists f e, (loop f st = Some e /\ P e )) + [ exists (fun st => exists f e, (loop f st = inr e /\ P e )) | destruct 1 as [?[??]]; revert dependent start; induction fuel ]; repeat match goal with | _ => solve [ trivial | congruence | eauto ] | _ => progress destruct_head' @ex | _ => progress destruct_head' @and | _ => progress intros - | _ => progress (cbv [loop_default terminates] in * ) - | _ => progress (cbn [loop] in * ) + | _ => progress cbv [loop_default terminates] in * + | _ => progress cbn [loop] in * | _ => progress erewrite loop_default_eq by eassumption | _ => progress erewrite loop_continue_step in * by eassumption | _ => progress erewrite loop_break_step in * by eassumption | _ => progress break_match_hyps | _ => progress break_match | _ => progress eexists + | H1:_, c:_ |- _ => progress specialize (H1 c); congruence end. Qed. - - Lemma to_measure (measure : continue_state -> nat) : - (forall state state', body state = inr state' -> - 0 <= measure state' < measure state) -> - forall fuel start, - measure start <= fuel -> - terminates fuel start. - Proof. - induction fuel; intros; - repeat match goal with - | _ => solve [ congruence | lia ] - | _ => progress cbv [terminates] - | _ => progress cbn [loop] - | _ => progress break_match - | H : forall _ _, body _ = inr _ -> _ , Heq : body _ = inr _ |- _ => specialize (H _ _ Heq) - | _ => eapply IHfuel - end. - Qed. End Loops. Definition by_invariant {continue_state break_state body fuel start default} - invariant P term invariant_start invariant_continue invariant_break - := proj2 (@invariant_iff continue_state break_state body fuel start default term P) + invariant measure P invariant_start invariant_continue invariant_break le_start progress + := proj2 (@invariant_iff continue_state break_state body fuel start default (terminates_by_measure body measure progress fuel start le_start) P) (ex_intro _ invariant (conj invariant_start (conj invariant_continue invariant_break))). +Arguments terminates_by_measure {_ _ _}. -Section While. - Context {state} - (test : state -> bool) - (body : state -> state). +Module while. + Section While. + Context {state} + (test : state -> bool) + (body : state -> state). - Fixpoint while (fuel: nat) (s : state) {struct fuel} : option state := - if test s - then - match fuel with - | O => None - | S fuel' => while fuel' (body s) - end - else Some s. + Fixpoint while (fuel: nat) (s : state) {struct fuel} : state := + if test s + then + let s := body s in + match fuel with + | O => s + | S fuel' => while fuel' s + end + else s. - Section AsLoop. - Local Definition lbody := fun s => if test s then inr (body s) else inl s. - Definition while_using_loop := loop lbody. + Section AsLoop. + Local Definition lbody := fun s => if test s then inr (body s) else inl s. - Lemma while_eq_loop : forall n s, while n s = while_using_loop n s. - Proof. - induction n; intros; - cbv [lbody while_using_loop]; cbn [while loop]; break_match; auto. - Qed. - End AsLoop. + Lemma eq_loop : forall fuel start, while fuel start = loop_default lbody fuel start (while fuel start). + Proof. + induction fuel; intros; + cbv [lbody loop_default sum_rect id] in *; + cbn [while loop]; [|rewrite IHfuel]; break_match; auto. + Qed. - Definition while_default d f s := - match while f s with - | None => d - | Some x => x - end. -End While. + Lemma by_invariant fuel start + (invariant : state -> Prop) (measure : state -> nat) (P : state -> Prop) + (_: invariant start) + (_: forall s, invariant s -> if test s then invariant (body s) else P s) + (_: measure start <= fuel) + (_: forall s, if test s then measure (body s) < measure s else True) + : P (while fuel start). + Proof. + rewrite eq_loop; cbv [lbody]. + eapply (by_invariant invariant measure); + repeat match goal with + | [ H : forall s, invariant s -> _, G: invariant ?s |- _ ] => unique pose proof (H _ G) + | [ H : forall s, if ?f s then _ else _, G: ?f ?s = _ |- _ ] => unique pose proof (H s) + | _ => solve [ trivial | congruence ] + | _ => progress cbv [progress] + | _ => progress intros + | _ => progress subst + | _ => progress inversion_sum + | _ => progress break_match_hyps (* FIXME: this must be last? *) + end. + Qed. + End AsLoop. + End While. + Arguments by_invariant {_ _ _ _ _}. +End while. +Notation while := while.while. Definition for2 {state} (test : state -> bool) (increment body : state -> state) := while test (fun s => increment (body s)). -Definition for2_default {state} d test increment body f s := - match @for2 state test increment body f s with - | None => d - | Some x => x - end. -Definition for3 {state} init test increment body fuel := @for2 state test increment body fuel init. -Definition for3_default {state} d init test increment body fuel := - match @for3 state init test increment body fuel with - | None => d - | Some x => x - end. +Definition for3 {state} init test increment body fuel := + @for2 state test increment body fuel init. -Section GCD. - Definition gcd_step := - fun '(a, b) => if Nat.ltb a b - then inr (a, b-a) - else if Nat.ltb b a - then inr (a-b, b) - else inl a. +Module _test. + Section GCD. + Definition gcd_step := + fun '(a, b) => if Nat.ltb a b + then inr (a, b-a) + else if Nat.ltb b a + then inr (a-b, b) + else inl a. - Definition gcd fuel a b := loop_default gcd_step fuel (a,b) 0. + Definition gcd fuel a b := loop_default gcd_step fuel (a,b) 0. - Eval cbv [gcd loop_default loop gcd_step] in (gcd 10 5 7). - - Example gcd_test : gcd 1000 28 35 = 7 := eq_refl. + (* Eval cbv [gcd loop_default loop gcd_step] in (gcd 10 5 7). *) + + Example gcd_test : gcd 1000 28 35 = 7 := eq_refl. - Definition gcd_step_cps - : (nat * nat) -> forall T, (nat + (nat * nat) -> T) -> T - := - fun st T ret => - let a := fst st in - let b := snd st in - if Nat.ltb a b - then ret (inr (a, b-a)) - else if Nat.ltb b a - then ret (inr (a-b, b)) - else ret (inl a). + Definition gcd_step_cps + : (nat * nat) -> forall T, (nat + (nat * nat) -> T) -> T + := + fun st T ret => + let a := fst st in + let b := snd st in + if Nat.ltb a b + then ret (inr (a, b-a)) + else if Nat.ltb b a + then ret (inr (a-b, b)) + else ret (inl a). - Definition gcd_cps fuel a b {T} (ret:nat->T) - := loop_cps gcd_step_cps fuel (a,b) ret. + Definition gcd_cps fuel a b {T} (ret:nat->T) + := loop_cps gcd_step_cps fuel (a,b) ret. - Example gcd_test2 : gcd_cps 1000 28 35 id = Some 7 := eq_refl. - - Eval cbv [gcd_cps loop_cps gcd_step_cps id] in (gcd_cps 2 5 7 id). + Example gcd_test2 : gcd_cps 1000 28 35 id = inr 7 := eq_refl. + + (* Eval cbv [gcd_cps loop_cps gcd_step_cps id] in (gcd_cps 2 5 7 id). *) -End GCD. + End GCD. -(* simple example--set all elements in a list to 0 *) -Section ZeroLoop. + (* simple example--set all elements in a list to 0 *) + Section ZeroLoop. + Import Crypto.Util.ListUtil. - Definition zero_body (state : nat * list nat) : - list nat + (nat * list nat) := - if dec (fst state < length (snd state)) - then inr (S (fst state), set_nth (fst state) 0 (snd state)) - else inl (snd state). + Definition zero_body (state : nat * list nat) : + list nat + (nat * list nat) := + if dec (fst state < length (snd state)) + then inr (S (fst state), set_nth (fst state) 0 (snd state)) + else inl (snd state). - Lemma zero_body_terminates (arr : list nat) : - terminates zero_body (length arr) (0,arr). - Proof. - eapply to_measure with (measure :=(fun state => length (snd state) - (fst state))); cbv [zero_body]; intros until 0; - repeat match goal with - | _ => progress autorewrite with cancel_pair distr_length - | _ => progress subst - | _ => progress break_match; intros - | _ => congruence - | H: inl _ = inl _ |- _ => injection H; intros; subst; clear H - | H: inr _ = inr _ |- _ => injection H; intros; subst ;clear H - | _ => lia - end. - Qed. + Lemma zero_body_progress (arr : list nat) : + progress zero_body (fun state : nat * list nat => length (snd state) - fst state). + Proof. + cbv [zero_body progress]; intros until 0; + repeat match goal with + | _ => progress autorewrite with cancel_pair distr_length + | _ => progress subst + | _ => progress break_match; intros + | _ => congruence + | H: inl _ = inl _ |- _ => injection H; intros; subst; clear H + | H: inr _ = inr _ |- _ => injection H; intros; subst ;clear H + | _ => lia + end. + Qed. - Definition zero_loop (arr : list nat) : list nat := - loop_default zero_body (length arr) (0,arr) nil. - - Definition zero_invariant (state : nat * list nat) := - fst state <= length (snd state) - /\ forall n, n < fst state -> nth_default 0 (snd state) n = 0. + Definition zero_loop (arr : list nat) : list nat := + loop_default zero_body (length arr) (0,arr) nil. + + Definition zero_invariant (state : nat * list nat) := + fst state <= length (snd state) + /\ forall n, n < fst state -> nth_default 0 (snd state) n = 0. - Lemma zero_correct (arr : list nat) : - forall n, nth_default 0 (zero_loop arr) n = 0. - Proof. - intros. cbv [zero_loop]. - eapply (by_invariant zero_invariant); auto using zero_body_terminates; - [ cbv [zero_invariant]; autorewrite with cancel_pair; split; intros; lia | ..]; - cbv [zero_invariant zero_body]; - intros until 0; - break_match; intros; - repeat match goal with - | _ => congruence - | H: inl _ = inl _ |- _ => injection H; intros; subst; clear H - | H: inr _ = inr _ |- _ => injection H; intros; subst ;clear H - | _ => progress split - | _ => progress intros - | _ => progress subst - | _ => progress (autorewrite with cancel_pair distr_length in * ) - | _ => rewrite set_nth_nth_default by lia - | _ => progress break_match - | H : _ /\ _ |- _ => destruct H - | H : (_,_) = ?x |- _ => - destruct x; inversion H; subst; destruct H - | H : _ |- _ => apply H; lia - | _ => lia - end. - destruct (Compare_dec.lt_dec n (fst s)). - apply H1; lia. - apply nth_default_out_of_bounds; lia. - Qed. -End ZeroLoop.
\ No newline at end of file + Lemma zero_correct (arr : list nat) : + forall n, nth_default 0 (zero_loop arr) n = 0. + Proof. + intros. cbv [zero_loop]. + eapply (by_invariant zero_invariant); eauto using zero_body_progress; + [ cbv [zero_invariant]; autorewrite with cancel_pair; split; intros; lia | ..]; + cbv [zero_invariant zero_body]; + intros until 0; + break_match; intros; + repeat match goal with + | _ => congruence + | H: inl _ = inl _ |- _ => injection H; intros; subst; clear H + | H: inr _ = inr _ |- _ => injection H; intros; subst ;clear H + | _ => progress split + | _ => progress intros + | _ => progress subst + | _ => progress (autorewrite with cancel_pair distr_length in * ) + | _ => rewrite set_nth_nth_default by lia + | _ => progress break_match + | H : _ /\ _ |- _ => destruct H + | H : (_,_) = ?x |- _ => + destruct x; inversion H; subst; destruct H + | H : _ |- _ => apply H; lia + | _ => lia + end. + destruct (Compare_dec.lt_dec n (fst s)). + apply H1; lia. + apply nth_default_out_of_bounds; lia. + Qed. + End ZeroLoop. +End _test.
\ No newline at end of file |