aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Andres Erbsen <andreser@google.com>2017-12-19 12:28:19 -0500
committerGravatar Andres Erbsen <andreser@mit.edu>2017-12-22 12:55:33 -0500
commitc3ee4f53a93c54ab2a65a5535e3335f4e8e8a3f5 (patch)
tree4710a627e26d28e8e87cc8fe5377e56e2f20a0ae
parent1b8932d64d40329678dcdb230fb5cc6a95064799 (diff)
Montgomery.XZ, Loops: montladder proof scaffolding
-rw-r--r--src/Curves/Montgomery/XZ.v39
-rw-r--r--src/Curves/Montgomery/XZProofs.v39
-rw-r--r--src/Experiments/Loops.v346
3 files changed, 236 insertions, 188 deletions
diff --git a/src/Curves/Montgomery/XZ.v b/src/Curves/Montgomery/XZ.v
index 87a53b7fe..336ee6b95 100644
--- a/src/Curves/Montgomery/XZ.v
+++ b/src/Curves/Montgomery/XZ.v
@@ -2,7 +2,7 @@ Require Import Crypto.Algebra.Field.
Require Import Crypto.Util.GlobalSettings Crypto.Util.Notations.
Require Import Crypto.Util.Sum Crypto.Util.Prod Crypto.Util.LetIn.
Require Import Crypto.Util.Decidable.
-Require Import Crypto.Util.ForLoop.
+Require Import Crypto.Experiments.Loops.
Require Import Crypto.Spec.MontgomeryCurve Crypto.Curves.Montgomery.Affine.
Module M.
@@ -110,26 +110,27 @@ Module M.
((x2, z2), (x3, z3))%core
end.
- Context {cswap:bool->F*F->F*F->(F*F)*(F*F)}.
-
+ Context {cswap:bool->F->F->F*F}.
Local Notation xor := Coq.Init.Datatypes.xorb.
-
- (* Ideally, we would verify that this corresponds to x coordinate
- multiplication *)
Local Open Scope core_scope.
- Definition montladder (bound : positive) (testbit:Z->bool) (u:F) :=
- let '(P1, P2, swap) :=
- for (int i = BinInt.Z.pos bound; i >= 0; i--)
- updating ('(P1, P2, swap) = ((1%F, 0%F), (u, 1%F), false)) {{
- dlet s_i := testbit i in
- dlet swap := xor swap s_i in
- let '(P1, P2) := cswap swap P1 P2 in
- dlet swap := s_i in
- let '(P1, P2) := xzladderstep u P1 P2 in
- (P1, P2, swap)
- }} in
- let '((x, z), _) := cswap swap P1 P2 in
- x * Finv z.
+ (* TODO: make a nice notations for loops like here *)
+ Definition montladder (scalarbits : Z) (testbit:Z->bool) (x1:F) : F :=
+ let '(x2, z2, x3, z3, swap, _) := (* names of variables as used after the loop *)
+ (while (fun '(_, i) => BinInt.Z.geb i 0) (* the test of the loop *)
+ (fun '(x2, z2, x3, z3, swap, i) => (* names of variables as used in the loop; we should reuse the same names as for after the loop *)
+ dlet b := testbit i in (* the body... *)
+ dlet swap := xor swap b in
+ let (x2, x3) := cswap swap x2 x3 in
+ let (z2, z3) := cswap swap z2 z3 in
+ dlet swap := b in
+ let '((x2, z2), (x3, z3)) := xzladderstep x1 (x2, z2) (x3, z3) in
+ let i := BinInt.Z.pred i in (* the third "increment" component of a for loop; either between the test and body or just inlined into the body like here *)
+ (x2, z2, x3, z3, swap, i)) (* the "return value" of the body is always the exact same variable names as in the beginning of the body because we shadow the original binders, but I think for now this will be unavoidable boilerplate. *)
+ (BinInt.Z.to_nat scalarbits) (* bound on number of loop iterations, should come between test and body *)
+ (1%F, 0%F, x1, 1%F, false, BinInt.Z.pred scalarbits)) in (* initial values, these should come before the test and body *)
+ let (x2, x3) := cswap swap x2 x3 in
+ let (z2, z3) := cswap swap z2 z3 in
+ x2 * Finv z2.
End MontgomeryCurve.
End M.
diff --git a/src/Curves/Montgomery/XZProofs.v b/src/Curves/Montgomery/XZProofs.v
index fd2a7ee49..1e584f35f 100644
--- a/src/Curves/Montgomery/XZProofs.v
+++ b/src/Curves/Montgomery/XZProofs.v
@@ -138,12 +138,17 @@ Module M.
Proof.
cbv [ladder_invariant] in *.
pose proof difference_preserved Q Q' as Hrw.
- (* TODO: rewrite with in match argument with [sumwise (fieldwise (n:=2) Feq) (fun _ _ => True)] *)
- match type of Hrw with
- M.eq ?X ?Y => replace X with Y by admit
- end.
- assumption.
- Admitted.
+ (* FIXME: what we actually want to do here is to rewrite with in
+ match argument with
+ [sumwise (fieldwise (n:=2) Feq) (fun _ _ => True)] *)
+ cbv [M.eq] in *; break_match; break_match_hyps;
+ destruct_head' @and; repeat split; subst;
+ try solve [intuition congruence].
+ congruence (* congruence failed, idk WHY *)
+ || (match goal with
+ [H:?f1 = ?x1, G:?f = ?f1 |- ?f = ?x1] => rewrite G; exact H
+ end).
+ Qed.
Lemma to_xz_add x1 xz x'z' Q Q'
(Hxz : projective xz) (Hx'z' : projective x'z')
@@ -179,5 +184,27 @@ Module M.
Lemma projective_to_xz Q : projective (to_xz Q).
Proof. t. Qed.
+
+ Local Notation montladder := (M.montladder(a24:=a24)(Fadd:=Fadd)(Fsub:=Fsub)(Fmul:=Fmul)(Fzero:=Fzero)(Fone:=Fone)(Finv:=Finv)(cswap:=fun b x y => if b then pair y x else pair x y)).
+ Import Crypto.Experiments.Loops.
+ Lemma montladder_correct P scalarbits testbit point : P (montladder scalarbits testbit point).
+ Proof.
+ cbv beta delta [M.montladder].
+ lazymatch goal with
+ | |- context [while ?t ?b ?l ?i] => eassert (_ (while t b l i) : Prop)
+ end.
+ lazymatch goal with
+ | |- ?e ?x
+ => eapply (while.by_invariant
+ (* loop invariant *) _
+ (* decreasing measure *) (fun s => BinInt.Z.to_nat (snd s))
+ e)
+ end.
+ { (* invariant start *) admit. }
+ { (* invariant preservation *) admit. }
+ { (* measure start *) admit. }
+ { (* measure decreases *) admit. }
+ { (* invariant implies postcondition *) admit. }
+ Abort.
End MontgomeryCurve.
End M.
diff --git a/src/Experiments/Loops.v b/src/Experiments/Loops.v
index 125e0f197..94e127764 100644
--- a/src/Experiments/Loops.v
+++ b/src/Experiments/Loops.v
@@ -1,10 +1,14 @@
Require Import Coq.Lists.List.
Require Import Lia.
-Require Import Crypto.Util.ListUtil.
+Require Import Crypto.Util.Tactics.UniquePose.
+Require Import Crypto.Util.Tactics.BreakMatch.
+Require Import Crypto.Util.Tactics.DestructHead.
Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.Prod.
Require Import Crypto.Util.Option.
+Require Import Crypto.Util.Sum.
Require Import Crypto.Util.LetIn.
+Require Crypto.Util.ListUtil. (* for tests *)
Section Loops.
Context {continue_state break_state}
@@ -16,31 +20,31 @@ Section Loops.
Definition funapp {A B} (f : A -> B) (x : A) := f x.
Fixpoint loop_cps (fuel: nat) (start : continue_state)
- {T} (ret : break_state -> T) : option T :=
+ {T} (ret : break_state -> T) : continue_state + T :=
funapp
(body_cps start _) (fun next =>
match next with
- | inl state => Some (ret state)
+ | inl state => inr (ret state)
| inr state =>
match fuel with
- | O => None
+ | O => inl state
| S fuel' =>
loop_cps fuel' state ret
end end).
Fixpoint loop (fuel: nat) (start : continue_state)
- : option break_state :=
+ : continue_state + break_state :=
match (body start) with
- | inl state => Some state
+ | inl state => inr state
| inr state =>
match fuel with
- | O => None
+ | O => inl state
| S fuel' => loop fuel' state
end end.
Lemma loop_break_step fuel start state :
(body start = inl state) ->
- loop fuel start = Some state.
+ loop fuel start = inr state.
Proof.
destruct fuel; simpl loop; break_match; intros; congruence.
Qed.
@@ -48,25 +52,42 @@ Section Loops.
Lemma loop_continue_step fuel start state :
(body start = inr state) ->
loop fuel start =
- match fuel with | O => None | S fuel' => loop fuel' state end.
+ match fuel with | O => inl state | S fuel' => loop fuel' state end.
Proof.
destruct fuel; simpl loop; break_match; intros; congruence.
Qed.
- Definition terminates fuel start :=
- loop fuel start <> None.
+ (* TODO: provide [invariant state] to proofs of this *)
+ Definition progress (measure : continue_state -> nat) :=
+ forall state state', body state = inr state' -> measure state' < measure state.
+ Definition terminates fuel start := forall l, loop fuel start <> inl l.
+ Lemma terminates_by_measure measure (H : progress measure) :
+ forall fuel start, measure start <= fuel -> terminates fuel start.
+ Proof.
+ induction fuel; intros;
+ repeat match goal with
+ | _ => solve [ congruence | lia ]
+ | _ => progress cbv [progress terminates] in *
+ | _ => progress cbn [loop]
+ | _ => progress break_match
+ | H : forall _ _, body _ = inr _ -> _ , Heq : body _ = inr _ |- _ => specialize (H _ _ Heq)
+ | _ => eapply IHfuel
+ end.
+ Qed.
Definition loop_default fuel start default
- : break_state := match (loop fuel start) with
- | None => default
- | Some result => result
- end.
+ : break_state :=
+ sum_rect
+ (fun _ => break_state)
+ (fun _ => default)
+ (fun result => result)
+ (loop fuel start).
Lemma loop_default_eq fuel start default
(Hterm : terminates fuel start) :
- loop fuel start = Some (loop_default fuel start default).
+ loop fuel start = inr (loop_default fuel start default).
Proof.
- cbv [terminates loop_default] in *; break_match; congruence.
+ cbv [terminates loop_default sum_rect] in *; break_match; congruence.
Qed.
Lemma invariant_iff fuel start default (H : terminates fuel start) P :
@@ -77,188 +98,187 @@ Section Loops.
/\ (forall s s', body s = inl s' -> inv s -> P s')).
Proof.
split;
- [ exists (fun st => exists f e, (loop f st = Some e /\ P e ))
+ [ exists (fun st => exists f e, (loop f st = inr e /\ P e ))
| destruct 1 as [?[??]]; revert dependent start; induction fuel ];
repeat match goal with
| _ => solve [ trivial | congruence | eauto ]
| _ => progress destruct_head' @ex
| _ => progress destruct_head' @and
| _ => progress intros
- | _ => progress (cbv [loop_default terminates] in * )
- | _ => progress (cbn [loop] in * )
+ | _ => progress cbv [loop_default terminates] in *
+ | _ => progress cbn [loop] in *
| _ => progress erewrite loop_default_eq by eassumption
| _ => progress erewrite loop_continue_step in * by eassumption
| _ => progress erewrite loop_break_step in * by eassumption
| _ => progress break_match_hyps
| _ => progress break_match
| _ => progress eexists
+ | H1:_, c:_ |- _ => progress specialize (H1 c); congruence
end.
Qed.
-
- Lemma to_measure (measure : continue_state -> nat) :
- (forall state state', body state = inr state' ->
- 0 <= measure state' < measure state) ->
- forall fuel start,
- measure start <= fuel ->
- terminates fuel start.
- Proof.
- induction fuel; intros;
- repeat match goal with
- | _ => solve [ congruence | lia ]
- | _ => progress cbv [terminates]
- | _ => progress cbn [loop]
- | _ => progress break_match
- | H : forall _ _, body _ = inr _ -> _ , Heq : body _ = inr _ |- _ => specialize (H _ _ Heq)
- | _ => eapply IHfuel
- end.
- Qed.
End Loops.
Definition by_invariant {continue_state break_state body fuel start default}
- invariant P term invariant_start invariant_continue invariant_break
- := proj2 (@invariant_iff continue_state break_state body fuel start default term P)
+ invariant measure P invariant_start invariant_continue invariant_break le_start progress
+ := proj2 (@invariant_iff continue_state break_state body fuel start default (terminates_by_measure body measure progress fuel start le_start) P)
(ex_intro _ invariant (conj invariant_start (conj invariant_continue invariant_break))).
+Arguments terminates_by_measure {_ _ _}.
-Section While.
- Context {state}
- (test : state -> bool)
- (body : state -> state).
+Module while.
+ Section While.
+ Context {state}
+ (test : state -> bool)
+ (body : state -> state).
- Fixpoint while (fuel: nat) (s : state) {struct fuel} : option state :=
- if test s
- then
- match fuel with
- | O => None
- | S fuel' => while fuel' (body s)
- end
- else Some s.
+ Fixpoint while (fuel: nat) (s : state) {struct fuel} : state :=
+ if test s
+ then
+ let s := body s in
+ match fuel with
+ | O => s
+ | S fuel' => while fuel' s
+ end
+ else s.
- Section AsLoop.
- Local Definition lbody := fun s => if test s then inr (body s) else inl s.
- Definition while_using_loop := loop lbody.
+ Section AsLoop.
+ Local Definition lbody := fun s => if test s then inr (body s) else inl s.
- Lemma while_eq_loop : forall n s, while n s = while_using_loop n s.
- Proof.
- induction n; intros;
- cbv [lbody while_using_loop]; cbn [while loop]; break_match; auto.
- Qed.
- End AsLoop.
+ Lemma eq_loop : forall fuel start, while fuel start = loop_default lbody fuel start (while fuel start).
+ Proof.
+ induction fuel; intros;
+ cbv [lbody loop_default sum_rect id] in *;
+ cbn [while loop]; [|rewrite IHfuel]; break_match; auto.
+ Qed.
- Definition while_default d f s :=
- match while f s with
- | None => d
- | Some x => x
- end.
-End While.
+ Lemma by_invariant fuel start
+ (invariant : state -> Prop) (measure : state -> nat) (P : state -> Prop)
+ (_: invariant start)
+ (_: forall s, invariant s -> if test s then invariant (body s) else P s)
+ (_: measure start <= fuel)
+ (_: forall s, if test s then measure (body s) < measure s else True)
+ : P (while fuel start).
+ Proof.
+ rewrite eq_loop; cbv [lbody].
+ eapply (by_invariant invariant measure);
+ repeat match goal with
+ | [ H : forall s, invariant s -> _, G: invariant ?s |- _ ] => unique pose proof (H _ G)
+ | [ H : forall s, if ?f s then _ else _, G: ?f ?s = _ |- _ ] => unique pose proof (H s)
+ | _ => solve [ trivial | congruence ]
+ | _ => progress cbv [progress]
+ | _ => progress intros
+ | _ => progress subst
+ | _ => progress inversion_sum
+ | _ => progress break_match_hyps (* FIXME: this must be last? *)
+ end.
+ Qed.
+ End AsLoop.
+ End While.
+ Arguments by_invariant {_ _ _ _ _}.
+End while.
+Notation while := while.while.
Definition for2 {state} (test : state -> bool) (increment body : state -> state)
:= while test (fun s => increment (body s)).
-Definition for2_default {state} d test increment body f s :=
- match @for2 state test increment body f s with
- | None => d
- | Some x => x
- end.
-Definition for3 {state} init test increment body fuel := @for2 state test increment body fuel init.
-Definition for3_default {state} d init test increment body fuel :=
- match @for3 state init test increment body fuel with
- | None => d
- | Some x => x
- end.
+Definition for3 {state} init test increment body fuel :=
+ @for2 state test increment body fuel init.
-Section GCD.
- Definition gcd_step :=
- fun '(a, b) => if Nat.ltb a b
- then inr (a, b-a)
- else if Nat.ltb b a
- then inr (a-b, b)
- else inl a.
+Module _test.
+ Section GCD.
+ Definition gcd_step :=
+ fun '(a, b) => if Nat.ltb a b
+ then inr (a, b-a)
+ else if Nat.ltb b a
+ then inr (a-b, b)
+ else inl a.
- Definition gcd fuel a b := loop_default gcd_step fuel (a,b) 0.
+ Definition gcd fuel a b := loop_default gcd_step fuel (a,b) 0.
- Eval cbv [gcd loop_default loop gcd_step] in (gcd 10 5 7).
-
- Example gcd_test : gcd 1000 28 35 = 7 := eq_refl.
+ (* Eval cbv [gcd loop_default loop gcd_step] in (gcd 10 5 7). *)
+
+ Example gcd_test : gcd 1000 28 35 = 7 := eq_refl.
- Definition gcd_step_cps
- : (nat * nat) -> forall T, (nat + (nat * nat) -> T) -> T
- :=
- fun st T ret =>
- let a := fst st in
- let b := snd st in
- if Nat.ltb a b
- then ret (inr (a, b-a))
- else if Nat.ltb b a
- then ret (inr (a-b, b))
- else ret (inl a).
+ Definition gcd_step_cps
+ : (nat * nat) -> forall T, (nat + (nat * nat) -> T) -> T
+ :=
+ fun st T ret =>
+ let a := fst st in
+ let b := snd st in
+ if Nat.ltb a b
+ then ret (inr (a, b-a))
+ else if Nat.ltb b a
+ then ret (inr (a-b, b))
+ else ret (inl a).
- Definition gcd_cps fuel a b {T} (ret:nat->T)
- := loop_cps gcd_step_cps fuel (a,b) ret.
+ Definition gcd_cps fuel a b {T} (ret:nat->T)
+ := loop_cps gcd_step_cps fuel (a,b) ret.
- Example gcd_test2 : gcd_cps 1000 28 35 id = Some 7 := eq_refl.
-
- Eval cbv [gcd_cps loop_cps gcd_step_cps id] in (gcd_cps 2 5 7 id).
+ Example gcd_test2 : gcd_cps 1000 28 35 id = inr 7 := eq_refl.
+
+ (* Eval cbv [gcd_cps loop_cps gcd_step_cps id] in (gcd_cps 2 5 7 id). *)
-End GCD.
+ End GCD.
-(* simple example--set all elements in a list to 0 *)
-Section ZeroLoop.
+ (* simple example--set all elements in a list to 0 *)
+ Section ZeroLoop.
+ Import Crypto.Util.ListUtil.
- Definition zero_body (state : nat * list nat) :
- list nat + (nat * list nat) :=
- if dec (fst state < length (snd state))
- then inr (S (fst state), set_nth (fst state) 0 (snd state))
- else inl (snd state).
+ Definition zero_body (state : nat * list nat) :
+ list nat + (nat * list nat) :=
+ if dec (fst state < length (snd state))
+ then inr (S (fst state), set_nth (fst state) 0 (snd state))
+ else inl (snd state).
- Lemma zero_body_terminates (arr : list nat) :
- terminates zero_body (length arr) (0,arr).
- Proof.
- eapply to_measure with (measure :=(fun state => length (snd state) - (fst state))); cbv [zero_body]; intros until 0;
- repeat match goal with
- | _ => progress autorewrite with cancel_pair distr_length
- | _ => progress subst
- | _ => progress break_match; intros
- | _ => congruence
- | H: inl _ = inl _ |- _ => injection H; intros; subst; clear H
- | H: inr _ = inr _ |- _ => injection H; intros; subst ;clear H
- | _ => lia
- end.
- Qed.
+ Lemma zero_body_progress (arr : list nat) :
+ progress zero_body (fun state : nat * list nat => length (snd state) - fst state).
+ Proof.
+ cbv [zero_body progress]; intros until 0;
+ repeat match goal with
+ | _ => progress autorewrite with cancel_pair distr_length
+ | _ => progress subst
+ | _ => progress break_match; intros
+ | _ => congruence
+ | H: inl _ = inl _ |- _ => injection H; intros; subst; clear H
+ | H: inr _ = inr _ |- _ => injection H; intros; subst ;clear H
+ | _ => lia
+ end.
+ Qed.
- Definition zero_loop (arr : list nat) : list nat :=
- loop_default zero_body (length arr) (0,arr) nil.
-
- Definition zero_invariant (state : nat * list nat) :=
- fst state <= length (snd state)
- /\ forall n, n < fst state -> nth_default 0 (snd state) n = 0.
+ Definition zero_loop (arr : list nat) : list nat :=
+ loop_default zero_body (length arr) (0,arr) nil.
+
+ Definition zero_invariant (state : nat * list nat) :=
+ fst state <= length (snd state)
+ /\ forall n, n < fst state -> nth_default 0 (snd state) n = 0.
- Lemma zero_correct (arr : list nat) :
- forall n, nth_default 0 (zero_loop arr) n = 0.
- Proof.
- intros. cbv [zero_loop].
- eapply (by_invariant zero_invariant); auto using zero_body_terminates;
- [ cbv [zero_invariant]; autorewrite with cancel_pair; split; intros; lia | ..];
- cbv [zero_invariant zero_body];
- intros until 0;
- break_match; intros;
- repeat match goal with
- | _ => congruence
- | H: inl _ = inl _ |- _ => injection H; intros; subst; clear H
- | H: inr _ = inr _ |- _ => injection H; intros; subst ;clear H
- | _ => progress split
- | _ => progress intros
- | _ => progress subst
- | _ => progress (autorewrite with cancel_pair distr_length in * )
- | _ => rewrite set_nth_nth_default by lia
- | _ => progress break_match
- | H : _ /\ _ |- _ => destruct H
- | H : (_,_) = ?x |- _ =>
- destruct x; inversion H; subst; destruct H
- | H : _ |- _ => apply H; lia
- | _ => lia
- end.
- destruct (Compare_dec.lt_dec n (fst s)).
- apply H1; lia.
- apply nth_default_out_of_bounds; lia.
- Qed.
-End ZeroLoop. \ No newline at end of file
+ Lemma zero_correct (arr : list nat) :
+ forall n, nth_default 0 (zero_loop arr) n = 0.
+ Proof.
+ intros. cbv [zero_loop].
+ eapply (by_invariant zero_invariant); eauto using zero_body_progress;
+ [ cbv [zero_invariant]; autorewrite with cancel_pair; split; intros; lia | ..];
+ cbv [zero_invariant zero_body];
+ intros until 0;
+ break_match; intros;
+ repeat match goal with
+ | _ => congruence
+ | H: inl _ = inl _ |- _ => injection H; intros; subst; clear H
+ | H: inr _ = inr _ |- _ => injection H; intros; subst ;clear H
+ | _ => progress split
+ | _ => progress intros
+ | _ => progress subst
+ | _ => progress (autorewrite with cancel_pair distr_length in * )
+ | _ => rewrite set_nth_nth_default by lia
+ | _ => progress break_match
+ | H : _ /\ _ |- _ => destruct H
+ | H : (_,_) = ?x |- _ =>
+ destruct x; inversion H; subst; destruct H
+ | H : _ |- _ => apply H; lia
+ | _ => lia
+ end.
+ destruct (Compare_dec.lt_dec n (fst s)).
+ apply H1; lia.
+ apply nth_default_out_of_bounds; lia.
+ Qed.
+ End ZeroLoop.
+End _test. \ No newline at end of file