aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-06-11 15:24:22 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-06-11 15:24:22 -0400
commit9b6d577171419a8478f6cb1649020e7237a962f7 (patch)
treea5a7268bd2455f45bd8640975abfb4592235ca12 /src
parent165bc0e8b8f1f62eea3d9dafab14425050f1f57a (diff)
Fix loop notations, add for loops
Diffstat (limited to 'src')
-rw-r--r--src/Util/Loop.v224
1 files changed, 203 insertions, 21 deletions
diff --git a/src/Util/Loop.v b/src/Util/Loop.v
index ac763b11f..66b229281 100644
--- a/src/Util/Loop.v
+++ b/src/Util/Loop.v
@@ -1,7 +1,11 @@
(** * Definition and Notations for [do { body }] *)
+Require Import Coq.ZArith.BinInt.
+Require Import Coq.micromega.Lia.
Require Import Coq.omega.Omega.
+Require Import Crypto.Util.ZUtil.
Require Import Crypto.Util.Notations.
Require Import Crypto.Util.LetIn.
+Require Import Crypto.Util.Tactics.SpecializeBy.
(* TODO: move *)
Module CPSNotations.
@@ -11,7 +15,7 @@ Module CPSNotations.
(* TODO: [cpscall] is a marker to get Coq to print code using this notation only when it was actually used *)
Definition cpscall {R} (f:forall{T}(continuation:R->T),T) {T} (continuation:R->T)
:= @f T continuation.
- Notation "x <- v ; C" := (cpscall v (fun x => C)) (at level 70, right associativity, format "'[v' x <- v ; '/' C ']'").
+ Notation "x' <- v ; C" := (cpscall v (fun x' => C)).
(** A value of type [~>R] accepts a continuation that takes an
argument of type [R]. It is meant to be used in [Definition] and
@@ -23,7 +27,7 @@ Module CPSNotations.
[~> R] is universally quantified over the possible return types
of the continuations that it can be applied to.
*)
- Notation "~> R" := (forall {T} (_:R->T), T) (at level 70).
+ Notation "~> R" := (forall {T} (_:R->T), T).
(** The type [A ~> R] contains functions that takes an argument of
type [A] and pass a value of type [R] to the continuation. Functions
@@ -31,12 +35,12 @@ Module CPSNotations.
~> B ~>C] -- the first form requires both arguments to be specified
before its output can be CPS-bound, the latter must be bound once it
is partially applied to one argument. *)
- Notation "A ~> R" := (A -> ~>R) (at level 99).
-
+ Notation "A ~> R" := (A -> ~>R).
+
(* TODO: [cpsreturn] is a marker to get Coq to print loop notations before a [return] *)
Definition cpsreturn {T} (x:T) := x.
(** [return x] passes [x] to the continuation implicit in the previous notations. *)
- Notation "'return' x" := (cpsreturn (fun {T} (continuation:_->T) => continuation x)) (at level 70, format "'return' x").
+ Notation "'return' x" := (cpsreturn (fun {T} (continuation:_->T) => continuation x)).
End CPSNotations.
Section with_state.
@@ -104,7 +108,8 @@ Section with_state.
Local Hint Extern 2 => omega.
- Theorem loop_cps_wf_ind
+ (** TODO(andreser): Remove this if we don't need it *)
+ Theorem loop_cps_wf_ind_break
(measure : state -> nat)
(invariant : state -> Prop)
T (P : T -> Prop) n v0 body rest
@@ -122,25 +127,203 @@ Section with_state.
induction n as [|n IHn]; intros v0 rest Hinv Hbody Hmeasure; simpl; try omega.
edestruct Hbody as [Hbody'|Hbody']; eauto.
Qed.
+
+ Theorem loop_cps_wf_ind
+ (measure : state -> nat)
+ (invariant : state -> Prop)
+ T (P : T -> Prop) n v0 body rest
+ : invariant v0
+ -> (forall v continue,
+ invariant v
+ -> ((forall v', measure v' < measure v -> invariant v' -> P (continue v'))
+ -> P (body v T continue rest)))
+ -> measure v0 < n
+ -> P (loop_cps n v0 body T rest).
+ Proof.
+ revert v0.
+ induction n as [|n IHn]; intros v0 Hinv Hbody Hmeasure; simpl; try omega.
+ eauto.
+ Qed.
End with_state.
-(* N.B., Coq doesn't yet print this *)
+(** N.B. If the body is polymorphic (that is, if the type argument
+ shows up in the body), then we need to bind the name of the type
+ parameter somewhere in the notation for it to show up; we have a
+ separate notation for this case. *)
+(** TODO: When these notations are finalized, reserve them in Notations.v and moving the level and formatting rules there *)
+Notation "'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue , break , @ T ) {{ body }} ; rest"
+ := (@loop_cps _ fuel initial
+ (fun state1 => .. (fun staten => id (fun T continue break => body)) .. )
+ _ (fun state1 => .. (fun staten => rest) .. ))
+ (at level 200, state1 binder, staten binder,
+ format "'[v ' 'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue , break , @ T ) {{ '//' body ']' '//' }} ; '//' rest").
Notation "'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue , break ) {{ body }} ; rest"
:= (@loop_cps _ fuel initial
(fun state1 => .. (fun staten => id (fun T continue break => body)) .. )
_ (fun state1 => .. (fun staten => rest) .. ))
(at level 200, state1 binder, staten binder,
format "'[v ' 'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue , break ) {{ '//' body ']' '//' }} ; '//' rest").
+Notation "'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue ) {{ body }} ; rest"
+ := (@loop_cps _ fuel initial
+ (fun state1 => .. (fun staten => id (fun T continue _ => body)) .. )
+ _ (fun state1 => .. (fun staten => rest) .. ))
+ (at level 200, state1 binder, staten binder,
+ format "'[v ' 'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue ) {{ '//' body ']' '//' }} ; '//' rest").
+
+Section with_for_state.
+ Import CPSNotations.
+ Context {state : Type}.
+
+ Section with_loop_params.
+ Context (test : Z -> Z -> bool) (i_final : Z) (upd_i : Z -> Z)
+ (body : state -> Z -> forall {T} (continue : state -> T) (break : state -> T), T).
+
+ (* we assume that [upd_i] is linear to compute the fuel *)
+ Definition for_cps (i0 : Z) (initial : state)
+ : ~> state
+ := fun T ret
+ => @loop_cps
+ (Z * state)
+ (S (S (Z.to_nat ((i_final - i0) / (upd_i 0%Z)))))
+ (i0, initial)
+ (fun '(i, st) T continue break
+ => if test i i_final
+ then @body st i T
+ (fun st' => continue (upd_i i, st')%Z)
+ (fun st' => break (i, st'))
+ else break (i, st))
+ T (fun '(i, st) => ret st).
+
+ Section lemmas.
+ Local Open Scope Z_scope.
+ Context (upd_linear : forall x, upd_i x = upd_i 0 + x)
+ (upd_signed : forall i0, test i0 i_final = true -> 0 < (i_final - i0) / (upd_i 0)).
+
+ (** TODO: Strengthen this to take into account the value of
+ the loop counter at the end of the loop; based on
+ [ForLoop.v], it should be something like [(finish -
+ Z.sgn (finish - i0 + step - Z.sgn step) * (Z.abs
+ (finish - i0 + step - Z.sgn step) mod Z.abs step) +
+ step - Z.sgn step)] *)
+ Theorem for_cps_ind
+ (invariant : Z -> state -> Prop)
+ T (P : (*Z ->*) T -> Prop) i0 v0 rest
+ : invariant i0 v0
+ -> (forall i v continue,
+ test i i_final = true
+ -> (forall v, invariant (upd_i i) v -> P (continue v))
+ -> invariant i v
+ -> P (@body v i T continue rest))
+ -> (forall i v, test i i_final = false -> invariant i v -> P (rest v))
+ -> P (for_cps i0 v0 T rest).
+ Proof.
+ unfold for_cps, cpscall, cpsreturn.
+ intros Hinv IH Hrest.
+ eapply @loop_cps_wf_ind with (T:=T)
+ (invariant := fun '(i, s) => invariant i s)
+ (measure := fun '(i, s) => S (Z.to_nat ((i_final - i) / upd_i 0)));
+ [ assumption
+ |
+ | omega ].
+ intros [i st] continue Hinv' IH'.
+ destruct (test i i_final) eqn:Hi; [ | solve [ eauto ] ].
+ pose proof (upd_signed _ Hi) as upd_signed'.
+ assert (upd_i 0 <> 0)
+ by (intro H'; rewrite H' in upd_signed'; autorewrite with zsimplify in upd_signed';
+ omega).
+ specialize (IH i st (fun st' => continue (upd_i i, st')) Hi).
+ specialize (fun v pf => IH' (upd_i i, v) pf).
+ cbv beta iota in *.
+ specialize (fun pf v => IH' v pf).
+ rewrite upd_linear in IH'.
+ replace ((i_final - (upd_i 0 + i)) / upd_i 0)
+ with ((i_final - i) / upd_i 0 - 1)
+ in IH'
+ by (Z.div_mod_to_quot_rem; nia).
+ rewrite <- upd_linear, Z2Nat.inj_sub in IH' by omega.
+ assert ((Z.to_nat 0 < Z.to_nat ((i_final - i) / upd_i 0))%nat)
+ by (apply Z2Nat.inj_lt; omega).
+ change (Z.to_nat 0) with 0%nat in *.
+ change (Z.to_nat 1) with 1%nat in *.
+ auto with omega.
+ Qed.
+ End lemmas.
+ End with_loop_params.
+End with_for_state.
+
+Delimit Scope for_upd_scope with for_upd.
+Delimit Scope for_test_scope with for_test.
+Notation "i += k" := (Z.add i k) : for_upd_scope.
+Notation "i -= k" := (Z.sub i k) : for_upd_scope.
+Notation "i ++" := (i += 1)%for_upd : for_upd_scope.
+Notation "i --" := (i -= 1)%for_upd : for_upd_scope.
+Notation "<" := Z.ltb (at level 71) : for_test_scope.
+Notation ">" := Z.gtb (at level 71) : for_test_scope.
+Notation "<=" := Z.leb (at level 71) : for_test_scope.
+Notation ">=" := Z.geb (at level 71) : for_test_scope.
+Notation "≤" := Z.leb (at level 71) : for_test_scope.
+Notation "≥" := Z.geb (at level 71) : for_test_scope.
+Definition force_idZ (f : Z -> Z) (pf : f = id) {T} (v : T) := v.
+(** [lhs] and [cmp_expr] go at level 9 so that they bind more tightly
+ than application (so that [i (<)] sticks [i] in [lhs] and [(<)] in
+ [cmp_expr], rather than sticking [i (<)] in [lhs] and then
+ complaining about a missing value for [cmp_expr]. Unfortunately,
+ because the comparison operators need to be at level > 70 to not
+ conflict with their infix versions, putting [cmp_expr] at level 9
+ forces us to wrap parentheses around the comparison operator. *)
+(** TODO(andreser): If it's worth it, duplicate these notations for
+ each value of [cmp_expr] so that we don't need to wrap the
+ comparison operator in parentheses. *)
+(** TODO: When these notations are finalized, reserve them in Notations.v and moving the level and formatting rules there *)
+Notation "'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue , break , @ T ) {{ body }} ; rest"
+ := (force_idZ
+ (fun i1 => .. (fun i2 => lhs) ..)
+ eq_refl
+ (@for_cps _ cmp_expr%for_test
+ final
+ (fun i1 => .. (fun i2 => upd_expr%for_upd) .. )
+ (fun state1 => .. (fun staten => id (fun i1 => .. (fun i2 => id (fun T continue break => body)) .. )) .. )
+ i0
+ initial
+ _ (fun state1 => .. (fun staten => rest) .. )))
+ (at level 200, state1 binder, staten binder, i1 binder, i2 binder, lhs at level 9, cmp_expr at level 9,
+ format "'[v ' 'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue , break , @ T ) {{ '//' body ']' '//' }} ; '//' rest").
+Notation "'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue , break ) {{ body }} ; rest"
+ := (force_idZ
+ (fun i1 => .. (fun i2 => lhs) ..)
+ eq_refl
+ (@for_cps _ cmp_expr%for_test
+ final
+ (fun i1 => .. (fun i2 => upd_expr%for_upd) .. )
+ (fun state1 => .. (fun staten => id (fun i1 => .. (fun i2 => id (fun T continue break => body)) .. )) .. )
+ i0
+ initial
+ _ (fun state1 => .. (fun staten => rest) .. )))
+ (at level 200, state1 binder, staten binder, i1 binder, i2 binder, lhs at level 9, cmp_expr at level 9,
+ format "'[v ' 'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue , break ) {{ '//' body ']' '//' }} ; '//' rest").
+Notation "'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue ) {{ body }} ; rest"
+ := (force_idZ
+ (fun i1 => .. (fun i2 => lhs) ..)
+ eq_refl
+ (@for_cps _ cmp_expr%for_test
+ final
+ (fun i1 => .. (fun i2 => upd_expr%for_upd) .. )
+ (fun state1 => .. (fun staten => id (fun i1 => .. (fun i2 => id (fun T continue break => body)) .. )) .. )
+ i0
+ initial
+ _ (fun state1 => .. (fun staten => rest) .. )))
+ (at level 200, state1 binder, staten binder, i1 binder, i2 binder, lhs at level 9, cmp_expr at level 9,
+ format "'[v ' 'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue ) {{ '//' body ']' '//' }} ; '//' rest").
Section LoopTest.
Import CPSNotations.
Check loop _{ 10 } (x = 0) labels (continue, break) {{ continue (x + 1) }} ; x.
-
+
Check
loop _{ 1234 } ('(i, a) = (0, 0)) labels (continue, break)
{{
if i <? 10
- then
+ then
continue (i + 1, a+1)
else
break (0, a)
@@ -155,27 +338,26 @@ Section LoopTest.
Check loop _{ 10 } (x = 0) labels (continue, break) {{ continue (x + 1) }} ;
return x.
- Check loop _{ 10 } (x = 0) labels (continue, break) {{ continue (x + 1) }} ;
+ Check loop _{ 10 } (x = 0) labels (continue, break) {{ continue (x + 1) }} ;
x <- f x;
return x.
-
- (* TODO: the loop notation should print here. *)
- Check loop _{ 10 } (x = 0) labels (continue, break) {{ x <- f x; continue (x) }} ; x.
- (*
- (* TODO LATER: something like these notations would be nice, desugaring to a [state := nat * T] *)
- for ( i = s; i < f; i++) updating (P = zero) labels (continue, break)
+ Check loop _{ 10 } (x = 0) labels (continue, break) {{ x <- f x ; continue (x) }} ; x.
+
+
+ Axiom s F : Z.
+ Axiom zero : nat.
+ Check for ( i = s; i (<) F; i++) updating (P = zero) labels (continue, break)
{{
continue (P+P)
}};
P.
- for ( i = s; i < f; i++) updating (P = zero) labels (continue)
+ Check for ( i = s; i (<) F; i++) updating (P = zero) labels (continue)
{{
continue (P+P)
}};
P.
- *)
End LoopTest.
Require Import Crypto.Util.Tuple Crypto.Util.CPSUtil.
@@ -187,7 +369,7 @@ Section ScalarMult.
(*
table of xi*(32^i)*B
- for
+ for
0 <= xi <= 8 (with sign flips, -8 <= xi <= 8)
0 <= i <= 31
*)
@@ -201,5 +383,5 @@ Section ScalarMult.
continue (i', P_s)
}};
return P_s.
- Print ScalarMultBase. (* TODO: the loop notation should print here *)
-End ScalarMult. \ No newline at end of file
+ Print ScalarMultBase.
+End ScalarMult.