diff options
author | Andres Erbsen <andreser@mit.edu> | 2018-04-18 09:12:49 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-18 09:12:49 -0400 |
commit | 0f860ce167139c266409640ce2dcd4c1a1ac3996 (patch) | |
tree | 807223f21cae0a011fe82838fe8a8cb93a2f8780 | |
parent | f2ef2c85530035b60d9071abf256da37e84858bf (diff) | |
parent | 5511acfa62857e2e649c401324261dd16d9275f0 (diff) |
Merge pull request #335 from mit-plv/cpsloops
comprehensive loops framework with complete proof theory
-rw-r--r-- | _CoqProject | 8 | ||||
-rw-r--r-- | src/Curves/Montgomery/XZ.v | 2 | ||||
-rw-r--r-- | src/Curves/Montgomery/XZProofs.v | 57 | ||||
-rw-r--r-- | src/Experiments/Loops.v | 284 | ||||
-rw-r--r-- | src/Util/ForLoop.v | 89 | ||||
-rw-r--r-- | src/Util/ForLoop/Instances.v | 67 | ||||
-rw-r--r-- | src/Util/ForLoop/InvariantFramework.v | 369 | ||||
-rw-r--r-- | src/Util/ForLoop/Tests.v | 55 | ||||
-rw-r--r-- | src/Util/ForLoop/Unrolling.v | 314 | ||||
-rw-r--r-- | src/Util/Loop.v | 480 | ||||
-rw-r--r-- | src/Util/Loops.v | 526 |
11 files changed, 557 insertions, 1694 deletions
diff --git a/_CoqProject b/_CoqProject index faecc2fbf..782741ee2 100644 --- a/_CoqProject +++ b/_CoqProject @@ -240,7 +240,6 @@ src/Curves/Weierstrass/Affine.v src/Curves/Weierstrass/AffineProofs.v src/Curves/Weierstrass/Jacobian.v src/Curves/Weierstrass/Projective.v -src/Experiments/Loops.v src/Experiments/SimplyTypedArithmetic.v src/LegacyArithmetic/ArchitectureToZLike.v src/LegacyArithmetic/ArchitectureToZLikeProofs.v @@ -6454,7 +6453,6 @@ src/Util/Factorize.v src/Util/FixCoqMistakes.v src/Util/FixedWordSizes.v src/Util/FixedWordSizesEquality.v -src/Util/ForLoop.v src/Util/FsatzAutoLemmas.v src/Util/GlobalSettings.v src/Util/HList.v @@ -6466,7 +6464,7 @@ src/Util/LetIn.v src/Util/LetInMonad.v src/Util/ListUtil.v src/Util/Logic.v -src/Util/Loop.v +src/Util/Loops.v src/Util/NUtil.v src/Util/NatUtil.v src/Util/Notations.v @@ -6496,10 +6494,6 @@ src/Util/Bool/Equality.v src/Util/Bool/IsTrue.v src/Util/Decidable/Bool2Prop.v src/Util/Decidable/Decidable2Bool.v -src/Util/ForLoop/Instances.v -src/Util/ForLoop/InvariantFramework.v -src/Util/ForLoop/Tests.v -src/Util/ForLoop/Unrolling.v src/Util/ListUtil/FoldBool.v src/Util/ListUtil/Forall.v src/Util/Logic/ImplAnd.v diff --git a/src/Curves/Montgomery/XZ.v b/src/Curves/Montgomery/XZ.v index 336ee6b95..88e1d7398 100644 --- a/src/Curves/Montgomery/XZ.v +++ b/src/Curves/Montgomery/XZ.v @@ -2,7 +2,7 @@ Require Import Crypto.Algebra.Field. Require Import Crypto.Util.GlobalSettings Crypto.Util.Notations. Require Import Crypto.Util.Sum Crypto.Util.Prod Crypto.Util.LetIn. Require Import Crypto.Util.Decidable. -Require Import Crypto.Experiments.Loops. +Require Import Crypto.Util.Loops. Require Import Crypto.Spec.MontgomeryCurve Crypto.Curves.Montgomery.Affine. Module M. diff --git a/src/Curves/Montgomery/XZProofs.v b/src/Curves/Montgomery/XZProofs.v index c17d14b0a..650ed6920 100644 --- a/src/Curves/Montgomery/XZProofs.v +++ b/src/Curves/Montgomery/XZProofs.v @@ -258,7 +258,7 @@ Module M. Local Notation montladder := (M.montladder(a24:=a24)(Fadd:=Fadd)(Fsub:=Fsub)(Fmul:=Fmul)(Fzero:=Fzero)(Fone:=Fone)(Finv:=Finv)(cswap:=fun b x y => if b then pair y x else pair x y)). Local Notation scalarmult := (@ScalarMult.scalarmult_ref Mpoint Madd M.zero Mopp). - Import Crypto.Experiments.Loops. + Import Crypto.Util.Loops. Import Coq.ZArith.BinInt. Lemma to_x_inv00 (HFinv:Finv 0 = 0) x z : to_x (pair x z) = x * Finv z. @@ -281,34 +281,34 @@ Module M. Proof. cbv beta delta [M.montladder]. (* [while.by_invariant] expects a goal like [?P (while _ _ _ _)], make it so: *) - lazymatch goal with |- context [while ?t ?b ?l ?i] => pattern (while t b l i) end. - eapply (while.by_invariant + lazymatch goal with |- context [while ?t ?b ?l ?ii] => pattern (while t b l ii) end. + eapply (while.by_invariant_fuel (fun '(x2, z2, x3, z3, swap, i) => (i < scalarbits)%Z /\ z2 = 0 /\ if dec (Logic.eq i (Z.pred scalarbits)) then x3 = 0 else z3 = 0) - (fun s => Z.to_nat (Z.succ (snd s))) (* decreasing measure *) ). - { (* invariant holds in the beginning *) cbn. - split; [lia|split;[reflexivity|t]]. } + (fun s => Z.to_nat (Z.succ (snd s)))). + { split. + (* invariant holds in the beginning *) + { cbn; split; [lia|split;[reflexivity|t]]. } + { (* fuel <= measure *) cbn. rewrite Z.succ_pred. reflexivity. } } { intros [ [ [ [ [x2 z2] x3] z3] swap] i] [Hi [Hz2 Hx3z3]]. destruct (i >=? 0)%Z eqn:Hbranch; (* did the loop continue? *) rewrite Z.geb_ge_iff in Hbranch. - { (* if loop continued, invariant is preserved *) - destruct (dec (Logic.eq i (Z.pred scalarbits))). - { (* first loop iteration *) - cbv -[xzladderstep xorb Z.testbit Z.pred dec Z.lt]; - destruct (xorb swap (Z.testbit n i)); - split; [lia|t|lia|t]. } - { (* subsequent loop iterations *) - cbv -[xzladderstep xorb Z.testbit Z.pred dec Z.lt]. - destruct (xorb swap (Z.testbit n i)); - (split; [lia| split; [t| break_match;[lia|t]]]). } } + { split. (* if loop continued, invariant is preserved *) + { destruct (dec (Logic.eq i (Z.pred scalarbits))). + { (* first loop iteration *) + cbv -[xzladderstep xorb Z.testbit Z.pred dec Z.lt]; + destruct (xorb swap (Z.testbit n i)); + split; [lia|t|lia|t]. } + { (* subsequent loop iterations *) + cbv -[xzladderstep xorb Z.testbit Z.pred dec Z.lt]. + destruct (xorb swap (Z.testbit n i)); + (split; [lia| split; [t| break_match;[lia|t]]]). } } + { (* measure decreases *) + cbv [Let_In]; break_match; cbn; rewrite Z.succ_pred; apply Znat.Z2Nat.inj_lt; lia. } } { (* if loop exited, invariant implies postcondition *) break_match; break_match_hyps; setoid_subst_rel Feq; fsatz. } } - { (* fuel <= measure *) cbn. rewrite Z.succ_pred. reflexivity. } - { (* measure decreases *) intros [? i]. - destruct (i >=? 0)%Z eqn:Hbranch;rewrite Z.geb_ge_iff in Hbranch; [|exact I]. - cbv [Let_In]; break_match; cbn; rewrite Z.succ_pred; apply Znat.Z2Nat.inj_lt; lia. } Qed. Lemma montladder_correct_nz @@ -325,7 +325,7 @@ Module M. cbv beta delta [M.montladder]. (* [while.by_invariant] expects a goal like [?P (while _ _ _ _)], make it so: *) lazymatch goal with |- context [while ?t ?b ?l ?i] => pattern (while t b l i) end. - eapply (while.by_invariant + eapply (while.by_invariant_fuel (fun '(x2, z2, x3, z3, swap, i) => (i >= -1)%Z /\ projective (pair x2 z2) /\ @@ -337,12 +337,15 @@ Module M. eq q' (to_xz (scalarmult (Z.succ r) P)) /\ ladder_invariant point (scalarmult r P) (scalarmult (Z.succ r) P)) (fun s => Z.to_nat (Z.succ (snd s))) (* decreasing measure *) ). - { (* invariant holds in the beginning *) cbn. - rewrite ?Z.succ_pred, ?Z.lt_pow_2_shiftr, <-?Z.one_succ by tauto. - repeat split; [lia|t..]. } + { split; cbn. + { (* invariant holds in the beginning *) + rewrite ?Z.succ_pred, ?Z.lt_pow_2_shiftr, <-?Z.one_succ by tauto. + repeat split; [lia|t..]. } + { (* sufficient fuel *) rewrite Z.succ_pred. reflexivity. } } { intros [ [ [ [ [x2 z2] x3] z3] swap] i] [Hi [Hx2z2 [Hx3z3 [Hq [Hq' Hladder]]]]]. destruct (i >=? 0)%Z eqn:Hbranch; (* did the loop continue? *) rewrite Z.geb_ge_iff in Hbranch. + split. { (* if loop continued, invariant is preserved *) let group _ := ltac:(repeat rewrite ?scalarmult_add_l, ?scalarmult_0_l, ?scalarmult_1_l, ?Hierarchy.left_identity, ?Hierarchy.right_identity, ?Hierarchy.associative, ?(Hierarchy.commutative _ P); reflexivity) in destruct (Z.testbit n i) eqn:Hbit in *; @@ -373,16 +376,14 @@ Module M. => refine (proj1 (Proper_ladder_invariant _ _ (reflexivity _) _ _ _ _ _ _) (ladder_invariant_swap _ _ _ H)); group () | |- ?P => match type of P with Prop => split end end. } + { (* measure decreases *) + cbv [Let_In]; break_match; cbn; rewrite Z.succ_pred; apply Znat.Z2Nat.inj_lt; lia. } { (* if loop exited, invariant implies postcondition *) destruct_head' @and; autorewrite with cancel_pair in *. replace i with ((-(1))%Z) in * by lia; clear Hi Hbranch. rewrite Z.succ_m1, Z.shiftr_0_r in *. destruct swap eqn:Hswap; rewrite <-!to_x_inv00 by assumption; eauto using projective_to_xz, proper_to_x_projective. } } - { (* fuel <= measure *) cbn. rewrite Z.succ_pred. reflexivity. } - { (* measure decreases *) intros [? i]. - destruct (i >=? 0)%Z eqn:Hbranch;rewrite Z.geb_ge_iff in Hbranch; [|exact I]. - cbv [Let_In]; break_match; cbn; rewrite Z.succ_pred; apply Znat.Z2Nat.inj_lt; lia. } Qed. (* Using montladder_correct_0 in the combined correctness theorem requires diff --git a/src/Experiments/Loops.v b/src/Experiments/Loops.v deleted file mode 100644 index 94e127764..000000000 --- a/src/Experiments/Loops.v +++ /dev/null @@ -1,284 +0,0 @@ -Require Import Coq.Lists.List. -Require Import Lia. -Require Import Crypto.Util.Tactics.UniquePose. -Require Import Crypto.Util.Tactics.BreakMatch. -Require Import Crypto.Util.Tactics.DestructHead. -Require Import Crypto.Util.Decidable. -Require Import Crypto.Util.Prod. -Require Import Crypto.Util.Option. -Require Import Crypto.Util.Sum. -Require Import Crypto.Util.LetIn. -Require Crypto.Util.ListUtil. (* for tests *) - -Section Loops. - Context {continue_state break_state} - (body : continue_state -> break_state + continue_state) - (body_cps : continue_state -> - forall {T}, (break_state + continue_state -> T) - -> T). - - Definition funapp {A B} (f : A -> B) (x : A) := f x. - - Fixpoint loop_cps (fuel: nat) (start : continue_state) - {T} (ret : break_state -> T) : continue_state + T := - funapp - (body_cps start _) (fun next => - match next with - | inl state => inr (ret state) - | inr state => - match fuel with - | O => inl state - | S fuel' => - loop_cps fuel' state ret - end end). - - Fixpoint loop (fuel: nat) (start : continue_state) - : continue_state + break_state := - match (body start) with - | inl state => inr state - | inr state => - match fuel with - | O => inl state - | S fuel' => loop fuel' state - end end. - - Lemma loop_break_step fuel start state : - (body start = inl state) -> - loop fuel start = inr state. - Proof. - destruct fuel; simpl loop; break_match; intros; congruence. - Qed. - - Lemma loop_continue_step fuel start state : - (body start = inr state) -> - loop fuel start = - match fuel with | O => inl state | S fuel' => loop fuel' state end. - Proof. - destruct fuel; simpl loop; break_match; intros; congruence. - Qed. - - (* TODO: provide [invariant state] to proofs of this *) - Definition progress (measure : continue_state -> nat) := - forall state state', body state = inr state' -> measure state' < measure state. - Definition terminates fuel start := forall l, loop fuel start <> inl l. - Lemma terminates_by_measure measure (H : progress measure) : - forall fuel start, measure start <= fuel -> terminates fuel start. - Proof. - induction fuel; intros; - repeat match goal with - | _ => solve [ congruence | lia ] - | _ => progress cbv [progress terminates] in * - | _ => progress cbn [loop] - | _ => progress break_match - | H : forall _ _, body _ = inr _ -> _ , Heq : body _ = inr _ |- _ => specialize (H _ _ Heq) - | _ => eapply IHfuel - end. - Qed. - - Definition loop_default fuel start default - : break_state := - sum_rect - (fun _ => break_state) - (fun _ => default) - (fun result => result) - (loop fuel start). - - Lemma loop_default_eq fuel start default - (Hterm : terminates fuel start) : - loop fuel start = inr (loop_default fuel start default). - Proof. - cbv [terminates loop_default sum_rect] in *; break_match; congruence. - Qed. - - Lemma invariant_iff fuel start default (H : terminates fuel start) P : - P (loop_default fuel start default) <-> - (exists (inv : continue_state -> Prop), - inv start - /\ (forall s s', body s = inr s' -> inv s -> inv s') - /\ (forall s s', body s = inl s' -> inv s -> P s')). - Proof. - split; - [ exists (fun st => exists f e, (loop f st = inr e /\ P e )) - | destruct 1 as [?[??]]; revert dependent start; induction fuel ]; - repeat match goal with - | _ => solve [ trivial | congruence | eauto ] - | _ => progress destruct_head' @ex - | _ => progress destruct_head' @and - | _ => progress intros - | _ => progress cbv [loop_default terminates] in * - | _ => progress cbn [loop] in * - | _ => progress erewrite loop_default_eq by eassumption - | _ => progress erewrite loop_continue_step in * by eassumption - | _ => progress erewrite loop_break_step in * by eassumption - | _ => progress break_match_hyps - | _ => progress break_match - | _ => progress eexists - | H1:_, c:_ |- _ => progress specialize (H1 c); congruence - end. - Qed. -End Loops. - -Definition by_invariant {continue_state break_state body fuel start default} - invariant measure P invariant_start invariant_continue invariant_break le_start progress - := proj2 (@invariant_iff continue_state break_state body fuel start default (terminates_by_measure body measure progress fuel start le_start) P) - (ex_intro _ invariant (conj invariant_start (conj invariant_continue invariant_break))). -Arguments terminates_by_measure {_ _ _}. - -Module while. - Section While. - Context {state} - (test : state -> bool) - (body : state -> state). - - Fixpoint while (fuel: nat) (s : state) {struct fuel} : state := - if test s - then - let s := body s in - match fuel with - | O => s - | S fuel' => while fuel' s - end - else s. - - Section AsLoop. - Local Definition lbody := fun s => if test s then inr (body s) else inl s. - - Lemma eq_loop : forall fuel start, while fuel start = loop_default lbody fuel start (while fuel start). - Proof. - induction fuel; intros; - cbv [lbody loop_default sum_rect id] in *; - cbn [while loop]; [|rewrite IHfuel]; break_match; auto. - Qed. - - Lemma by_invariant fuel start - (invariant : state -> Prop) (measure : state -> nat) (P : state -> Prop) - (_: invariant start) - (_: forall s, invariant s -> if test s then invariant (body s) else P s) - (_: measure start <= fuel) - (_: forall s, if test s then measure (body s) < measure s else True) - : P (while fuel start). - Proof. - rewrite eq_loop; cbv [lbody]. - eapply (by_invariant invariant measure); - repeat match goal with - | [ H : forall s, invariant s -> _, G: invariant ?s |- _ ] => unique pose proof (H _ G) - | [ H : forall s, if ?f s then _ else _, G: ?f ?s = _ |- _ ] => unique pose proof (H s) - | _ => solve [ trivial | congruence ] - | _ => progress cbv [progress] - | _ => progress intros - | _ => progress subst - | _ => progress inversion_sum - | _ => progress break_match_hyps (* FIXME: this must be last? *) - end. - Qed. - End AsLoop. - End While. - Arguments by_invariant {_ _ _ _ _}. -End while. -Notation while := while.while. - -Definition for2 {state} (test : state -> bool) (increment body : state -> state) - := while test (fun s => increment (body s)). - -Definition for3 {state} init test increment body fuel := - @for2 state test increment body fuel init. - -Module _test. - Section GCD. - Definition gcd_step := - fun '(a, b) => if Nat.ltb a b - then inr (a, b-a) - else if Nat.ltb b a - then inr (a-b, b) - else inl a. - - Definition gcd fuel a b := loop_default gcd_step fuel (a,b) 0. - - (* Eval cbv [gcd loop_default loop gcd_step] in (gcd 10 5 7). *) - - Example gcd_test : gcd 1000 28 35 = 7 := eq_refl. - - Definition gcd_step_cps - : (nat * nat) -> forall T, (nat + (nat * nat) -> T) -> T - := - fun st T ret => - let a := fst st in - let b := snd st in - if Nat.ltb a b - then ret (inr (a, b-a)) - else if Nat.ltb b a - then ret (inr (a-b, b)) - else ret (inl a). - - Definition gcd_cps fuel a b {T} (ret:nat->T) - := loop_cps gcd_step_cps fuel (a,b) ret. - - Example gcd_test2 : gcd_cps 1000 28 35 id = inr 7 := eq_refl. - - (* Eval cbv [gcd_cps loop_cps gcd_step_cps id] in (gcd_cps 2 5 7 id). *) - - End GCD. - - (* simple example--set all elements in a list to 0 *) - Section ZeroLoop. - Import Crypto.Util.ListUtil. - - Definition zero_body (state : nat * list nat) : - list nat + (nat * list nat) := - if dec (fst state < length (snd state)) - then inr (S (fst state), set_nth (fst state) 0 (snd state)) - else inl (snd state). - - Lemma zero_body_progress (arr : list nat) : - progress zero_body (fun state : nat * list nat => length (snd state) - fst state). - Proof. - cbv [zero_body progress]; intros until 0; - repeat match goal with - | _ => progress autorewrite with cancel_pair distr_length - | _ => progress subst - | _ => progress break_match; intros - | _ => congruence - | H: inl _ = inl _ |- _ => injection H; intros; subst; clear H - | H: inr _ = inr _ |- _ => injection H; intros; subst ;clear H - | _ => lia - end. - Qed. - - Definition zero_loop (arr : list nat) : list nat := - loop_default zero_body (length arr) (0,arr) nil. - - Definition zero_invariant (state : nat * list nat) := - fst state <= length (snd state) - /\ forall n, n < fst state -> nth_default 0 (snd state) n = 0. - - Lemma zero_correct (arr : list nat) : - forall n, nth_default 0 (zero_loop arr) n = 0. - Proof. - intros. cbv [zero_loop]. - eapply (by_invariant zero_invariant); eauto using zero_body_progress; - [ cbv [zero_invariant]; autorewrite with cancel_pair; split; intros; lia | ..]; - cbv [zero_invariant zero_body]; - intros until 0; - break_match; intros; - repeat match goal with - | _ => congruence - | H: inl _ = inl _ |- _ => injection H; intros; subst; clear H - | H: inr _ = inr _ |- _ => injection H; intros; subst ;clear H - | _ => progress split - | _ => progress intros - | _ => progress subst - | _ => progress (autorewrite with cancel_pair distr_length in * ) - | _ => rewrite set_nth_nth_default by lia - | _ => progress break_match - | H : _ /\ _ |- _ => destruct H - | H : (_,_) = ?x |- _ => - destruct x; inversion H; subst; destruct H - | H : _ |- _ => apply H; lia - | _ => lia - end. - destruct (Compare_dec.lt_dec n (fst s)). - apply H1; lia. - apply nth_default_out_of_bounds; lia. - Qed. - End ZeroLoop. -End _test.
\ No newline at end of file diff --git a/src/Util/ForLoop.v b/src/Util/ForLoop.v deleted file mode 100644 index db12608e2..000000000 --- a/src/Util/ForLoop.v +++ /dev/null @@ -1,89 +0,0 @@ -(** * Definition and Notations for [for (int i = i₀; i < i∞; i += Δi)] *) -Require Import Coq.ZArith.BinInt. -Require Import Crypto.Util.Notations. -(** Note: These definitions are fairly tricky. See - https://github.com/mit-plv/fiat-crypto/issues/164 and - https://github.com/mit-plv/fiat-crypto/pull/163 for more - discussion. - - TODO: Fix the definitions to make them more obviously right. *) - -Section with_body. - Context {stateT : Type} - (body : nat -> stateT -> stateT). - - Fixpoint repeat_function (count : nat) (st : stateT) : stateT - := match count with - | O => st - | S count' => repeat_function count' (body count st) - end. -End with_body. - -Local Open Scope bool_scope. -Local Open Scope Z_scope. - -Definition for_loop {stateT} (i0 finish : Z) (step : Z) (initial : stateT) (body : Z -> stateT -> stateT) - : stateT - := let count := Z.to_nat (Z.quot (finish - i0 + step - Z.sgn step) step) in - repeat_function (fun c => body (i0 + step * Z.of_nat (count - c))) count initial. - - -Notation "'for' i (:= i0 ; += step ; < finish ) 'updating' ( state := initial ) {{ body }}" - := (for_loop i0 finish step initial (fun i state => body)) - : core_scope. - -Module Import ForNotationConstants. - Definition eq := @eq Z. - Module Z. - Definition ltb := Z.ltb. - Definition ltb' := Z.ltb. - Definition gtb := Z.gtb. - Definition gtb' := Z.gtb. - End Z. -End ForNotationConstants. - -Delimit Scope for_notation_scope with for_notation. -Notation "x += y" := (eq x (Z.pos y)) : for_notation_scope. -Notation "x -= y" := (eq x (Z.neg y)) : for_notation_scope. -Notation "++ x" := (x += 1)%for_notation : for_notation_scope. -Notation "-- x" := (x -= 1)%for_notation : for_notation_scope. -Notation "x ++" := (x += 1)%for_notation : for_notation_scope. -Notation "x --" := (x -= 1)%for_notation : for_notation_scope. -Infix "<" := Z.ltb : for_notation_scope. -Infix ">" := Z.gtb : for_notation_scope. -Notation "x <= y" := (Z.ltb' x (y + 1)) : for_notation_scope. -Notation "x >= y" := (Z.gtb' x (y - 1)) : for_notation_scope. - -Class class_eq {A} (x y : A) := make_class_eq : x = y. -Global Instance class_eq_refl {A x} : @class_eq A x x := eq_refl. - -Class for_loop_is_good (i0 : Z) (step : Z) (finish : Z) (cmp : Z -> Z -> bool) - := make_good : - ((Z.sgn step =? Z.sgn (finish - i0)) - && (cmp i0 finish)) - = true. -Hint Extern 0 (for_loop_is_good _ _ _ _) => vm_compute; reflexivity : typeclass_instances. - -Definition for_loop_notation {i0 : Z} {step : Z} {finish : Z} {stateT} {initial : stateT} - {cmp : Z -> Z -> bool} - step_expr finish_expr (body : Z -> stateT -> stateT) - {Hstep : class_eq (fun i => i = step) step_expr} - {Hfinish : class_eq (fun i => cmp i finish) finish_expr} - {Hgood : for_loop_is_good i0 step finish cmp} - : stateT - := for_loop i0 finish step initial body. - -Notation "'for' ( 'int' i = i0 ; finish_expr ; step_expr ) 'updating' ( state1 .. staten = initial ) {{ body }}" - := (@for_loop_notation - i0%Z _ _ _ initial%Z _ - (fun i : Z => step_expr%for_notation) - (fun i : Z => finish_expr%for_notation) - (fun (i : Z) => (fun state1 => .. (fun staten => body) .. )) - _ _ _). -Notation "'for' ( 'int' i = i0 ; finish_expr ; step_expr ) 'updating' ( state1 .. staten = initial ) {{ body }}" - := (@for_loop_notation - i0%Z _ _ _ initial%Z _ - (fun i : Z => step_expr%for_notation) - (fun i : Z => finish_expr%for_notation) - (fun (i : Z) => (fun state1 => .. (fun staten => body) .. )) - eq_refl eq_refl _). diff --git a/src/Util/ForLoop/Instances.v b/src/Util/ForLoop/Instances.v deleted file mode 100644 index 0a1f65e29..000000000 --- a/src/Util/ForLoop/Instances.v +++ /dev/null @@ -1,67 +0,0 @@ -Require Import Coq.omega.Omega. -Require Import Coq.Classes.Morphisms. -Require Import Crypto.Util.ForLoop. -Require Import Crypto.Util.Notations. - -Lemma repeat_function_Proper_rel_le {stateT} R f g n - (Hfg : forall c, 0 < c <= n -> forall s1 s2, R s1 s2 -> R (f c s1) (g c s2)) - s1 s2 (Hs : R s1 s2) - : R (@repeat_function stateT f n s1) (@repeat_function stateT g n s2). -Proof. - revert s1 s2 Hs. - induction n; simpl; auto. - intros; apply IHn; auto; - intros; apply Hfg; auto; - omega. -Qed. - -Global Instance repeat_function_Proper_rel {stateT} R - : Proper (pointwise_relation _ (R ==> R) ==> eq ==> R ==> R) (@repeat_function stateT) | 10. -Proof. - unfold pointwise_relation, respectful. - intros body1 body2 Hbody c y ?; subst y. - induction c; simpl; auto. -Qed. - -Lemma repeat_function_Proper_le {stateT} f g n - (Hfg : forall c, 0 < c <= n -> forall st, f c st = g c st) - st - : @repeat_function stateT f n st = @repeat_function stateT g n st. -Proof. - apply repeat_function_Proper_rel_le; try reflexivity; intros; subst; auto. -Qed. - -Global Instance repeat_function_Proper {stateT} - : Proper (pointwise_relation _ (pointwise_relation _ eq) ==> eq ==> eq ==> eq) (@repeat_function stateT). -Proof. - intros ???; eapply repeat_function_Proper_rel; repeat intro; subst. - unfold pointwise_relation, respectful in *; auto. -Qed. -About for_loop. - -Global Instance for_loop_Proper_rel {stateT} R i0 final step - : Proper (R ==> pointwise_relation _ (R ==> R) ==> R) (@for_loop stateT i0 final step) | 10. -Proof. - intros ?? Hinitial ?? Hbody; revert Hinitial. - unfold for_loop; eapply repeat_function_Proper_rel; - unfold pointwise_relation, respectful in *; auto. -Qed. - -Global Instance for_loop_Proper_rel_full {stateT} R - : Proper (eq ==> eq ==> eq ==> R ==> pointwise_relation _ (R ==> R) ==> R) (@for_loop stateT) | 20. -Proof. - intros ?????????; subst; apply for_loop_Proper_rel. -Qed. - -Global Instance for_loop_Proper {stateT} i0 final step initial - : Proper (pointwise_relation _ (pointwise_relation _ eq) ==> eq) (@for_loop stateT i0 final step initial). -Proof. - unfold pointwise_relation. - intros ???; eapply for_loop_Proper_rel; try reflexivity; repeat intro; subst; auto. -Qed. - -Global Instance for_loop_Proper_full {stateT} - : Proper (eq ==> eq ==> eq ==> eq ==> pointwise_relation _ (pointwise_relation _ eq) ==> eq) (@for_loop stateT) | 5. -Proof. - intros ????????????; subst; apply for_loop_Proper. -Qed. diff --git a/src/Util/ForLoop/InvariantFramework.v b/src/Util/ForLoop/InvariantFramework.v deleted file mode 100644 index 2bff6bef9..000000000 --- a/src/Util/ForLoop/InvariantFramework.v +++ /dev/null @@ -1,369 +0,0 @@ -(** * Proving properties of for-loops via loop-invariants *) -Require Import Coq.micromega.Psatz. -Require Import Coq.ZArith.ZArith. -Require Import Crypto.Util.ForLoop. -Require Import Crypto.Util.ForLoop.Unrolling. -Require Import Crypto.Util.ZUtil. -Require Import Crypto.Util.Bool. -Require Import Crypto.Util.Tactics.UniquePose. -Require Import Crypto.Util.Notations. - -Lemma repeat_function_ind {stateT} (P : nat -> stateT -> Prop) - (body : nat -> stateT -> stateT) - (count : nat) (st : stateT) - (Pbefore : P count st) - (Pbody : forall c st, c < count -> P (S c) st -> P c (body (S c) st)) - : P 0 (repeat_function body count st). -Proof. - revert dependent st; revert dependent body; revert dependent P. - induction count as [|? IHcount]; intros P body Pbody st Pbefore; [ exact Pbefore | ]. - { rewrite repeat_function_unroll1_end; apply Pbody; [ omega | ]. - apply (IHcount (fun c => P (S c))); auto with omega. } -Qed. - -Local Open Scope bool_scope. -Local Open Scope Z_scope. - -Section for_loop. - Context (i0 finish : Z) (step : Z) {stateT} (initial : stateT) (body : Z -> stateT -> stateT) - (P : Z -> stateT -> Prop) - (Pbefore : P i0 initial) - (Pbody : forall c st n, c = i0 + n * step -> i0 <= c < finish \/ finish < c <= i0 -> P c st -> P (c + step) (body c st)) - (Hgood : Z.sgn step = Z.sgn (finish - i0)). - - Let countZ := (Z.quot (finish - i0 + step - Z.sgn step) step). - Let count := Z.to_nat countZ. - Let of_nat_count c := (i0 + step * Z.of_nat (count - c)). - Let nat_body := (fun c => body (of_nat_count c)). - - Local Arguments Z.mul !_ !_. - Local Arguments Z.add !_ !_. - Local Arguments Z.sub !_ !_. - - Local Lemma Hgood_complex : Z.sgn step = Z.sgn (finish - i0 + step - Z.sgn step). - Proof using Hgood. - clear -Hgood. - revert Hgood. - generalize dependent (finish - i0); intro z; intros. - destruct step, z; simpl in * |- ; try (simpl; omega); - repeat change (Z.sgn (Z.pos _)) with 1; - repeat change (Z.sgn (Z.neg _)) with (-1); - symmetry; - [ apply Z.sgn_pos_iff | apply Z.sgn_neg_iff ]; - lia. - Qed. - - Local Lemma Hcount_nonneg : 0 <= countZ. - Proof using Hgood. - apply Z.quot_nonneg_same_sgn. - symmetry; apply Hgood_complex. - Qed. - - Lemma for_loop_ind - : P (finish - Z.sgn (finish - i0 + step - Z.sgn step) * (Z.abs (finish - i0 + step - Z.sgn step) mod Z.abs step) + step - Z.sgn step) - (for_loop i0 finish step initial body). - Proof using Pbody Pbefore Hgood. - destruct (Z_zerop step). - { subst; unfold for_loop; simpl in *. - rewrite Z.quot_div_full; simpl. - symmetry in Hgood; rewrite Z.sgn_null_iff in Hgood. - assert (finish = i0) by omega; subst. - simpl; autorewrite with zsimplify_const; simpl; auto. } - assert (Hsgn_step : Z.sgn step <> 0) by (rewrite Z.sgn_null_iff; auto). - assert (Hsgn : Z.sgn ((finish - i0 + step - Z.sgn step) / step) = Z.sgn ((finish - i0 + step - Z.sgn step) / step) * Z.sgn (finish - i0 + step - Z.sgn step) * Z.sgn step) - by (rewrite <- Hgood_complex, <- Z.mul_assoc, <- Z.sgn_mul, (Z.sgn_pos (_ * _)) by nia; omega). - assert (Hfis_div : 0 <= (finish - i0 + step - Z.sgn step) / step) - by (apply Z.sgn_nonneg; rewrite Hsgn; apply Zdiv_sgn). - clear Hsgn. - let rhs := match goal with |- ?P ?rhs _ => rhs end in - assert (Heq : i0 + step * Z.of_nat count = rhs). - { unfold count, countZ. - rewrite Z.mod_eq by (rewrite Z.abs_0_iff; assumption). - rewrite Z.quot_div_full, <- !Z.sgn_abs, <- !Hgood_complex, !Zdiv_mult_cancel_r, !Z.mul_sub_distr_l by auto. - rewrite <- !Z.sgn_mul, !(Z.mul_comm _ (Z.sgn _)), !(Z.mul_assoc (Z.sgn _) _), <- Z.sgn_mul, Z.sgn_pos, !Z.mul_1_l by nia. - repeat rewrite ?Z.sub_add_distr, ?Z.sub_sub_distr; rewrite Z.sub_diag. - autorewrite with zsimplify_const. - rewrite Z2Nat.id by omega. - omega. } - rewrite <- Heq; clear Heq. - unfold for_loop. - generalize (@repeat_function_ind stateT (fun c => P (of_nat_count c)) nat_body count initial); - cbv beta in *. - unfold of_nat_count in *; cbv beta in *. - rewrite Nat.sub_diag, !Nat.sub_0_r. - autorewrite with zsimplify_const. - intro H; specialize (H Pbefore). - destruct (Z_dec' i0 finish) as [ Hs | Hs]. - { apply H; clear H Pbefore. - { intros c st Hc. - assert (Hclt : 0 < Z.of_nat (count - c)) by (apply (inj_lt 0); omega). - intro H'; specialize (fun pf' n pf => Pbody _ _ n pf pf' H'). - move Pbody at bottom. - { let T := match type of Pbody with ?T -> _ => T end in - let H := fresh in - cut T; [ intro H; specialize (Pbody H) | ]. - { revert Pbody. - subst nat_body; cbv beta. - rewrite Nat.sub_succ_r, Nat2Z.inj_pred by omega. - rewrite <- Z.sub_1_r, Z.mul_sub_distr_l, Z.mul_1_r. - rewrite <- !Z.add_assoc, !Z.sub_add in *. - refine (fun p => p (Z.of_nat (count - c) - 1) _). - lia. } - { destruct Hs; [ left | right ]. - { assert (Hstep : 0 < step) - by (rewrite <- Z.sgn_pos_iff, Hgood, Z.sgn_pos_iff; omega). - assert (0 < Z.of_nat (S c)) by (apply (inj_lt 0); omega). - assert (0 <= (finish - i0 + step - Z.sgn step) mod step) by auto with zarith. - assert (0 < step <= step * Z.of_nat (S c)) by nia. - split; [ nia | ]. - rewrite Nat2Z.inj_sub, Z.mul_sub_distr_l by omega. - unfold count. - rewrite Z2Nat.id by auto using Hcount_nonneg. - unfold countZ. - rewrite Z.mul_quot_eq_full by auto. - rewrite <- !Hgood_complex, Z.abs_sgn. - rewrite !Z.add_sub_assoc, !Z.add_assoc, Zplus_minus. - rewrite Z.sgn_pos in * by omega. - omega. } - { assert (Hstep : step < 0) - by (rewrite <- Z.sgn_neg_iff, Hgood, Z.sgn_neg_iff; omega). - assert (Hcsc0 : 0 <= Z.of_nat (count - S c)) by auto with zarith. - assert (Hsc0 : 0 < Z.of_nat (S c)) by lia. - assert (step * Z.of_nat (count - S c) <= 0) by (clear -Hcsc0 Hstep; nia). - assert (step * Z.of_nat (S c) <= step < 0) by (clear -Hsc0 Hstep; nia). - assert (finish - i0 < 0) - by (rewrite <- Z.sgn_neg_iff, <- Hgood, Z.sgn_neg_iff; omega). - assert (finish - i0 + step - Z.sgn step < 0) - by (rewrite <- Z.sgn_neg_iff, <- Hgood_complex, Z.sgn_neg_iff; omega). - assert ((finish - i0 + step - Z.sgn step) mod step <= 0) by (apply Z_mod_neg; auto with zarith). - split; [ | nia ]. - rewrite Nat2Z.inj_sub, Z.mul_sub_distr_l by omega. - unfold count. - rewrite Z2Nat.id by auto using Hcount_nonneg. - unfold countZ. - rewrite Z.mul_quot_eq_full by auto. - rewrite <- !Hgood_complex, Z.abs_sgn. - rewrite Z.sgn_neg in * by omega. - omega. } } } } } - { subst. - subst count nat_body countZ. - repeat first [ assumption - | rewrite Z.sub_diag - | progress autorewrite with zsimplify_const in * - | rewrite Z.quot_sub_sgn ]. } - Qed. -End for_loop. - -Lemma for_loop_notation_ind {stateT} (P : Z -> stateT -> Prop) - {i0 : Z} {step : Z} {finish : Z} {initial : stateT} - {cmp : Z -> Z -> bool} - {step_expr finish_expr} (body : Z -> stateT -> stateT) - {Hstep : class_eq (fun i => i = step) step_expr} - {Hfinish : class_eq (fun i => cmp i finish) finish_expr} - {Hgood : for_loop_is_good i0 step finish cmp} - (Pbefore : P i0 initial) - (Pbody : forall c st n, c = i0 + n * step -> i0 <= c < finish \/ finish < c <= i0 -> P c st -> P (c + step) (body c st)) - : P (finish - Z.sgn (finish - i0 + step - Z.sgn step) * (Z.abs (finish - i0 + step - Z.sgn step) mod Z.abs step) + step - Z.sgn step) - (@for_loop_notation i0 step finish _ initial cmp step_expr finish_expr body Hstep Hfinish Hgood). -Proof. - unfold for_loop_notation, for_loop_is_good in *; split_andb; Z.ltb_to_lt. - apply for_loop_ind; auto. -Qed. - -Local Ltac pre_t := - lazymatch goal with - | [ Pbefore : ?P ?i0 ?initial - |- ?P _ (@for_loop_notation ?i0 ?step ?finish _ ?initial _ ?step_expr ?finish_expr ?body ?Hstep ?Hfinish ?Hgood) ] - => generalize (@for_loop_notation_ind - _ P i0 step finish initial _ step_expr finish_expr body Hstep Hfinish Hgood Pbefore) - end. -Local Ltac t_step := - first [ progress unfold for_loop_is_good, for_loop_notation in * - | progress split_andb - | progress Z.ltb_to_lt - | rewrite Z.sgn_pos by lia - | rewrite Z.abs_eq by lia - | rewrite Z.sgn_neg by lia - | rewrite Z.abs_neq by lia - | progress autorewrite with zsimplify_const - | match goal with - | [ Hsgn : Z.sgn ?step = Z.sgn _ |- _ ] - => unique assert (0 < step) by (rewrite <- Z.sgn_pos_iff, Hsgn, Z.sgn_pos_iff; omega); clear Hsgn - | [ Hsgn : Z.sgn ?step = Z.sgn _ |- _ ] - => unique assert (step < 0) by (rewrite <- Z.sgn_neg_iff, Hsgn, Z.sgn_neg_iff; omega); clear Hsgn - | [ |- (_ -> ?P ?x ?y) -> ?P ?x' ?y' ] - => replace x with x' by lia; let H' := fresh in intro H'; apply H'; clear H' - | [ |- (_ -> _) -> _ ] - => let H := fresh "Hbody" in intro H; progress Z.replace_all_neg_with_pos; revert H - end - | rewrite !Z.opp_sub_distr - | rewrite !Z.opp_add_distr - | rewrite !Z.opp_involutive - | rewrite !Z.sub_opp_r - | rewrite (Z.add_opp_r _ 1) - | progress (push_Zmod; pull_Zmod) - | progress Z.replace_all_neg_with_pos - | solve [ eauto with omega ] ]. -Local Ltac t := pre_t; repeat t_step. - -Lemma for_loop_ind_lt {stateT} (P : Z -> stateT -> Prop) - {i0 : Z} {step : Z} {finish : Z} {initial : stateT} - {step_expr finish_expr} (body : Z -> stateT -> stateT) - {Hstep : class_eq (fun i => i = step) step_expr} - {Hfinish : class_eq (fun i => Z.ltb i finish) finish_expr} - {Hgood : for_loop_is_good i0 step finish Z.ltb} - (Pbefore : P i0 initial) - (Pbody : forall c st n, c = i0 + n * step -> i0 <= c < finish -> P c st -> P (c + step) (body c st)) - : P (finish + step - 1 - ((finish - i0 - 1) mod step)) - (@for_loop_notation i0 step finish _ initial Z.ltb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish Hgood). -Proof. t. Qed. - -Lemma for_loop_ind_gt {stateT} (P : Z -> stateT -> Prop) - {i0 : Z} {step : Z} {finish : Z} {initial : stateT} - {step_expr finish_expr} (body : Z -> stateT -> stateT) - {Hstep : class_eq (fun i => i = step) step_expr} - {Hfinish : class_eq (fun i => Z.gtb i finish) finish_expr} - {Hgood : for_loop_is_good i0 step finish Z.gtb} - (Pbefore : P i0 initial) - (Pbody : forall c st n, c = i0 + n * step -> finish < c <= i0 -> P c st -> P (c + step) (body c st)) - : P (finish + step + 1 + (i0 - finish - step - 1) mod (-step)) - (@for_loop_notation i0 step finish _ initial Z.gtb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish Hgood). -Proof. - replace (i0 - finish) with (-(finish - i0)) by omega. - t. -Qed. - -Lemma for_loop_ind_lt1 {stateT} (P : Z -> stateT -> Prop) - {i0 : Z} {finish : Z} {initial : stateT} - (body : Z -> stateT -> stateT) - {Hgood : for_loop_is_good i0 1 finish _} - (Pbefore : P i0 initial) - (Pbody : forall c st, i0 <= c < finish -> P c st -> P (c + 1) (body c st)) - : P finish - (for (int i = i0; i < finish; i++) updating (st = initial) {{ - body i st - }}). -Proof. - generalize (@for_loop_ind_lt - stateT P i0 1 finish initial _ _ body eq_refl eq_refl Hgood Pbefore). - rewrite Z.mod_1_r, Z.sub_0_r, Z.add_simpl_r. - auto. -Qed. - -Lemma for_loop_ind_gt1 {stateT} (P : Z -> stateT -> Prop) - {i0 : Z} {finish : Z} {initial : stateT} - (body : Z -> stateT -> stateT) - {Hgood : for_loop_is_good i0 (-1) finish _} - (Pbefore : P i0 initial) - (Pbody : forall c st, finish < c <= i0 -> P c st -> P (c - 1) (body c st)) - : P finish - (for (int i = i0; i > finish; i--) updating (st = initial) {{ - body i st - }}). -Proof. - generalize (@for_loop_ind_gt - stateT P i0 (-1) finish initial _ _ body eq_refl eq_refl Hgood Pbefore). - simpl; rewrite Z.mod_1_r, Z.add_0_r, (Z.add_opp_r _ 1), Z.sub_simpl_r. - intro H; apply H; intros *. - rewrite (Z.add_opp_r _ 1); auto. -Qed. - -Local Hint Extern 1 (for_loop_is_good (?i0 + ?step) ?step ?finish _) -=> refine (for_loop_is_good_step_lt _); try assumption : typeclass_instances. -Local Hint Extern 1 (for_loop_is_good (?i0 + ?step) ?step ?finish _) -=> refine (for_loop_is_good_step_gt _); try assumption : typeclass_instances. -Local Hint Extern 1 (for_loop_is_good (?i0 - ?step') _ ?finish _) -=> refine (for_loop_is_good_step_gt (step:=-step') _); try assumption : typeclass_instances. -Local Hint Extern 1 (for_loop_is_good (?i0 + 1) 1 ?finish _) -=> refine (for_loop_is_good_step_lt' _); try assumption : typeclass_instances. -Local Hint Extern 1 (for_loop_is_good (?i0 - 1) (-1) ?finish _) -=> refine (for_loop_is_good_step_gt' _); try assumption : typeclass_instances. - -(** The Hoare-logic-like conditions for ≤ and ≥ loops seem slightly - unnatural; you have to choose either to state your correctness - property in terms of [i + 1], or talk about the correctness - condition when the loop counter is [i₀ - 1] (which is strange; - it's like saying the loop has run -1 times), or give the - correctness condition after the first run of the loop body, rather - than before it. We give lemmas for the second two options; if - you're using the first one, Coq probably won't be able to infer - the motive ([P], below) automatically, and you might as well use - the vastly more general version [for_loop_ind_lt] / - [for_loop_ind_gt]. *) -Lemma for_loop_ind_le1 {stateT} (P : Z -> stateT -> Prop) - {i0 : Z} {finish : Z} {initial : stateT} - (body : Z -> stateT -> stateT) - {Hgood : for_loop_is_good i0 1 (finish+1) _} - (Pbefore : P i0 (body i0 initial)) - (Pbody : forall c st, i0 <= c <= finish -> P (c-1) st -> P c (body c st)) - : P finish - (for (int i = i0; i <= finish; i++) updating (st = initial) {{ - body i st - }}). -Proof. - rewrite for_loop_le1_unroll1. - edestruct Sumbool.sumbool_of_bool; Z.ltb_to_lt; cbv zeta. - { generalize (@for_loop_ind_lt - stateT (fun n => P (n - 1)) (i0+1) 1 (finish+1) (body i0 initial) _ _ body eq_refl eq_refl _). - rewrite Z.mod_1_r, Z.sub_0_r, !Z.add_simpl_r. - intro H; apply H; auto with omega; intros *. - rewrite !Z.add_simpl_r; auto with omega. } - { unfold for_loop_is_good, ForNotationConstants.Z.ltb', ForNotationConstants.Z.ltb in *; split_andb; Z.ltb_to_lt. - assert (i0 = finish) by omega; subst. - assumption. } -Qed. - -Lemma for_loop_ind_le1_offset {stateT} (P : Z -> stateT -> Prop) - {i0 : Z} {finish : Z} {initial : stateT} - (body : Z -> stateT -> stateT) - {Hgood : for_loop_is_good i0 1 (finish+1) _} - (Pbefore : P (i0-1) initial) - (Pbody : forall c st, i0 <= c <= finish -> P (c-1) st -> P c (body c st)) - : P finish - (for (int i = i0; i <= finish; i++) updating (st = initial) {{ - body i st - }}). -Proof. - apply for_loop_ind_le1; auto with omega. - unfold for_loop_is_good, ForNotationConstants.Z.ltb', ForNotationConstants.Z.ltb in *; split_andb; Z.ltb_to_lt. - auto with omega. -Qed. - -Lemma for_loop_ind_ge1 {stateT} (P : Z -> stateT -> Prop) - {i0 : Z} {finish : Z} {initial : stateT} - (body : Z -> stateT -> stateT) - {Hgood : for_loop_is_good i0 (-1) (finish-1) _} - (Pbefore : P i0 (body i0 initial)) - (Pbody : forall c st, finish <= c <= i0 -> P (c+1) st -> P c (body c st)) - : P finish - (for (int i = i0; i >= finish; i--) updating (st = initial) {{ - body i st - }}). -Proof. - rewrite for_loop_ge1_unroll1. - edestruct Sumbool.sumbool_of_bool; Z.ltb_to_lt; cbv zeta. - { generalize (@for_loop_ind_gt - stateT (fun n => P (n + 1)) (i0-1) (-1) (finish-1) (body i0 initial) _ _ body eq_refl eq_refl _). - simpl; rewrite Z.mod_1_r, Z.add_0_r, (Z.add_opp_r _ 1), !Z.sub_simpl_r. - intro H; apply H; intros *; auto with omega. - rewrite (Z.add_opp_r _ 1), !Z.sub_simpl_r; auto with omega. } - { unfold for_loop_is_good, ForNotationConstants.Z.gtb', ForNotationConstants.Z.gtb in *; split_andb; Z.ltb_to_lt. - assert (i0 = finish) by omega; subst. - assumption. } -Qed. - -Lemma for_loop_ind_ge1_offset {stateT} (P : Z -> stateT -> Prop) - {i0 : Z} {finish : Z} {initial : stateT} - (body : Z -> stateT -> stateT) - {Hgood : for_loop_is_good i0 (-1) (finish-1) _} - (Pbefore : P (i0+1) initial) - (Pbody : forall c st, finish <= c <= i0 -> P (c+1) st -> P c (body c st)) - : P finish - (for (int i = i0; i >= finish; i--) updating (st = initial) {{ - body i st - }}). -Proof. - apply for_loop_ind_ge1; auto with omega. - unfold for_loop_is_good, ForNotationConstants.Z.gtb', ForNotationConstants.Z.gtb in *; split_andb; Z.ltb_to_lt. - auto with omega. -Qed. diff --git a/src/Util/ForLoop/Tests.v b/src/Util/ForLoop/Tests.v deleted file mode 100644 index 1061f1958..000000000 --- a/src/Util/ForLoop/Tests.v +++ /dev/null @@ -1,55 +0,0 @@ -Require Import Coq.ZArith.BinInt. -Require Import Coq.micromega.Psatz. -Require Import Crypto.Util.ForLoop. -Require Import Crypto.Util.ForLoop.InvariantFramework. -Require Import Crypto.Util.ZUtil. - -Local Open Scope Z_scope. - -Check (for i (:= 0; += 1; < 10) updating (v := 5) {{ v + i }}). -Check (for (int i = 0; i < 5; i++) updating ( '(v1, v2) = (0, 0) ) {{ (v1 + i, v2 + i) }}). - -Compute for (int i = 0; i < 5; i++) updating (v = 0) {{ v + i }}. -Compute for (int i = 0; i <= 5; i++) updating (v = 0) {{ v + i }}. -Compute for (int i = 5; i > -1; i--) updating (v = 0) {{ v + i }}. -Compute for (int i = 5; i >= 0; i--) updating (v = 0) {{ v + i }}. -Compute for (int i = 0; i < 5; i += 2) updating (v = 0) {{ v + i }}. -Compute for (int i = 0; i <= 5; i += 2) updating (v = 0) {{ v + i }}. -Compute for (int i = 5; i > -1; i -= 2) updating (v = 0) {{ v + i }}. -Compute for (int i = 5; i >= 0; i -= 2) updating (v = 0) {{ v + i }}. -Compute for (int i = 0; i < 6; i += 2) updating (v = 0) {{ v + i }}. -Compute for (int i = 0; i <= 6; i += 2) updating (v = 0) {{ v + i }}. -Compute for (int i = 6; i > -1; i -= 2) updating (v = 0) {{ v + i }}. -Compute for (int i = 6; i >= 0; i -= 2) updating (v = 0) {{ v + i }}. -Check eq_refl : for (int i = 0; i <= 5; i++) updating (v = 0) {{ v + i }} = 15. -Check eq_refl : for (int i = 0; i < 5; i++) updating (v = 0) {{ v + i }} = 10. -Check eq_refl : for (int i = 5; i >= 0; i--) updating (v = 0) {{ v + i }} = 15. -Check eq_refl : for (int i = 5; i > -1; i--) updating (v = 0) {{ v + i }} = 15. -Check eq_refl : for (int i = 0; i <= 5; i += 2) updating (v = 0) {{ v + i }} = 6. -Check eq_refl : for (int i = 0; i < 5; i += 2) updating (v = 0) {{ v + i }} = 6. -Check eq_refl : for (int i = 5; i > -1; i -= 2) updating (v = 0) {{ v + i }} = 9. -Check eq_refl : for (int i = 5; i >= 0; i -= 2) updating (v = 0) {{ v + i }} = 9. -Check eq_refl : for (int i = 0; i <= 6; i += 2) updating (v = 0) {{ v + i }} = 12. -Check eq_refl : for (int i = 0; i < 6; i += 2) updating (v = 0) {{ v + i }} = 6. -Check eq_refl : for (int i = 6; i > -1; i -= 2) updating (v = 0) {{ v + i }} = 12. -Check eq_refl : for (int i = 6; i >= 0; i -= 2) updating (v = 0) {{ v + i }} = 12. - -Local Notation for_sumT n' - := (let n := Z.pos n' in - (2 * - for (int i = 0; i <= n; i++) updating (v = 0) {{ - v + i - }})%Z - = n * (n + 1)) - (only parsing). - -Check eq_refl : for_sumT 5. - -(** Here we show that if we add the numbers from 0 to n, we get [n * (n + 1) / 2] *) -Example for_sum n' : for_sumT n'. -Proof. - intro n. - apply for_loop_ind_le1. - { compute; reflexivity. } - { intros; nia. } -Qed. diff --git a/src/Util/ForLoop/Unrolling.v b/src/Util/ForLoop/Unrolling.v deleted file mode 100644 index 95b46e711..000000000 --- a/src/Util/ForLoop/Unrolling.v +++ /dev/null @@ -1,314 +0,0 @@ -(** * Proving properties of for-loops via loop-unrolling *) -Require Import Coq.micromega.Psatz. -Require Import Coq.ZArith.ZArith. -Require Import Crypto.Util.ForLoop. -Require Import Crypto.Util.ForLoop.Instances. -Require Import Crypto.Util.ZUtil. -Require Import Crypto.Util.Bool. -Require Import Crypto.Util.Tactics.RewriteHyp. -Require Import Crypto.Util.Tactics.BreakMatch. -Require Import Crypto.Util.Tactics.DestructHead. -Require Import Crypto.Util.Notations. - -Section with_body. - Context {stateT : Type} - (body : nat -> stateT -> stateT). - - Lemma unfold_repeat_function (count : nat) (st : stateT) - : repeat_function body count st - = match count with - | O => st - | S count' => repeat_function body count' (body count st) - end. - Proof using Type. destruct count; reflexivity. Qed. - - Lemma repeat_function_unroll1_start (count : nat) (st : stateT) - : repeat_function body (S count) st - = repeat_function body count (body (S count) st). - Proof using Type. rewrite unfold_repeat_function; reflexivity. Qed. - - Lemma repeat_function_unroll1_end (count : nat) (st : stateT) - : repeat_function body (S count) st - = body 1 (repeat_function (fun count => body (S count)) count st). - Proof using Type. - revert st; induction count as [|? IHcount]; [ reflexivity | ]. - intros; simpl in *; rewrite <- IHcount; reflexivity. - Qed. - - Lemma repeat_function_unroll1_start_match (count : nat) (st : stateT) - : repeat_function body count st - = match count with - | 0 => st - | S count' => repeat_function body count' (body count st) - end. - Proof using Type. destruct count; [ reflexivity | apply repeat_function_unroll1_start ]. Qed. - - Lemma repeat_function_unroll1_end_match (count : nat) (st : stateT) - : repeat_function body count st - = match count with - | 0 => st - | S count' => body 1 (repeat_function (fun count => body (S count)) count' st) - end. - Proof using Type. destruct count; [ reflexivity | apply repeat_function_unroll1_end ]. Qed. -End with_body. - -Local Open Scope bool_scope. -Local Open Scope Z_scope. - -Section for_loop. - Context (i0 finish : Z) (step : Z) {stateT} (initial : stateT) (body : Z -> stateT -> stateT) - (Hgood : Z.sgn step = Z.sgn (finish - i0)). - - Let countZ := (Z.quot (finish - i0 + step - Z.sgn step) step). - Let count := Z.to_nat countZ. - Let of_nat_count c := (i0 + step * Z.of_nat (count - c)). - Let nat_body := (fun c => body (of_nat_count c)). - - Lemma for_loop_empty - (Heq : finish = i0) - : for_loop i0 finish step initial body = initial. - Proof. - subst; unfold for_loop. - rewrite Z.sub_diag, Z.quot_sub_sgn; autorewrite with zsimplify_const. - reflexivity. - Qed. - - Lemma for_loop_unroll1 - : for_loop i0 finish step initial body - = if finish =? i0 - then initial - else let initial' := body i0 initial in - if Z.abs (finish - i0) <=? Z.abs step - then initial' - else for_loop (i0 + step) finish step initial' body. - Proof. - break_innermost_match_step; Z.ltb_to_lt. - { apply for_loop_empty; assumption. } - { unfold for_loop. - rewrite repeat_function_unroll1_start_match. - destruct (Z_zerop step); - repeat first [ progress break_innermost_match - | congruence - | lia - | progress Z.ltb_to_lt - | progress subst - | progress rewrite Nat.sub_diag - | progress autorewrite with zsimplify_const in * - | progress rewrite Z.quot_small_iff in * by omega - | progress rewrite Z.quot_small_abs in * by lia - | rewrite Nat.sub_succ_l by omega - | progress destruct_head' and - | rewrite !Z.sub_add_distr - | match goal with - | [ H : ?x = Z.of_nat _ |- context[?x] ] => rewrite H - | [ H : Z.abs ?x <= 0 |- _ ] => assert (x = 0) by lia; clear H - | [ H : 0 = Z.sgn ?x |- _ ] => assert (x = 0) by lia; clear H - | [ H : ?x - ?y = 0 |- _ ] => is_var x; assert (x = y) by omega; subst x - | [ H : Z.to_nat _ = _ |- _ ] => apply Nat2Z.inj_iff in H - | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_eq x) in H by omega - | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_neq x) in H by omega - | [ H : Z.of_nat (Z.to_nat _) = _ |- _ ] - => rewrite Z2Nat.id in H by (apply Z.quot_nonneg_same_sgn; lia) - | [ H : _ = Z.of_nat (S ?x) |- _ ] - => is_var x; destruct x; [ reflexivity | ] - | [ H : ?x + 1 = Z.of_nat (S ?y) |- _ ] - => assert (x = Z.of_nat y) by lia; clear H - | [ |- repeat_function _ ?x ?y = repeat_function _ ?x ?y ] - => apply repeat_function_Proper_le; intros - | [ |- ?f _ ?x = ?f _ ?x ] - => is_var f; apply f_equal2; [ | reflexivity ] - end - | progress rewrite Z.quot_add_sub_sgn_small in * |- by lia - | progress autorewrite with zsimplify ]. } - Qed. -End for_loop. - -Lemma for_loop_notation_empty {stateT} - {i0 : Z} {step : Z} {finish : Z} {initial : stateT} - {cmp : Z -> Z -> bool} - {step_expr finish_expr} (body : Z -> stateT -> stateT) - {Hstep : class_eq (fun i => i = step) step_expr} - {Hfinish : class_eq (fun i => cmp i finish) finish_expr} - {Hgood : for_loop_is_good i0 step finish cmp} - (Heq : i0 = finish) - : @for_loop_notation i0 step finish _ initial cmp step_expr finish_expr body Hstep Hfinish Hgood = initial. -Proof. - unfold for_loop_notation, for_loop_is_good in *; split_andb; Z.ltb_to_lt. - apply for_loop_empty; auto. -Qed. - -Local Notation adjust_bool b p - := (match b as b' return b' = true -> b' = true with - | true => fun _ => eq_refl - | false => fun x => x - end p). - -Lemma for_loop_is_good_step_gen - cmp - (Hcmp : cmp = Z.ltb \/ cmp = Z.gtb) - {i0 step finish} - {H : for_loop_is_good i0 step finish cmp} - (H' : cmp (i0 + step) finish = true) - : for_loop_is_good (i0 + step) step finish cmp. -Proof. - unfold for_loop_is_good in *. - rewrite H', Bool.andb_true_r. - destruct Hcmp; subst; - split_andb; Z.ltb_to_lt; - [ rewrite (Z.sgn_pos (finish - i0)) in * by omega - | rewrite (Z.sgn_neg (finish - i0)) in * by omega ]; - destruct step; simpl in *; try congruence; - symmetry; - [ apply Z.sgn_pos_iff | apply Z.sgn_neg_iff ] - ; omega. -Qed. - -Definition for_loop_is_good_step_lt - {i0 step finish} - {H : for_loop_is_good i0 step finish Z.ltb} - (H' : Z.ltb (i0 + step) finish = true) - : for_loop_is_good (i0 + step) step finish Z.ltb - := for_loop_is_good_step_gen Z.ltb (or_introl eq_refl) (H:=H) H'. -Definition for_loop_is_good_step_gt - {i0 step finish} - {H : for_loop_is_good i0 step finish Z.gtb} - (H' : Z.gtb (i0 + step) finish = true) - : for_loop_is_good (i0 + step) step finish Z.gtb - := for_loop_is_good_step_gen Z.gtb (or_intror eq_refl) (H:=H) H'. -Definition for_loop_is_good_step_lt' - {i0 finish} - {H : for_loop_is_good i0 1 (finish + 1) Z.ltb} - (H' : Z.ltb i0 finish = true) - : for_loop_is_good (i0 + 1) 1 (finish + 1) Z.ltb. -Proof. - apply for_loop_is_good_step_lt; Z.ltb_to_lt; omega. -Qed. -Definition for_loop_is_good_step_gt' - {i0 finish} - {H : for_loop_is_good i0 (-1) (finish - 1) Z.gtb} - (H' : Z.gtb i0 finish = true) - : for_loop_is_good (i0 - 1) (-1) (finish - 1) Z.gtb. -Proof. - apply for_loop_is_good_step_gt; Z.ltb_to_lt; omega. -Qed. - -Local Hint Extern 1 (for_loop_is_good (?i0 + ?step) ?step ?finish _) -=> refine (adjust_bool _ (for_loop_is_good_step_lt _)); try assumption : typeclass_instances. -Local Hint Extern 1 (for_loop_is_good (?i0 + ?step) ?step ?finish _) -=> refine (adjust_bool _ (for_loop_is_good_step_gt _)); try assumption : typeclass_instances. -Local Hint Extern 1 (for_loop_is_good (?i0 - ?step') _ ?finish _) -=> refine (adjust_bool _ (for_loop_is_good_step_gt (step:=-step') _)); try assumption : typeclass_instances. -Local Hint Extern 1 (for_loop_is_good (?i0 + 1) 1 ?finish _) -=> refine (adjust_bool _ (for_loop_is_good_step_lt' _)); try assumption : typeclass_instances. -Local Hint Extern 1 (for_loop_is_good (?i0 - 1) (-1) ?finish _) -=> refine (adjust_bool _ (for_loop_is_good_step_gt' _)); try assumption : typeclass_instances. - -Local Ltac t := - repeat match goal with - | _ => progress unfold for_loop_is_good, for_loop_notation in * - | _ => progress rewrite for_loop_unroll1 by auto - | _ => omega - | _ => progress subst - | _ => reflexivity - | _ => progress split_andb - | _ => progress Z.ltb_to_lt - | _ => progress break_innermost_match_step - | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_eq x) in H by omega - | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_neq x) in H by omega - | [ H : context[Z.sgn ?x] |- _ ] => rewrite (Z.sgn_pos x) in H by omega - | [ H : context[Z.sgn ?x] |- _ ] => rewrite (Z.sgn_neg x) in H by omega - | [ H : Z.sgn _ = 1 |- _ ] => apply Z.sgn_pos_iff in H - | [ H : Z.sgn _ = -1 |- _ ] => apply Z.sgn_neg_iff in H - end. - -Lemma for_loop_lt_unroll1 {stateT} - {i0 : Z} {step : Z} {finish : Z} {initial : stateT} - {step_expr finish_expr} (body : Z -> stateT -> stateT) - {Hstep : class_eq (fun i => i = step) step_expr} - {Hfinish : class_eq (fun i => Z.ltb i finish) finish_expr} - {Hgood : for_loop_is_good i0 step finish Z.ltb} - : (@for_loop_notation i0 step finish _ initial Z.ltb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish Hgood) - = let initial' := body i0 initial in - if Sumbool.sumbool_of_bool (Z.ltb (i0 + step) finish) - then @for_loop_notation (i0 + step) step finish _ initial' Z.ltb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish _ - else initial'. -Proof. t. Qed. - -Lemma for_loop_gt_unroll1 {stateT} - {i0 : Z} {step : Z} {finish : Z} {initial : stateT} - {step_expr finish_expr} (body : Z -> stateT -> stateT) - {Hstep : class_eq (fun i => i = step) step_expr} - {Hfinish : class_eq (fun i => Z.gtb i finish) finish_expr} - {Hgood : for_loop_is_good i0 step finish Z.gtb} - : (@for_loop_notation i0 step finish _ initial Z.gtb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish Hgood) - = let initial' := body i0 initial in - if Sumbool.sumbool_of_bool (Z.gtb (i0 + step) finish) - then @for_loop_notation (i0 + step) step finish _ initial' Z.gtb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish _ - else initial'. -Proof. t. Qed. - -Lemma for_loop_lt1_unroll1 {stateT} - {i0 : Z} {finish : Z} {initial : stateT} - {body : Z -> stateT -> stateT} - {Hgood : for_loop_is_good i0 1 finish _} - : for (int i = i0; i < finish; i++) updating (st = initial) {{ - body i st - }} - = let initial' := body i0 initial in - if Sumbool.sumbool_of_bool (Z.ltb (i0 + 1) finish) - then for (int i = i0+1; i < finish; i++) updating (st = initial') {{ - body i st - }} - else initial'. -Proof. apply for_loop_lt_unroll1. Qed. - -Lemma for_loop_gt1_unroll1 {stateT} - {i0 : Z} {finish : Z} {initial : stateT} - {body : Z -> stateT -> stateT} - {Hgood : for_loop_is_good i0 (-1) finish _} - : for (int i = i0; i > finish; i--) updating (st = initial) {{ - body i st - }} - = let initial' := body i0 initial in - if Sumbool.sumbool_of_bool (Z.gtb (i0 - 1) finish) - then for (int i = i0-1; i > finish; i--) updating (st = initial') {{ - body i st - }} - else initial'. -Proof. apply for_loop_gt_unroll1. Qed. - -Lemma for_loop_le1_unroll1 {stateT} - {i0 : Z} {finish : Z} {initial : stateT} - {body : Z -> stateT -> stateT} - {Hgood : for_loop_is_good i0 1 (finish+1) _} - : for (int i = i0; i <= finish; i++) updating (st = initial) {{ - body i st - }} - = let initial' := body i0 initial in - if Sumbool.sumbool_of_bool (Z.ltb i0 finish) - then for (int i = i0+1; i <= finish; i++) updating (st = initial') {{ - body i st - }} - else initial'. -Proof. - rewrite for_loop_lt_unroll1; unfold for_loop_notation. - break_innermost_match; Z.ltb_to_lt; try omega; try reflexivity. -Qed. - -Lemma for_loop_ge1_unroll1 {stateT} - {i0 : Z} {finish : Z} {initial : stateT} - {body : Z -> stateT -> stateT} - {Hgood : for_loop_is_good i0 (-1) (finish-1) _} - : for (int i = i0; i >= finish; i--) updating (st = initial) {{ - body i st - }} - = let initial' := body i0 initial in - if Sumbool.sumbool_of_bool (Z.gtb i0 finish) - then for (int i = i0-1; i >= finish; i--) updating (st = initial') {{ - body i st - }} - else initial'. -Proof. - rewrite for_loop_gt_unroll1; unfold for_loop_notation. - break_innermost_match; Z.ltb_to_lt; try omega; try reflexivity. -Qed. diff --git a/src/Util/Loop.v b/src/Util/Loop.v deleted file mode 100644 index 2ee90d11f..000000000 --- a/src/Util/Loop.v +++ /dev/null @@ -1,480 +0,0 @@ -(** * Definition and Notations for [do { body }] *) -Require Import Coq.ZArith.BinInt. -Require Import Coq.Classes.Morphisms. -Require Import Coq.micromega.Lia. -Require Import Coq.omega.Omega. -Require Import Crypto.Util.ZUtil. -Require Import Crypto.Util.ZUtil.Z2Nat. -Require Import Crypto.Util.Notations Crypto.Util.CPSNotations. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.Tactics.SpecializeBy. -Require Import Crypto.Util.Tactics.DestructHead. -Require Import Crypto.Util.Tactics.BreakMatch. -Require Import Crypto.Util.Tactics.SplitInContext. -Require Import Crypto.Util.Tactics.UniquePose. - -Section with_state. - Import CPSNotations. - Context {state : Type}. - - Definition loop_cps_step - (loop_cps - : forall (max_iter : nat) - (initial : state) - (body : state -> forall {T} (continue : state -> T) (break : state -> T), T), - ~> state) - (max_iter : nat) - : forall (initial : state) - (body : state -> forall {T} (continue : state -> T) (break : state -> T), T), - ~> state - := fun st body - => match max_iter with - | 0 - => (return st) - | S max_iter' - => fun T ret - => body st T (fun st' => @loop_cps max_iter' st' body _ ret) ret - end. - - Fixpoint loop_cps (max_iter : nat) - : forall (initial : state) - (body : state -> forall {T} (continue : state -> T) (break : state -> T), T), - ~> state - := loop_cps_step loop_cps max_iter. - - Lemma unfold_loop_cps - (max_iter : nat) - : loop_cps max_iter - = loop_cps_step loop_cps max_iter. - Proof. - destruct max_iter; reflexivity. - Qed. - - Theorem loop_cps_def (max_iter : nat) - (initial : state) - (body : state -> forall {T} (continue : state -> T) (break : state -> T), T) - T ret - : loop_cps (S max_iter) initial body T ret = - body initial (fun st' => @loop_cps max_iter st' body _ ret) ret. - Proof. - reflexivity. - Qed. - - Theorem loop_cps_ind - (invariant : state -> Prop) - T (P : T -> Prop) n v0 body rest - : invariant v0 - -> (forall v continue break, - (forall v, invariant v -> P (continue v)) - -> (forall v, invariant v -> P (break v)) - -> invariant v - -> P (body v T continue break)) - -> (forall v, invariant v -> P (rest v)) - -> P (loop_cps n v0 body T rest). - Proof. - revert v0 rest. - induction n as [|n IHn]; intros v0 rest Hinv Hbody HP; simpl; cbv [cpsreturn]; auto. - Qed. - - Local Hint Extern 2 => omega. - - (** TODO(andreser): Remove this if we don't need it *) - Theorem loop_cps_wf_ind_break - (measure : state -> nat) - (invariant : state -> Prop) - T (P : T -> Prop) n v0 body rest - : invariant v0 - -> (forall v continue, - invariant v - -> (forall break, - (forall v', measure v' < measure v -> invariant v' -> P (continue v')) - -> P (body v T continue break)) - \/ P (body v T continue rest)) - -> measure v0 < n - -> P (loop_cps n v0 body T rest). - Proof. - revert v0 rest. - induction n as [|n IHn]; intros v0 rest Hinv Hbody Hmeasure; simpl; try omega. - edestruct Hbody as [Hbody'|Hbody']; eauto. - Qed. - - Theorem loop_cps_wf_ind - (measure : state -> nat) - (invariant : state -> Prop) - T (P : T -> Prop) n v0 body rest - : invariant v0 - -> (forall v continue, - invariant v - -> ((forall v', measure v' < measure v -> invariant v' -> P (continue v')) - -> P (body v T continue rest))) - -> measure v0 < n - -> P (loop_cps n v0 body T rest). - Proof. - revert v0. - induction n as [|n IHn]; intros v0 Hinv Hbody Hmeasure; simpl; try omega. - eauto. - Qed. - - Theorem eq_loop_cps_large_n - (measure : state -> nat) - n n' v0 body T rest - : measure v0 < n - -> measure v0 < n' - -> (forall v continue continue', - (forall v', measure v' < measure v -> continue v' = continue' v') - -> measure v <= measure v0 - -> body v T continue rest = body v T continue' rest) - -> loop_cps n v0 body T rest = loop_cps n' v0 body T rest. - Proof. - revert n n'. - match goal with - | [ |- forall n n', ?P ] => cut (forall n n', n <= n' -> P) - end. - { intros H n n' ???. - destruct (le_lt_dec n n'); [ | symmetry ]; auto. } - { intros n n' Hle Hn _. - revert n' Hle v0 Hn. - induction n as [|n IHn], n' as [|n']; simpl; - auto with omega. } - Qed. -End with_state. - -(** N.B. If the body is polymorphic (that is, if the type argument - shows up in the body), then we need to bind the name of the type - parameter somewhere in the notation for it to show up; we have a - separate notation for this case. *) -(** TODO: When these notations are finalized, reserve them in Notations.v and moving the level and formatting rules there *) -Notation "'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue , break , @ T ) {{ body }} ; rest" - := (@loop_cps _ fuel initial - (fun state1 => .. (fun staten => id (fun T continue break => body)) .. ) - _ (fun state1 => .. (fun staten => rest) .. )) - (at level 200, state1 binder, staten binder, - format "'[v ' 'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue , break , @ T ) {{ '//' body ']' '//' }} ; '//' rest"). -Notation "'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue , break ) {{ body }} ; rest" - := (@loop_cps _ fuel initial - (fun state1 => .. (fun staten => id (fun T continue break => body)) .. ) - _ (fun state1 => .. (fun staten => rest) .. )) - (at level 200, state1 binder, staten binder, - format "'[v ' 'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue , break ) {{ '//' body ']' '//' }} ; '//' rest"). -Notation "'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue ) {{ body }} ; rest" - := (@loop_cps _ fuel initial - (fun state1 => .. (fun staten => id (fun T continue _ => body)) .. ) - _ (fun state1 => .. (fun staten => rest) .. )) - (at level 200, state1 binder, staten binder, - format "'[v ' 'loop' _{ fuel } ( state1 .. staten = initial ) 'labels' ( continue ) {{ '//' body ']' '//' }} ; '//' rest"). - -Section with_for_state. - Import CPSNotations. - Section with_loop_params. - Context {state : Type}. - Context (test : Z -> Z -> bool) (i_final : Z) (upd_i : Z -> Z) - (body : state -> Z -> forall {T} (continue : state -> T) (break : state -> T), T). - - (* we assume that [upd_i] is linear to compute the fuel *) - Definition for_cps (i0 : Z) (initial : state) - : ~> state - := fun T ret - => @loop_cps - (Z * state) - (S (S (Z.to_nat ((i_final - i0) / (upd_i 0%Z))))) - (i0, initial) - (fun '(i, st) T continue break - => if test i i_final - then @body st i T - (fun st' => continue (upd_i i, st')%Z) - (fun st' => break (i, st')) - else break (i, st)) - T (fun '(i, st) => ret st). - - Section lemmas. - Local Open Scope Z_scope. - Context (upd_linear : forall x, upd_i x = upd_i 0 + x) - (upd_nonzero : upd_i 0 <> 0) - (upd_signed : forall i0, test i0 i_final = true -> 0 <= (i_final - i0) / (upd_i 0)). - - (** TODO: Strengthen this to take into account the value of - the loop counter at the end of the loop; based on - [ForLoop.v], it should be something like [(finish - - Z.sgn (finish - i0 + step - Z.sgn step) * (Z.abs - (finish - i0 + step - Z.sgn step) mod Z.abs step) + - step - Z.sgn step)] *) - Theorem for_cps_ind - (invariant : Z -> state -> Prop) - T (P : (*Z ->*) T -> Prop) i0 v0 rest - : invariant i0 v0 - -> (forall i v continue, - test i i_final = true - -> (forall v, invariant (upd_i i) v -> P (continue v)) - -> invariant i v - -> P (@body v i T continue rest)) - -> (forall i v, test i i_final = false -> invariant i v -> P (rest v)) - -> P (for_cps i0 v0 T rest). - Proof. - unfold for_cps, cpscall, cpsreturn. - intros Hinv IH Hrest. - eapply @loop_cps_wf_ind with (T:=T) - (invariant := fun '(i, s) => invariant i s) - (measure := fun '(i, s) => Z.to_nat (1 + (i_final - i) / upd_i 0)); - [ assumption - | - | solve [ destruct (Z_le_gt_dec 0 ((i_final - i0) / upd_i 0)); - [ rewrite Z2Nat.inj_add by omega; simpl; omega - | rewrite Z2Nat.inj_nonpos by omega; omega ] ] ]. - intros [i st] continue Hinv' IH'. - destruct (test i i_final) eqn:Hi; [ | solve [ eauto ] ]. - pose proof (upd_signed _ Hi) as upd_signed'. - assert (upd_i 0 <> 0) - by (intro H'; rewrite H' in upd_signed'; autorewrite with zsimplify in upd_signed'; - omega). - specialize (IH i st (fun st' => continue (upd_i i, st')) Hi). - specialize (fun v pf => IH' (upd_i i, v) pf). - cbv beta iota in *. - specialize (fun pf v => IH' v pf). - rewrite upd_linear in IH'. - replace ((i_final - (upd_i 0 + i)) / upd_i 0) - with ((i_final - i) / upd_i 0 - 1) - in IH' - by (Z.div_mod_to_quot_rem; nia). - rewrite <- upd_linear, Zplus_minus, Z2Nat.inj_add in IH' by omega. - auto with omega. - Qed. - - Theorem for_cps_unroll1 - T i0 v0 rest - (body_Proper : Proper (eq ==> eq ==> forall_relation (fun T => (pointwise_relation _ eq) ==> eq ==> eq)) body) - : for_cps i0 v0 T rest - = if test i0 i_final - then @body v0 i0 T - (fun v => for_cps (upd_i i0) v T rest) - rest - else rest v0. - Proof. - unfold for_cps at 1. - rewrite loop_cps_def. - destruct (test i0 i_final) eqn:Hi; [ | reflexivity ]. - apply body_Proper; [ reflexivity | reflexivity | intro st | reflexivity ]. - assert (Hto0 : forall x, x <= 0 -> Z.to_nat x = 0%nat) - by (intros []; intros; simpl; lia). - apply eq_loop_cps_large_n with (measure := fun '(i, st) => Z.to_nat (1 + (i_final - i) / upd_i 0)); - repeat first [ progress intros - | reflexivity - | progress destruct_head'_prod - | progress unfold pointwise_relation - | break_innermost_match_step - | apply body_Proper - | omega - | rewrite Zdiv.Zdiv_0_r in * - | rewrite ?(Z.add_opp_r _ 1), Zplus_minus, <- ?(Z.add_opp_r _ 1) in * - | match goal with - | [ H : forall v, _ -> ?continue _ = ?continue' _ |- ?continue _ = ?continue' _ ] => apply H - | [ |- context[upd_i ?x] ] - => lazymatch x with - | 0 => fail - | _ => rewrite (upd_linear x) - end - | [ H : ?x = 0 |- context[?x] ] => rewrite H - | [ H : ?x = 0, H' : context[?x] |- _ ] => rewrite H in H' - | [ H : forall i, ?f i ?y = true -> ?R (_ / 0), H' : ?f _ ?y = true |- _ ] - => specialize (H _ H') - | [ |- context[?i_final - (upd_i 0 + ?i0)] ] - => replace (i_final - (upd_i 0 + i0)) with ((i_final - i0) + (-1) * upd_i 0) by omega; - rewrite Zdiv.Z_div_plus_full by assumption - end - | lazymatch goal with - | [ H : upd_i 0 = 0 |- _ ] => fail - | [ H : upd_i 0 <> 0 |- _ ] => fail - | _ => destruct (Z_zerop (upd_i 0)) - end - | match goal with - | [ |- context[S (Z.to_nat ?x)] ] - => destruct (Z_lt_le_dec 0 x); - [ rewrite <- (Z2Nat.inj_succ x) by omega - | rewrite !(Hto0 x) by omega ] - | [ |- (Z.to_nat _ < Z.to_nat _)%nat ] - => apply Z2Nat.inj_lt - | [ |- (Z.to_nat ?x < ?n)%nat ] - => apply (Z2Nat.inj_lt x (Z.of_nat n)); simpl - | [ H : forall i, ?f i ?y = true -> 0 <= _ / _, H' : ?f _ ?y = true |- _ ] - => specialize (H _ H') - end ]. - Qed. - End lemmas. - End with_loop_params. -End with_for_state. - -Delimit Scope for_upd_scope with for_upd. -Delimit Scope for_test_scope with for_test. -Notation "i += k" := (Z.add i k) : for_upd_scope. -Notation "i -= k" := (Z.sub i k) : for_upd_scope. -Notation "i ++" := (i += 1)%for_upd : for_upd_scope. -Notation "i --" := (i -= 1)%for_upd : for_upd_scope. -Notation "<" := Z.ltb (at level 71) : for_test_scope. -Notation ">" := Z.gtb (at level 71) : for_test_scope. -Notation "<=" := Z.leb (at level 71) : for_test_scope. -Notation ">=" := Z.geb (at level 71) : for_test_scope. -Notation "≤" := Z.leb (at level 71) : for_test_scope. -Notation "≥" := Z.geb (at level 71) : for_test_scope. -Global Close Scope for_test_scope. (* TODO: make these notations not print all over the place *) - -Definition force_idZ (f : Z -> Z) (pf : f = id) {T} (v : T) := v. -(** [lhs] and [cmp_expr] go at level 9 so that they bind more tightly - than application (so that [i (<)] sticks [i] in [lhs] and [(<)] in - [cmp_expr], rather than sticking [i (<)] in [lhs] and then - complaining about a missing value for [cmp_expr]. Unfortunately, - because the comparison operators need to be at level > 70 to not - conflict with their infix versions, putting [cmp_expr] at level 9 - forces us to wrap parentheses around the comparison operator. *) -(** TODO(andreser): If it's worth it, duplicate these notations for - each value of [cmp_expr] so that we don't need to wrap the - comparison operator in parentheses. *) -(** TODO: When these notations are finalized, reserve them in Notations.v and moving the level and formatting rules there *) -Notation "'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue , break , @ T ) {{ body }} ; rest" - := (force_idZ - (fun i1 => .. (fun i2 => lhs) ..) - eq_refl - (@for_cps _ cmp_expr%for_test - final - (fun i1 => .. (fun i2 => upd_expr%for_upd) .. ) - (fun state1 => .. (fun staten => id (fun i1 => .. (fun i2 => id (fun T continue break => body)) .. )) .. ) - i0 - initial - _ (fun state1 => .. (fun staten => rest) .. ))) - (at level 200, state1 binder, staten binder, i1 binder, i2 binder, lhs at level 9, cmp_expr at level 9, - format "'[v ' 'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue , break , @ T ) {{ '//' body ']' '//' }} ; '//' rest"). -Notation "'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue , break ) {{ body }} ; rest" - := (force_idZ - (fun i1 => .. (fun i2 => lhs) ..) - eq_refl - (@for_cps _ cmp_expr%for_test - final - (fun i1 => .. (fun i2 => upd_expr%for_upd) .. ) - (fun state1 => .. (fun staten => id (fun i1 => .. (fun i2 => id (fun T continue break => body)) .. )) .. ) - i0 - initial - _ (fun state1 => .. (fun staten => rest) .. ))) - (at level 200, state1 binder, staten binder, i1 binder, i2 binder, lhs at level 9, cmp_expr at level 9, - format "'[v ' 'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue , break ) {{ '//' body ']' '//' }} ; '//' rest"). -Notation "'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue ) {{ body }} ; rest" - := (force_idZ - (fun i1 => .. (fun i2 => lhs) ..) - eq_refl - (@for_cps _ cmp_expr%for_test - final - (fun i1 => .. (fun i2 => upd_expr%for_upd) .. ) - (fun state1 => .. (fun staten => id (fun i1 => .. (fun i2 => id (fun T continue break => body)) .. )) .. ) - i0 - initial - _ (fun state1 => .. (fun staten => rest) .. ))) - (at level 200, state1 binder, staten binder, i1 binder, i2 binder, lhs at level 9, cmp_expr at level 9, - format "'[v ' 'for' ( i1 .. i2 = i0 ; lhs cmp_expr final ; upd_expr ) 'updating' ( state1 .. staten = initial ) 'labels' ( continue ) {{ '//' body ']' '//' }} ; '//' rest"). - -Section LoopTest. - Import CPSNotations. - Check loop _{ 10 } (x = 0) labels (continue, break) {{ continue (x + 1) }} ; x. - - Check - loop _{ 1234 } ('(i, a) = (0, 0)) labels (continue, break) - {{ - if i <? 10 - then - continue (i + 1, a+1) - else - break (0, a) - }}; - a. - - Context (f:nat~>nat). - - Check x <- f 0 ; return x + x. - Check x <- f 0 ; y <- f x; z <- f y; return (x,y,z). - - Check loop _{ 10 } (x = 0) labels (continue, break) {{ continue (x + 1) }} ; - return x. - - Check loop _{ 10 } (x = 0) labels (continue, break) {{ continue (x + 1) }} ; - x <- f x; - return x. - - Check loop _{ 10 } (x = 0) labels (continue, break) {{ x <- f x ; continue (x) }} ; x. - - - Context (s F : Z) (zero : nat). - Check for ( i = s; i (<) F; i++) updating (P = zero) labels (continue, break) - {{ - continue (P+P) - }}; - P. - - Check for ( i = s; i (<) F; i++) updating (P = zero) labels (continue) - {{ - P <- f P; - continue (P+P) - }}; - P. -End LoopTest. - -Module CPSBoilerplate. - Import CPSNotations. - Definition valid {R} (f:~>R) := - forall {T} (continuation:R->T), - (x <- f; continuation x) = (continuation (f _ id)). - Existing Class valid. -End CPSBoilerplate. - -Require Import Coq.Classes.Morphisms. -Require Import Crypto.Algebra.ScalarMult. -Section ScalarMult. - Import CPSNotations. - - Context {G} (zero:G) (k w : Z) (add_tbl : G -> Z -> Z ~> G) (nth_limb : Z ~> Z). (* k w-bit limbs *) - - Definition ScalarMultBase := - for ( i = 0; i (<) k; i++) updating (P = zero) labels (continue, break) - {{ - x <- nth_limb i; - P <- add_tbl P i x; - continue P - }}; - P. - - Context {Geq add opp} {Hmonoid:@Algebra.Hierarchy.group G Geq add zero opp}. - Local Notation smul := (@scalarmult_ref G add zero opp). - Context {nth_limb_valid : forall a, CPSBoilerplate.valid (nth_limb a)}. - Context {add_tbl_valid : forall a b c, CPSBoilerplate.valid (add_tbl a b c)}. - Context {Proper_add_tbl : Proper (Geq ==> eq ==> eq ==> Geq) (fun a b c => add_tbl a b c _ id)}. - Context (B:G). - Context {limb_good} - {nth_limb_good: forall i, (0 <= i < k)%Z -> limb_good i (nth_limb i _ id)} - {add_tbl_correct : forall P i limb, - limb_good i limb -> Geq (add_tbl P i limb G id) (add P (smul (2 ^ i * limb) B))}. - - Definition n_upto t : Z := - for ( i = 0; i (<) t; i++) updating (n = 0%Z) labels (continue) - {{ - x <- nth_limb i; - continue (n + (2^i)*x)%Z - }}; - n. - - - Lemma ScalarMultBase_correct : Geq ScalarMultBase (smul (n_upto k) B). - cbv [ScalarMultBase]. - eapply for_cps_ind with (invariant := fun i P => Geq P (smul (n_upto i) B )%Z). - - intros; omega. - - omega. - - intros; rewrite Z.ltb_lt in H; autorewrite with zsimplify; omega. - - autorewrite with zsimplify. symmetry; eapply (scalarmult_0_l(add:=add)). - - cbv [force_idZ id]; intros. clear H. - setoid_rewrite nth_limb_valid; setoid_rewrite add_tbl_valid. - setoid_rewrite <-H0; [reflexivity|]; clear H0. - - etransitivity. - eapply Proper_add_tbl; [eapply H1|reflexivity|reflexivity]. - clear H1. - replace (n_upto (i+1))%Z with (n_upto i + (2^i)*(nth_limb i _ id))%Z by admit. - rewrite scalarmult_add_l. - rewrite add_tbl_correct; [reflexivity|]. - apply nth_limb_good. - admit. - Admitted. -End ScalarMult. diff --git a/src/Util/Loops.v b/src/Util/Loops.v new file mode 100644 index 000000000..56b27dac7 --- /dev/null +++ b/src/Util/Loops.v @@ -0,0 +1,526 @@ +Require Import Coq.Lists.List. +Require Import Lia. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.DestructHead. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Sum. +Require Import Crypto.Util.LetIn. + +Require Import Crypto.Util.CPSNotations. + +Module Import core. + Section Loops. + Context {A B : Type} (body : A -> A + B). + + (* the fuel parameter is only present to allow defining a loop + without proving its termination. The loop body does not have + access to the amount of remaining fuel, and thus increasing fuel + beyond termination cannot change the behavior. fuel counts full + loops -- the one that exacutes "break" is not included *) + + Fixpoint loop (fuel : nat) (s : A) {struct fuel} : A + B := + let s := body s in + match s with + | inl a => + match fuel with + | O => inl a + | S fuel' => loop fuel' a + end + | inr b => inr b + end. + + Context (body_cps : A ~> A + B). + + Fixpoint loop_cps (fuel : nat) (s : A) {struct fuel} :~> A + B := + s <- body_cps s; + match s with + | inl a => + match fuel with + | O => return (inl a) + | S fuel' => loop_cps fuel' a + end + | inr b => return (inr b) + end. + + Context (body_cps_ok : forall s {R} f, body_cps s R f = f (body s)). + Lemma loop_cps_ok n s {R} f : loop_cps n s R f = f (loop n s). + Proof. + revert f; revert R; revert s; induction n; + repeat match goal with + | _ => progress intros + | _ => progress cbv [cpsreturn cpscall] in * + | _ => progress cbn + | _ => progress rewrite ?body_cps_ok + | _ => progress rewrite ?IHn + | _ => progress break_match + | _ => reflexivity + end. + Qed. + + Context (body_cps2 : A -> forall {R}, (A -> R) -> (B -> R) -> R). + Fixpoint loop_cps2 (fuel : nat) (s : A) {R} (timeout:A->R) (ret:B->R) {struct fuel} : R := + body_cps2 s R + (fun a => + match fuel with + | O => timeout a + | S fuel' => @loop_cps2 fuel' a R timeout ret + end) + (fun b => ret b). + + Context (body_cps2_ok : forall s {R} continue break, + body_cps2 s R continue break = + match body s with + | inl a => continue a + | inr b => break b + end). + Lemma loop_cps2_ok n s {R} (timeout ret : _ -> R) : + @loop_cps2 n s R timeout ret = + match loop n s with + | inl a => timeout a + | inr b => ret b + end. + Proof. + revert timeout; revert ret; revert R; revert s; induction n; + repeat match goal with + | _ => progress intros + | _ => progress cbv [cpsreturn cpscall] in * + | _ => progress cbn + | _ => progress rewrite ?body_cps2_ok + | _ => progress rewrite ?IHn + | _ => progress inversion_sum + | _ => progress subst + | _ => progress break_match + | _ => reflexivity + end. + Qed. + + Local Lemma loop_fuel_0 s : loop 0 s = body s. + Proof. cbv; break_match; reflexivity. Qed. + + Local Lemma loop_fuel_S_first n s : loop (S n) s = + match body s with + | inl a => loop n a + | inr b => inr b + end. + Proof. reflexivity. Qed. + + Local Lemma loop_fuel_S_last n s : loop (S n) s = + match loop n s with + | inl a => body a + | inr b => loop n s + end. + Proof. + revert s; induction n; cbn; intros s. + { break_match; reflexivity. } + { destruct (body s); cbn; rewrite <-?IHn; reflexivity. } + Qed. + + Local Lemma loop_fuel_S_stable n s b (H : loop n s = inr b) : loop (S n) s = inr b. + Proof. + revert H; revert b; revert s; induction n; intros ? ? H. + { cbn [loop nat_rect] in *; break_match_hyps; congruence_sum; congruence. } + { rewrite loop_fuel_S_last. + break_match; congruence_sum; reflexivity. } + Qed. + + Local Lemma loop_fuel_add_stable n m s b (H : loop n s = inr b) : loop (m+n) s = inr b. + Proof. + induction m; intros. + { rewrite PeanoNat.Nat.add_0_l. assumption. } + { rewrite PeanoNat.Nat.add_succ_l. + erewrite loop_fuel_S_stable; eauto. } + Qed. + + Lemma loop_fuel_irrelevant n m s bn bm + (Hn : loop n s = inr bn) + (Hm : loop m s = inr bm) + : bn = bm. + Proof. + destruct (Compare_dec.le_le_S_dec n m) as [H|H]; + destruct (PeanoNat.Nat.le_exists_sub _ _ H) as [d [? _]]; subst. + { erewrite loop_fuel_add_stable in Hm by eassumption; congruence. } + { erewrite loop_fuel_add_stable in Hn. + { congruence_sum. reflexivity. } + { erewrite loop_fuel_S_stable by eassumption. congruence. } } + Qed. + + Local Lemma by_invariant_fuel' (inv:_->Prop) measure P f s0 + (init : inv s0 /\ measure s0 <= f) + (step : forall s, inv s -> match body s with + | inl s' => inv s' /\ measure s' < measure s + | inr s' => P s' + end) + : match loop f s0 with + | inl a => False + | inr s => P s + end. + Proof. + revert dependent s0; induction f; intros; destruct_head'_and. + { specialize (step s0 H); cbv; break_innermost_match; [lia|assumption]. } + { rewrite loop_fuel_S_first. + specialize (step s0 H); destruct (body s0); [|assumption]. + destruct step. + exact (IHf a ltac:(split; (assumption || lia))). } + Qed. + + Lemma by_invariant_fuel (inv:_->Prop) measure P f s0 + (init : inv s0 /\ measure s0 <= f) + (step : forall s, inv s -> match body s with + | inl s' => inv s' /\ measure s' < measure s + | inr s' => P s' + end) + : exists b, loop f s0 = inr b /\ P b. + Proof. + pose proof (by_invariant_fuel' inv measure P f s0); + specialize_by assumption; break_match_hyps; [contradiction|eauto]. + Qed. + + Lemma by_invariant (inv:_->Prop) measure P s0 + (init : inv s0) + (step : forall s, inv s -> match body s with + | inl s' => inv s' /\ measure s' < measure s + | inr s' => P s' + end) + : exists b, loop (measure s0) s0 = inr b /\ P b. + Proof. eapply by_invariant_fuel; eauto. Qed. + + (* Completeness proof *) + + Definition iterations_required fuel s : option nat := + nat_rect _ None + (fun n r => + match r with + | Some _ => r + | None => + match loop n s with + | inl a => None + | inr b => Some n + end + end + ) fuel. + + Lemma iterations_required_correct fuel s : + (forall m, iterations_required fuel s = Some m -> + m < fuel /\ + exists b, forall n, (n < m -> exists a, loop n s = inl a) /\ (m <= n -> loop n s = inr b)) + /\ + (iterations_required fuel s = None -> forall n, n < fuel -> exists a, loop n s = inl a). + Proof. + induction fuel; intros. + { cbn. split; intros; inversion_option; lia. } + { change (iterations_required (S fuel) s) + with (match iterations_required fuel s with + | None => match loop fuel s with + | inl _ => None + | inr _ => Some fuel + end + | Some _ => iterations_required fuel s + end) in *. + destruct (iterations_required fuel s) in *. + { split; intros; inversion_option; subst. + destruct (proj1 IHfuel _ eq_refl); split; [lia|assumption]. } + { destruct (loop fuel s) eqn:HSf; split; intros; inversion_option; subst. + { intros. destruct (PeanoNat.Nat.eq_dec n fuel); subst; eauto; []. + assert (n < fuel) by lia. eapply IHfuel; congruence. } + { split; [lia|]. + exists b; intros; split; intros. + { eapply IHfuel; congruence || lia. } + { eapply PeanoNat.Nat.le_exists_sub in H; destruct H as [?[]]; subst. + eauto using loop_fuel_add_stable. } } } } + Qed. + + Lemma iterations_required_step fuel s s' n + (Hs : iterations_required fuel s = Some (S n)) + (Hstep : body s = inl s') + : iterations_required fuel s' = Some n. + Proof. + eapply iterations_required_correct in Hs. + destruct Hs as [Hn [b Hs]]. + pose proof (proj2 (Hs (S n)) ltac:(lia)) as H. + rewrite loop_fuel_S_first, Hstep in H. + destruct (iterations_required fuel s') as [x|] eqn:Hs' in *; [f_equal|exfalso]. + { eapply iterations_required_correct in Hs'; destruct Hs' as [Hx Hs']. + destruct Hs' as [b' Hs']. + destruct (Compare_dec.le_lt_dec n x) as [Hc|Hc]. + { destruct (Compare_dec.le_lt_dec x n) as [Hc'|Hc']; try lia; []. + destruct (proj1 (Hs' n) Hc'); congruence. } + { destruct (proj1 (Hs (S x)) ltac:(lia)) as [? HX]. + rewrite loop_fuel_S_first, Hstep in HX. + pose proof (proj2 (Hs' x) ltac:(lia)). + congruence. } } + { eapply iterations_required_correct in Hs'; destruct Hs' as [? Hs']; + [|exact Hn]. + rewrite loop_fuel_S_last, H in Hs'; congruence. } + Qed. + + Local Lemma invariant_complete (P:_->Prop) f s0 b (H:loop f s0 = inr b) (HP:P b) + : exists inv measure, + (inv s0 /\ measure s0 <= f) + /\ forall s, inv s -> match body s with + | inl s' => inv s' /\ measure s' < measure s + | inr s' => P s' + end. + Proof. + set (measure s := match iterations_required (S f) s with None => 0 | Some n => n end). + exists (fun s => match loop (measure s) s with + | inl a => False + | inr r => r = b end). + exists (measure); split; [ |repeat match goal with |- _ /\ _ => split end..]. + { cbv [measure]. + destruct (iterations_required (S f) s0) eqn:Hs0; + eapply iterations_required_correct in Hs0; + [ .. | exact (ltac:(lia):f <S f)]; [|destruct_head'_ex; congruence]. + destruct Hs0 as [? [? Hs0]]; split; [|lia]. + pose proof (proj2 (Hs0 n) ltac:(lia)) as HH; rewrite HH. + exact (loop_fuel_irrelevant _ _ _ _ _ HH H). } + { intros s Hinv; destruct (body s) as [s'|c] eqn:Hstep. + { destruct (loop (measure s) s) eqn:Hs; [contradiction|subst]. + cbv [measure] in *. + destruct (iterations_required (S f) s) eqn:Hs' in *; try destruct n; + try (rewrite loop_fuel_0 in Hs; congruence); []. + pose proof (iterations_required_step _ _ s' _ Hs' Hstep) as HA. + rewrite HA. + destruct (proj1 (iterations_required_correct _ _) _ HA) as [? [? [? HE']]]. + pose proof (HE' ltac:(constructor)) as HE; clear HE'. + split; [|lia]. + rewrite loop_fuel_S_first, Hstep in Hs. + break_match; congruence. } + { destruct (loop (measure s) s) eqn:Hs; [contradiction|]. + assert (HH: loop 1 s = inr c) by (cbn; rewrite Hstep; reflexivity). + rewrite (loop_fuel_irrelevant _ _ _ _ _ HH Hs); congruence. } } + Qed. + + Lemma invariant_iff P f s0 : + (exists b, loop f s0 = inr b /\ P b) + <-> + (exists inv measure, + (inv s0 /\ measure s0 <= f) + /\ forall s, inv s -> match body s with + | inl s' => inv s' /\ measure s' < measure s + | inr s' => P s' + end). + Proof. + repeat (intros || split || destruct_head'_ex || destruct_head'_and); + eauto using invariant_complete, by_invariant_fuel. + Qed. + End Loops. + + Global Arguments loop_cps_ok {A B body body_cps}. + Global Arguments loop_cps2_ok {A B body body_cps2}. + Global Arguments by_invariant_fuel {A B body} inv measure P. + Global Arguments by_invariant {A B body} inv measure P. + Global Arguments invariant_iff {A B body} P f s0. + Global Arguments iterations_required_correct {A B body} fuel s. +End core. + +Module default. + Section Default. + Context {A B} (default : B) (body : A -> A + B). + Definition loop fuel s : B := + match loop body fuel s with + | inl s => default + | inr s => s + end. + + Lemma by_invariant_fuel inv measure (P:_->Prop) f s0 + (init : inv s0 /\ measure s0 <= f) + (step: forall s, inv s -> match body s with + | inl s' => inv s' /\ measure s' < measure s + | inr s' => P s' + end) + : P (loop f s0). + Proof. + edestruct (by_invariant_fuel (body:=body) inv measure P f s0) as [x [HA HB]]; eauto; []. + apply (f_equal (fun r : A + B => match r with inl s => default | inr s => s end)) in HA. + cbv [loop]; break_match; congruence. + Qed. + + Lemma by_invariant (inv:_->Prop) measure P s0 + (init : inv s0) + (step: forall s, inv s -> match body s with + | inl s' => inv s' /\ measure s' < measure s + | inr s' => P s' + end) + : P (loop (measure s0) s0). + Proof. eapply by_invariant_fuel; eauto. Qed. + End Default. + Global Arguments by_invariant_fuel {A B default body} inv measure P. + Global Arguments by_invariant {A B default body} inv measure P. +End default. + +Module silent. + Section Silent. + Context {state} (body : state -> state + state). + Definition loop fuel s : state := + match loop body fuel s with + | inl s => s + | inr s => s + end. + + Lemma by_invariant_fuel inv measure (P:_->Prop) f s0 + (init : inv s0 /\ measure s0 <= f) + (step: forall s, inv s -> match body s with + | inl s' => inv s' /\ measure s' < measure s + | inr s' => P s' + end) + : P (loop f s0). + Proof. + edestruct (by_invariant_fuel (body:=body) inv measure P f s0) as [x [A B]]; eauto; []. + apply (f_equal (fun r : state + state => match r with inl s => s | inr s => s end)) in A. + cbv [loop]; break_match; congruence. + Qed. + + Lemma by_invariant (inv:_->Prop) measure P s0 + (init : inv s0) + (step: forall s, inv s -> match body s with + | inl s' => inv s' /\ measure s' < measure s + | inr s' => P s' + end) + : P (loop (measure s0) s0). + Proof. eapply by_invariant_fuel; eauto. Qed. + End Silent. + + Global Arguments by_invariant_fuel {state body} inv measure P. + Global Arguments by_invariant {state body} inv measure P. +End silent. + +Module while. + Section While. + Context {state} + (test : state -> bool) + (body : state -> state). + + Fixpoint while f s := + if test s + then + let s := body s in + match f with + | O => s + | S f => while f s + end + else s. + + Local Definition lbody := fun s => if test s then inl (body s) else inr s. + + Lemma eq_loop f s : while f s = silent.loop lbody f s. + Proof. + revert s; induction f; intros s; + repeat match goal with + | _ => progress cbn in * + | _ => progress cbv [silent.loop lbody] in * + | _ => rewrite IHf + | _ => progress break_match + | _ => congruence + end. + Qed. + + Lemma by_invariant_fuel inv measure P f s0 + (init : inv s0 /\ measure s0 <= f) + (step : forall s, inv s -> if test s + then inv (body s) /\ measure (body s) < measure s + else P s) + : P (while f s0). + Proof. + rewrite eq_loop. + eapply silent.by_invariant_fuel; eauto; []; intros s H; cbv [lbody]. + specialize (step s H); destruct (test s); eauto. + Qed. + + Lemma by_invariant (inv:_->Prop) measure P s0 + (init : inv s0) + (step : forall s, inv s -> if test s + then inv (body s) /\ measure (body s) < measure s + else P s) + : P (while (measure s0) s0). + Proof. eapply by_invariant_fuel; eauto. Qed. + + Context (body_cps : state ~> state). + + Fixpoint while_cps f s :~> state := + if test s + then + s <- body_cps s; + match f with + | O => return s + | S f =>while_cps f s + end + else return s. + + Context (body_cps_ok : forall s {R} f, body_cps s R f = f (body s)). + Lemma loop_cps_ok n s {R} f : while_cps n s R f = f (while n s). + Proof. + revert s; induction n; intros s; + repeat match goal with + | _ => progress intros + | _ => progress cbv [cpsreturn cpscall] in * + | _ => progress cbn + | _ => progress rewrite ?body_cps_ok + | _ => progress rewrite ?IHn + | _ => progress inversion_sum + | _ => progress break_match + | _ => reflexivity + end. + Qed. + End While. + Global Arguments by_invariant_fuel {state test body} inv measure P. + Global Arguments by_invariant {state test body} inv measure P. +End while. +Notation while := while.while. + +Definition for2 {state} (test : state -> bool) (increment body : state -> state) + := while test (fun s => let s := body s in increment s). + +Definition for3 {state} init test increment body fuel := + @for2 state test increment body fuel init. + +(* TODO: we probably want notations for these. here are some ideas: *) + +(* notation for core.loop_cps2: + loop '(state1, state2, ...) := init + decreasing measure {{ + continue: + body + }} break; + cont +where + continue, break are binders (for continuations passed by core.loop_cps2) + state1, state2 are binders for tuple fields + measure is a function of state1, state2, ... (or the tuple bound by them) + body is a function of continue_label, state1, state2, ... + cont is the contiuation passed to core.loop_cps2 + "loop" and "decreasing" are delimiters +*) + +(* notation for while.while_cps: + + loop '(state1, state2, ...) := init + while test + decreasing measure {{ + continue: + body + }}; + cont +where + state1, state2, continue are binders + body is a function of all of them + test and measure are functions of state1, state2 (or the tuple bound by them) + "loop", "while" and "decreasing" are delimiters + cont is a continuation to the entire loop + *) + + +(* idea for notation for some for-like loop: +loop '(state1, state2) := init +for (i := 0; i <? n; i+1) {{ + continue: + body +}}; +cont + +where the first i is a binder for all uses of i, and the test and increment are parsed as functions of some argument *) + +(* ideally it should be possible to explicitly indicate the type of the loop state somewhere, in all these examples *)
\ No newline at end of file |