diff options
-rw-r--r-- | _CoqProject | 3 | ||||
-rw-r--r-- | src/Util/ForLoop.v | 34 | ||||
-rw-r--r-- | src/Util/ForLoop/Instances.v | 67 | ||||
-rw-r--r-- | src/Util/ForLoop/InvariantFramework.v | 369 | ||||
-rw-r--r-- | src/Util/ForLoop/Tests.v | 48 | ||||
-rw-r--r-- | src/Util/ForLoop/Unrolling.v | 314 |
6 files changed, 826 insertions, 9 deletions
diff --git a/_CoqProject b/_CoqProject index b41e8f6b9..416165ad1 100644 --- a/_CoqProject +++ b/_CoqProject @@ -248,7 +248,10 @@ src/Util/Unit.v src/Util/WordUtil.v src/Util/ZRange.v src/Util/ZUtil.v +src/Util/ForLoop/Instances.v +src/Util/ForLoop/InvariantFramework.v src/Util/ForLoop/Tests.v +src/Util/ForLoop/Unrolling.v src/Util/Logic/ImplAnd.v src/Util/Sigma/Associativity.v src/Util/Sigma/Lift.v diff --git a/src/Util/ForLoop.v b/src/Util/ForLoop.v index caa853a9a..9ec6f5ba4 100644 --- a/src/Util/ForLoop.v +++ b/src/Util/ForLoop.v @@ -9,35 +9,44 @@ Section with_body. Fixpoint repeat_function (count : nat) (st : stateT) : stateT := match count with | O => st - | S count' => repeat_function count' (body count' st) + | S count' => repeat_function count' (body count st) end. End with_body. Local Open Scope bool_scope. Local Open Scope Z_scope. -Definition for_loop (i0 finish : Z) (step : Z) {stateT} (initial : stateT) (body : Z -> stateT -> stateT) +Definition for_loop {stateT} (i0 finish : Z) (step : Z) (initial : stateT) (body : Z -> stateT -> stateT) : stateT - := let signed_step := (finish - i0) / step in - let count := Z.to_nat ((finish - i0) / signed_step) in - repeat_function (fun c => body (i0 + signed_step * Z.of_nat (count - c))) count initial. + := let count := Z.to_nat (Z.quot (finish - i0 + step - Z.sgn step) step) in + repeat_function (fun c => body (i0 + step * Z.of_nat (count - c))) count initial. Notation "'for' i (:= i0 ; += step ; < finish ) 'updating' ( state := initial ) {{ body }}" := (for_loop i0 finish step initial (fun i state => body)) : core_scope. +Module Import ForNotationConstants. + Definition eq := @eq Z. + Module Z. + Definition ltb := Z.ltb. + Definition ltb' := Z.ltb. + Definition gtb := Z.gtb. + Definition gtb' := Z.gtb. + End Z. +End ForNotationConstants. + Delimit Scope for_notation_scope with for_notation. -Notation "x += y" := (x = Z.pos y) : for_notation_scope. -Notation "x -= y" := (x = Z.neg y) : for_notation_scope. +Notation "x += y" := (eq x (Z.pos y)) : for_notation_scope. +Notation "x -= y" := (eq x (Z.neg y)) : for_notation_scope. Notation "++ x" := (x += 1)%for_notation : for_notation_scope. Notation "-- x" := (x -= 1)%for_notation : for_notation_scope. Notation "x ++" := (x += 1)%for_notation : for_notation_scope. Notation "x --" := (x -= 1)%for_notation : for_notation_scope. Infix "<" := Z.ltb : for_notation_scope. Infix ">" := Z.gtb : for_notation_scope. -Infix "<=" := Z.leb : for_notation_scope. -Infix ">=" := Z.geb : for_notation_scope. +Notation "x <= y" := (Z.ltb' x (y + 1)) : for_notation_scope. +Notation "x >= y" := (Z.gtb' x (y - 1)) : for_notation_scope. Class class_eq {A} (x y : A) := make_class_eq : x = y. Global Instance class_eq_refl {A x} : @class_eq A x x := eq_refl. @@ -64,4 +73,11 @@ Notation "'for' ( 'int' i = i0 ; finish_expr ; step_expr ) 'updating' ( state1 . (fun i : Z => step_expr%for_notation) (fun i : Z => finish_expr%for_notation) (fun (i : Z) => (fun state1 => .. (fun staten => body) .. )) + _ _ _). +Notation "'for' ( 'int' i = i0 ; finish_expr ; step_expr ) 'updating' ( state1 .. staten = initial ) {{ body }}" + := (@for_loop_notation + i0%Z _ _ _ initial%Z _ + (fun i : Z => step_expr%for_notation) + (fun i : Z => finish_expr%for_notation) + (fun (i : Z) => (fun state1 => .. (fun staten => body) .. )) eq_refl eq_refl _). diff --git a/src/Util/ForLoop/Instances.v b/src/Util/ForLoop/Instances.v new file mode 100644 index 000000000..0a1f65e29 --- /dev/null +++ b/src/Util/ForLoop/Instances.v @@ -0,0 +1,67 @@ +Require Import Coq.omega.Omega. +Require Import Coq.Classes.Morphisms. +Require Import Crypto.Util.ForLoop. +Require Import Crypto.Util.Notations. + +Lemma repeat_function_Proper_rel_le {stateT} R f g n + (Hfg : forall c, 0 < c <= n -> forall s1 s2, R s1 s2 -> R (f c s1) (g c s2)) + s1 s2 (Hs : R s1 s2) + : R (@repeat_function stateT f n s1) (@repeat_function stateT g n s2). +Proof. + revert s1 s2 Hs. + induction n; simpl; auto. + intros; apply IHn; auto; + intros; apply Hfg; auto; + omega. +Qed. + +Global Instance repeat_function_Proper_rel {stateT} R + : Proper (pointwise_relation _ (R ==> R) ==> eq ==> R ==> R) (@repeat_function stateT) | 10. +Proof. + unfold pointwise_relation, respectful. + intros body1 body2 Hbody c y ?; subst y. + induction c; simpl; auto. +Qed. + +Lemma repeat_function_Proper_le {stateT} f g n + (Hfg : forall c, 0 < c <= n -> forall st, f c st = g c st) + st + : @repeat_function stateT f n st = @repeat_function stateT g n st. +Proof. + apply repeat_function_Proper_rel_le; try reflexivity; intros; subst; auto. +Qed. + +Global Instance repeat_function_Proper {stateT} + : Proper (pointwise_relation _ (pointwise_relation _ eq) ==> eq ==> eq ==> eq) (@repeat_function stateT). +Proof. + intros ???; eapply repeat_function_Proper_rel; repeat intro; subst. + unfold pointwise_relation, respectful in *; auto. +Qed. +About for_loop. + +Global Instance for_loop_Proper_rel {stateT} R i0 final step + : Proper (R ==> pointwise_relation _ (R ==> R) ==> R) (@for_loop stateT i0 final step) | 10. +Proof. + intros ?? Hinitial ?? Hbody; revert Hinitial. + unfold for_loop; eapply repeat_function_Proper_rel; + unfold pointwise_relation, respectful in *; auto. +Qed. + +Global Instance for_loop_Proper_rel_full {stateT} R + : Proper (eq ==> eq ==> eq ==> R ==> pointwise_relation _ (R ==> R) ==> R) (@for_loop stateT) | 20. +Proof. + intros ?????????; subst; apply for_loop_Proper_rel. +Qed. + +Global Instance for_loop_Proper {stateT} i0 final step initial + : Proper (pointwise_relation _ (pointwise_relation _ eq) ==> eq) (@for_loop stateT i0 final step initial). +Proof. + unfold pointwise_relation. + intros ???; eapply for_loop_Proper_rel; try reflexivity; repeat intro; subst; auto. +Qed. + +Global Instance for_loop_Proper_full {stateT} + : Proper (eq ==> eq ==> eq ==> eq ==> pointwise_relation _ (pointwise_relation _ eq) ==> eq) (@for_loop stateT) | 5. +Proof. + intros ????????????; subst; apply for_loop_Proper. +Qed. diff --git a/src/Util/ForLoop/InvariantFramework.v b/src/Util/ForLoop/InvariantFramework.v new file mode 100644 index 000000000..503cfcdc3 --- /dev/null +++ b/src/Util/ForLoop/InvariantFramework.v @@ -0,0 +1,369 @@ +(** * Proving properties of for-loops via loop-invariants *) +Require Import Coq.micromega.Psatz. +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.ForLoop. +Require Import Crypto.Util.ForLoop.Unrolling. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Bool. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.Notations. + +Lemma repeat_function_ind {stateT} (P : nat -> stateT -> Prop) + (body : nat -> stateT -> stateT) + (count : nat) (st : stateT) + (Pbefore : P count st) + (Pbody : forall c st, c < count -> P (S c) st -> P c (body (S c) st)) + : P 0 (repeat_function body count st). +Proof. + revert dependent st; revert dependent body; revert dependent P. + induction count; intros; [ exact Pbefore | ]. + { rewrite repeat_function_unroll1_end; apply Pbody; [ omega | ]. + apply (IHcount (fun c => P (S c))); auto with omega. } +Qed. + +Local Open Scope bool_scope. +Local Open Scope Z_scope. + +Section for_loop. + Context (i0 finish : Z) (step : Z) {stateT} (initial : stateT) (body : Z -> stateT -> stateT) + (P : Z -> stateT -> Prop) + (Pbefore : P i0 initial) + (Pbody : forall c st n, c = i0 + n * step -> i0 <= c < finish \/ finish < c <= i0 -> P c st -> P (c + step) (body c st)) + (Hgood : Z.sgn step = Z.sgn (finish - i0)). + + Let countZ := (Z.quot (finish - i0 + step - Z.sgn step) step). + Let count := Z.to_nat countZ. + Let of_nat_count c := (i0 + step * Z.of_nat (count - c)). + Let nat_body := (fun c => body (of_nat_count c)). + + Local Arguments Z.mul !_ !_. + Local Arguments Z.add !_ !_. + Local Arguments Z.sub !_ !_. + + Local Lemma Hgood_complex : Z.sgn step = Z.sgn (finish - i0 + step - Z.sgn step). + Proof using Hgood. + clear -Hgood. + revert Hgood. + generalize dependent (finish - i0); intro z; intros. + destruct step, z; simpl in * |- ; try (simpl; omega); + repeat change (Z.sgn (Z.pos _)) with 1; + repeat change (Z.sgn (Z.neg _)) with (-1); + symmetry; + [ apply Z.sgn_pos_iff | apply Z.sgn_neg_iff ]; + lia. + Qed. + + Local Lemma Hcount_nonneg : 0 <= countZ. + Proof using Hgood. + apply Z.quot_nonneg_same_sgn. + symmetry; apply Hgood_complex. + Qed. + + Lemma for_loop_ind + : P (finish - Z.sgn (finish - i0 + step - Z.sgn step) * (Z.abs (finish - i0 + step - Z.sgn step) mod Z.abs step) + step - Z.sgn step) + (for_loop i0 finish step initial body). + Proof using Pbody Pbefore Hgood. + destruct (Z_zerop step). + { subst; unfold for_loop; simpl in *. + rewrite Z.quot_div_full; simpl. + symmetry in Hgood; rewrite Z.sgn_null_iff in Hgood. + assert (finish = i0) by omega; subst. + simpl; autorewrite with zsimplify_const; simpl; auto. } + assert (Hsgn_step : Z.sgn step <> 0) by (rewrite Z.sgn_null_iff; auto). + assert (Hsgn : Z.sgn ((finish - i0 + step - Z.sgn step) / step) = Z.sgn ((finish - i0 + step - Z.sgn step) / step) * Z.sgn (finish - i0 + step - Z.sgn step) * Z.sgn step) + by (rewrite <- Hgood_complex, <- Z.mul_assoc, <- Z.sgn_mul, (Z.sgn_pos (_ * _)) by nia; omega). + assert (Hfis_div : 0 <= (finish - i0 + step - Z.sgn step) / step) + by (apply Z.sgn_nonneg; rewrite Hsgn; apply Zdiv_sgn). + clear Hsgn. + let rhs := match goal with |- ?P ?rhs _ => rhs end in + assert (Heq : i0 + step * Z.of_nat count = rhs). + { unfold count, countZ. + rewrite Z.mod_eq by (rewrite Z.abs_0_iff; assumption). + rewrite Z.quot_div_full, <- !Z.sgn_abs, <- !Hgood_complex, !Zdiv_mult_cancel_r, !Z.mul_sub_distr_l by auto. + rewrite <- !Z.sgn_mul, !(Z.mul_comm _ (Z.sgn _)), !(Z.mul_assoc (Z.sgn _) _), <- Z.sgn_mul, Z.sgn_pos, !Z.mul_1_l by nia. + repeat rewrite ?Z.sub_add_distr, ?Z.sub_sub_distr; rewrite Z.sub_diag. + autorewrite with zsimplify_const. + rewrite Z2Nat.id by omega. + omega. } + rewrite <- Heq; clear Heq. + unfold for_loop. + generalize (@repeat_function_ind stateT (fun c => P (of_nat_count c)) nat_body count initial); + cbv beta in *. + unfold of_nat_count in *; cbv beta in *. + rewrite Nat.sub_diag, !Nat.sub_0_r. + autorewrite with zsimplify_const. + intro H; specialize (H Pbefore). + destruct (Z_dec' i0 finish) as [ Hs | Hs]. + { apply H; clear H Pbefore. + { intros c st Hc. + assert (Hclt : 0 < Z.of_nat (count - c)) by (apply (inj_lt 0); omega). + intro H'; specialize (fun pf' n pf => Pbody _ _ n pf pf' H'). + move Pbody at bottom. + { let T := match type of Pbody with ?T -> _ => T end in + let H := fresh in + cut T; [ intro H; specialize (Pbody H) | ]. + { revert Pbody. + subst nat_body; cbv beta. + rewrite Nat.sub_succ_r, Nat2Z.inj_pred by omega. + rewrite <- Z.sub_1_r, Z.mul_sub_distr_l, Z.mul_1_r. + rewrite <- !Z.add_assoc, !Z.sub_add in *. + refine (fun p => p (Z.of_nat (count - c) - 1) _). + lia. } + { destruct Hs; [ left | right ]. + { assert (Hstep : 0 < step) + by (rewrite <- Z.sgn_pos_iff, Hgood, Z.sgn_pos_iff; omega). + assert (0 < Z.of_nat (S c)) by (apply (inj_lt 0); omega). + assert (0 <= (finish - i0 + step - Z.sgn step) mod step) by auto with zarith. + assert (0 < step <= step * Z.of_nat (S c)) by nia. + split; [ nia | ]. + rewrite Nat2Z.inj_sub, Z.mul_sub_distr_l by omega. + unfold count. + rewrite Z2Nat.id by auto using Hcount_nonneg. + unfold countZ. + rewrite Z.mul_quot_eq_full by auto. + rewrite <- !Hgood_complex, Z.abs_sgn. + rewrite !Z.add_sub_assoc, !Z.add_assoc, Zplus_minus. + rewrite Z.sgn_pos in * by omega. + omega. } + { assert (Hstep : step < 0) + by (rewrite <- Z.sgn_neg_iff, Hgood, Z.sgn_neg_iff; omega). + assert (Hcsc0 : 0 <= Z.of_nat (count - S c)) by auto with zarith. + assert (Hsc0 : 0 < Z.of_nat (S c)) by lia. + assert (step * Z.of_nat (count - S c) <= 0) by (clear -Hcsc0 Hstep; nia). + assert (step * Z.of_nat (S c) <= step < 0) by (clear -Hsc0 Hstep; nia). + assert (finish - i0 < 0) + by (rewrite <- Z.sgn_neg_iff, <- Hgood, Z.sgn_neg_iff; omega). + assert (finish - i0 + step - Z.sgn step < 0) + by (rewrite <- Z.sgn_neg_iff, <- Hgood_complex, Z.sgn_neg_iff; omega). + assert ((finish - i0 + step - Z.sgn step) mod step <= 0) by (apply Z_mod_neg; auto with zarith). + split; [ | nia ]. + rewrite Nat2Z.inj_sub, Z.mul_sub_distr_l by omega. + unfold count. + rewrite Z2Nat.id by auto using Hcount_nonneg. + unfold countZ. + rewrite Z.mul_quot_eq_full by auto. + rewrite <- !Hgood_complex, Z.abs_sgn. + rewrite Z.sgn_neg in * by omega. + omega. } } } } } + { subst. + subst count nat_body countZ. + repeat first [ assumption + | rewrite Z.sub_diag + | progress autorewrite with zsimplify_const in * + | rewrite Z.quot_sub_sgn ]. } + Qed. +End for_loop. + +Lemma for_loop_notation_ind {stateT} (P : Z -> stateT -> Prop) + {i0 : Z} {step : Z} {finish : Z} {initial : stateT} + {cmp : Z -> Z -> bool} + {step_expr finish_expr} (body : Z -> stateT -> stateT) + {Hstep : class_eq (fun i => i = step) step_expr} + {Hfinish : class_eq (fun i => cmp i finish) finish_expr} + {Hgood : for_loop_is_good i0 step finish cmp} + (Pbefore : P i0 initial) + (Pbody : forall c st n, c = i0 + n * step -> i0 <= c < finish \/ finish < c <= i0 -> P c st -> P (c + step) (body c st)) + : P (finish - Z.sgn (finish - i0 + step - Z.sgn step) * (Z.abs (finish - i0 + step - Z.sgn step) mod Z.abs step) + step - Z.sgn step) + (@for_loop_notation i0 step finish _ initial cmp step_expr finish_expr body Hstep Hfinish Hgood). +Proof. + unfold for_loop_notation, for_loop_is_good in *; split_andb; Z.ltb_to_lt. + apply for_loop_ind; auto. +Qed. + +Local Ltac pre_t := + lazymatch goal with + | [ Pbefore : ?P ?i0 ?initial + |- ?P _ (@for_loop_notation ?i0 ?step ?finish _ ?initial _ ?step_expr ?finish_expr ?body ?Hstep ?Hfinish ?Hgood) ] + => generalize (@for_loop_notation_ind + _ P i0 step finish initial _ step_expr finish_expr body Hstep Hfinish Hgood Pbefore) + end. +Local Ltac t_step := + first [ progress unfold for_loop_is_good, for_loop_notation in * + | progress split_andb + | progress Z.ltb_to_lt + | rewrite Z.sgn_pos by lia + | rewrite Z.abs_eq by lia + | rewrite Z.sgn_neg by lia + | rewrite Z.abs_neq by lia + | progress autorewrite with zsimplify_const + | match goal with + | [ Hsgn : Z.sgn ?step = Z.sgn _ |- _ ] + => unique assert (0 < step) by (rewrite <- Z.sgn_pos_iff, Hsgn, Z.sgn_pos_iff; omega); clear Hsgn + | [ Hsgn : Z.sgn ?step = Z.sgn _ |- _ ] + => unique assert (step < 0) by (rewrite <- Z.sgn_neg_iff, Hsgn, Z.sgn_neg_iff; omega); clear Hsgn + | [ |- (_ -> ?P ?x ?y) -> ?P ?x' ?y' ] + => replace x with x' by lia; let H' := fresh in intro H'; apply H'; clear H' + | [ |- (_ -> _) -> _ ] + => let H := fresh "Hbody" in intro H; progress Z.replace_all_neg_with_pos; revert H + end + | rewrite !Z.opp_sub_distr + | rewrite !Z.opp_add_distr + | rewrite !Z.opp_involutive + | rewrite !Z.sub_opp_r + | rewrite (Z.add_opp_r _ 1) + | progress (push_Zmod; pull_Zmod) + | progress Z.replace_all_neg_with_pos + | solve [ eauto with omega ] ]. +Local Ltac t := pre_t; repeat t_step. + +Lemma for_loop_ind_lt {stateT} (P : Z -> stateT -> Prop) + {i0 : Z} {step : Z} {finish : Z} {initial : stateT} + {step_expr finish_expr} (body : Z -> stateT -> stateT) + {Hstep : class_eq (fun i => i = step) step_expr} + {Hfinish : class_eq (fun i => Z.ltb i finish) finish_expr} + {Hgood : for_loop_is_good i0 step finish Z.ltb} + (Pbefore : P i0 initial) + (Pbody : forall c st n, c = i0 + n * step -> i0 <= c < finish -> P c st -> P (c + step) (body c st)) + : P (finish + step - 1 - ((finish - i0 - 1) mod step)) + (@for_loop_notation i0 step finish _ initial Z.ltb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish Hgood). +Proof. t. Qed. + +Lemma for_loop_ind_gt {stateT} (P : Z -> stateT -> Prop) + {i0 : Z} {step : Z} {finish : Z} {initial : stateT} + {step_expr finish_expr} (body : Z -> stateT -> stateT) + {Hstep : class_eq (fun i => i = step) step_expr} + {Hfinish : class_eq (fun i => Z.gtb i finish) finish_expr} + {Hgood : for_loop_is_good i0 step finish Z.gtb} + (Pbefore : P i0 initial) + (Pbody : forall c st n, c = i0 + n * step -> finish < c <= i0 -> P c st -> P (c + step) (body c st)) + : P (finish + step + 1 + (i0 - finish - step - 1) mod (-step)) + (@for_loop_notation i0 step finish _ initial Z.gtb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish Hgood). +Proof. + replace (i0 - finish) with (-(finish - i0)) by omega. + t. +Qed. + +Lemma for_loop_ind_lt1 {stateT} (P : Z -> stateT -> Prop) + {i0 : Z} {finish : Z} {initial : stateT} + (body : Z -> stateT -> stateT) + {Hgood : for_loop_is_good i0 1 finish _} + (Pbefore : P i0 initial) + (Pbody : forall c st, i0 <= c < finish -> P c st -> P (c + 1) (body c st)) + : P finish + (for (int i = i0; i < finish; i++) updating (st = initial) {{ + body i st + }}). +Proof. + generalize (@for_loop_ind_lt + stateT P i0 1 finish initial _ _ body eq_refl eq_refl Hgood Pbefore). + rewrite Z.mod_1_r, Z.sub_0_r, Z.add_simpl_r. + auto. +Qed. + +Lemma for_loop_ind_gt1 {stateT} (P : Z -> stateT -> Prop) + {i0 : Z} {finish : Z} {initial : stateT} + (body : Z -> stateT -> stateT) + {Hgood : for_loop_is_good i0 (-1) finish _} + (Pbefore : P i0 initial) + (Pbody : forall c st, finish < c <= i0 -> P c st -> P (c - 1) (body c st)) + : P finish + (for (int i = i0; i > finish; i--) updating (st = initial) {{ + body i st + }}). +Proof. + generalize (@for_loop_ind_gt + stateT P i0 (-1) finish initial _ _ body eq_refl eq_refl Hgood Pbefore). + simpl; rewrite Z.mod_1_r, Z.add_0_r, (Z.add_opp_r _ 1), Z.sub_simpl_r. + intro H; apply H; intros *. + rewrite (Z.add_opp_r _ 1); auto. +Qed. + +Local Hint Extern 1 (for_loop_is_good (?i0 + ?step) ?step ?finish _) +=> refine (for_loop_is_good_step_lt _); try assumption : typeclass_instances. +Local Hint Extern 1 (for_loop_is_good (?i0 + ?step) ?step ?finish _) +=> refine (for_loop_is_good_step_gt _); try assumption : typeclass_instances. +Local Hint Extern 1 (for_loop_is_good (?i0 - ?step') _ ?finish _) +=> refine (for_loop_is_good_step_gt (step:=-step') _); try assumption : typeclass_instances. +Local Hint Extern 1 (for_loop_is_good (?i0 + 1) 1 ?finish _) +=> refine (for_loop_is_good_step_lt' _); try assumption : typeclass_instances. +Local Hint Extern 1 (for_loop_is_good (?i0 - 1) (-1) ?finish _) +=> refine (for_loop_is_good_step_gt' _); try assumption : typeclass_instances. + +(** The Hoare-logic-like conditions for ≤ and ≥ loops seem slightly + unnatural; you have to choose either to state your correctness + property in terms of [i + 1], or talk about the correctness + condition when the loop counter is [i₀ - 1] (which is strange; + it's like saying the loop has run -1 times), or give the + correctness condition after the first run of the loop body, rather + than before it. We give lemmas for the second two options; if + you're using the first one, Coq probably won't be able to infer + the motive ([P], below) automatically, and you might as well use + the vastly more general version [for_loop_ind_lt] / + [for_loop_ind_gt]. *) +Lemma for_loop_ind_le1 {stateT} (P : Z -> stateT -> Prop) + {i0 : Z} {finish : Z} {initial : stateT} + (body : Z -> stateT -> stateT) + {Hgood : for_loop_is_good i0 1 (finish+1) _} + (Pbefore : P i0 (body i0 initial)) + (Pbody : forall c st, i0 <= c <= finish -> P (c-1) st -> P c (body c st)) + : P finish + (for (int i = i0; i <= finish; i++) updating (st = initial) {{ + body i st + }}). +Proof. + rewrite for_loop_le1_unroll1. + edestruct Sumbool.sumbool_of_bool; Z.ltb_to_lt; cbv zeta. + { generalize (@for_loop_ind_lt + stateT (fun n => P (n - 1)) (i0+1) 1 (finish+1) (body i0 initial) _ _ body eq_refl eq_refl _). + rewrite Z.mod_1_r, Z.sub_0_r, !Z.add_simpl_r. + intro H; apply H; auto with omega; intros *. + rewrite !Z.add_simpl_r; auto with omega. } + { unfold for_loop_is_good, ForNotationConstants.Z.ltb', ForNotationConstants.Z.ltb in *; split_andb; Z.ltb_to_lt. + assert (i0 = finish) by omega; subst. + assumption. } +Qed. + +Lemma for_loop_ind_le1_offset {stateT} (P : Z -> stateT -> Prop) + {i0 : Z} {finish : Z} {initial : stateT} + (body : Z -> stateT -> stateT) + {Hgood : for_loop_is_good i0 1 (finish+1) _} + (Pbefore : P (i0-1) initial) + (Pbody : forall c st, i0 <= c <= finish -> P (c-1) st -> P c (body c st)) + : P finish + (for (int i = i0; i <= finish; i++) updating (st = initial) {{ + body i st + }}). +Proof. + apply for_loop_ind_le1; auto with omega. + unfold for_loop_is_good, ForNotationConstants.Z.ltb', ForNotationConstants.Z.ltb in *; split_andb; Z.ltb_to_lt. + auto with omega. +Qed. + +Lemma for_loop_ind_ge1 {stateT} (P : Z -> stateT -> Prop) + {i0 : Z} {finish : Z} {initial : stateT} + (body : Z -> stateT -> stateT) + {Hgood : for_loop_is_good i0 (-1) (finish-1) _} + (Pbefore : P i0 (body i0 initial)) + (Pbody : forall c st, finish <= c <= i0 -> P (c+1) st -> P c (body c st)) + : P finish + (for (int i = i0; i >= finish; i--) updating (st = initial) {{ + body i st + }}). +Proof. + rewrite for_loop_ge1_unroll1. + edestruct Sumbool.sumbool_of_bool; Z.ltb_to_lt; cbv zeta. + { generalize (@for_loop_ind_gt + stateT (fun n => P (n + 1)) (i0-1) (-1) (finish-1) (body i0 initial) _ _ body eq_refl eq_refl _). + simpl; rewrite Z.mod_1_r, Z.add_0_r, (Z.add_opp_r _ 1), !Z.sub_simpl_r. + intro H; apply H; intros *; auto with omega. + rewrite (Z.add_opp_r _ 1), !Z.sub_simpl_r; auto with omega. } + { unfold for_loop_is_good, ForNotationConstants.Z.gtb', ForNotationConstants.Z.gtb in *; split_andb; Z.ltb_to_lt. + assert (i0 = finish) by omega; subst. + assumption. } +Qed. + +Lemma for_loop_ind_ge1_offset {stateT} (P : Z -> stateT -> Prop) + {i0 : Z} {finish : Z} {initial : stateT} + (body : Z -> stateT -> stateT) + {Hgood : for_loop_is_good i0 (-1) (finish-1) _} + (Pbefore : P (i0+1) initial) + (Pbody : forall c st, finish <= c <= i0 -> P (c+1) st -> P c (body c st)) + : P finish + (for (int i = i0; i >= finish; i--) updating (st = initial) {{ + body i st + }}). +Proof. + apply for_loop_ind_ge1; auto with omega. + unfold for_loop_is_good, ForNotationConstants.Z.gtb', ForNotationConstants.Z.gtb in *; split_andb; Z.ltb_to_lt. + auto with omega. +Qed. diff --git a/src/Util/ForLoop/Tests.v b/src/Util/ForLoop/Tests.v index 7b800ddbc..1061f1958 100644 --- a/src/Util/ForLoop/Tests.v +++ b/src/Util/ForLoop/Tests.v @@ -1,7 +1,55 @@ Require Import Coq.ZArith.BinInt. +Require Import Coq.micromega.Psatz. Require Import Crypto.Util.ForLoop. +Require Import Crypto.Util.ForLoop.InvariantFramework. +Require Import Crypto.Util.ZUtil. Local Open Scope Z_scope. Check (for i (:= 0; += 1; < 10) updating (v := 5) {{ v + i }}). Check (for (int i = 0; i < 5; i++) updating ( '(v1, v2) = (0, 0) ) {{ (v1 + i, v2 + i) }}). + +Compute for (int i = 0; i < 5; i++) updating (v = 0) {{ v + i }}. +Compute for (int i = 0; i <= 5; i++) updating (v = 0) {{ v + i }}. +Compute for (int i = 5; i > -1; i--) updating (v = 0) {{ v + i }}. +Compute for (int i = 5; i >= 0; i--) updating (v = 0) {{ v + i }}. +Compute for (int i = 0; i < 5; i += 2) updating (v = 0) {{ v + i }}. +Compute for (int i = 0; i <= 5; i += 2) updating (v = 0) {{ v + i }}. +Compute for (int i = 5; i > -1; i -= 2) updating (v = 0) {{ v + i }}. +Compute for (int i = 5; i >= 0; i -= 2) updating (v = 0) {{ v + i }}. +Compute for (int i = 0; i < 6; i += 2) updating (v = 0) {{ v + i }}. +Compute for (int i = 0; i <= 6; i += 2) updating (v = 0) {{ v + i }}. +Compute for (int i = 6; i > -1; i -= 2) updating (v = 0) {{ v + i }}. +Compute for (int i = 6; i >= 0; i -= 2) updating (v = 0) {{ v + i }}. +Check eq_refl : for (int i = 0; i <= 5; i++) updating (v = 0) {{ v + i }} = 15. +Check eq_refl : for (int i = 0; i < 5; i++) updating (v = 0) {{ v + i }} = 10. +Check eq_refl : for (int i = 5; i >= 0; i--) updating (v = 0) {{ v + i }} = 15. +Check eq_refl : for (int i = 5; i > -1; i--) updating (v = 0) {{ v + i }} = 15. +Check eq_refl : for (int i = 0; i <= 5; i += 2) updating (v = 0) {{ v + i }} = 6. +Check eq_refl : for (int i = 0; i < 5; i += 2) updating (v = 0) {{ v + i }} = 6. +Check eq_refl : for (int i = 5; i > -1; i -= 2) updating (v = 0) {{ v + i }} = 9. +Check eq_refl : for (int i = 5; i >= 0; i -= 2) updating (v = 0) {{ v + i }} = 9. +Check eq_refl : for (int i = 0; i <= 6; i += 2) updating (v = 0) {{ v + i }} = 12. +Check eq_refl : for (int i = 0; i < 6; i += 2) updating (v = 0) {{ v + i }} = 6. +Check eq_refl : for (int i = 6; i > -1; i -= 2) updating (v = 0) {{ v + i }} = 12. +Check eq_refl : for (int i = 6; i >= 0; i -= 2) updating (v = 0) {{ v + i }} = 12. + +Local Notation for_sumT n' + := (let n := Z.pos n' in + (2 * + for (int i = 0; i <= n; i++) updating (v = 0) {{ + v + i + }})%Z + = n * (n + 1)) + (only parsing). + +Check eq_refl : for_sumT 5. + +(** Here we show that if we add the numbers from 0 to n, we get [n * (n + 1) / 2] *) +Example for_sum n' : for_sumT n'. +Proof. + intro n. + apply for_loop_ind_le1. + { compute; reflexivity. } + { intros; nia. } +Qed. diff --git a/src/Util/ForLoop/Unrolling.v b/src/Util/ForLoop/Unrolling.v new file mode 100644 index 000000000..e0518f39a --- /dev/null +++ b/src/Util/ForLoop/Unrolling.v @@ -0,0 +1,314 @@ +(** * Proving properties of for-loops via loop-unrolling *) +Require Import Coq.micromega.Psatz. +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.ForLoop. +Require Import Crypto.Util.ForLoop.Instances. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Bool. +Require Import Crypto.Util.Tactics.RewriteHyp. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.DestructHead. +Require Import Crypto.Util.Notations. + +Section with_body. + Context {stateT : Type} + (body : nat -> stateT -> stateT). + + Lemma unfold_repeat_function (count : nat) (st : stateT) + : repeat_function body count st + = match count with + | O => st + | S count' => repeat_function body count' (body count st) + end. + Proof using Type. destruct count; reflexivity. Qed. + + Lemma repeat_function_unroll1_start (count : nat) (st : stateT) + : repeat_function body (S count) st + = repeat_function body count (body (S count) st). + Proof using Type. rewrite unfold_repeat_function; reflexivity. Qed. + + Lemma repeat_function_unroll1_end (count : nat) (st : stateT) + : repeat_function body (S count) st + = body 1 (repeat_function (fun count => body (S count)) count st). + Proof using Type. + revert st; induction count; [ reflexivity | ]. + intros; simpl in *; rewrite <- IHcount; reflexivity. + Qed. + + Lemma repeat_function_unroll1_start_match (count : nat) (st : stateT) + : repeat_function body count st + = match count with + | 0 => st + | S count' => repeat_function body count' (body count st) + end. + Proof using Type. destruct count; [ reflexivity | apply repeat_function_unroll1_start ]. Qed. + + Lemma repeat_function_unroll1_end_match (count : nat) (st : stateT) + : repeat_function body count st + = match count with + | 0 => st + | S count' => body 1 (repeat_function (fun count => body (S count)) count' st) + end. + Proof using Type. destruct count; [ reflexivity | apply repeat_function_unroll1_end ]. Qed. +End with_body. + +Local Open Scope bool_scope. +Local Open Scope Z_scope. + +Section for_loop. + Context (i0 finish : Z) (step : Z) {stateT} (initial : stateT) (body : Z -> stateT -> stateT) + (Hgood : Z.sgn step = Z.sgn (finish - i0)). + + Let countZ := (Z.quot (finish - i0 + step - Z.sgn step) step). + Let count := Z.to_nat countZ. + Let of_nat_count c := (i0 + step * Z.of_nat (count - c)). + Let nat_body := (fun c => body (of_nat_count c)). + + Lemma for_loop_empty + (Heq : finish = i0) + : for_loop i0 finish step initial body = initial. + Proof. + subst; unfold for_loop. + rewrite Z.sub_diag, Z.quot_sub_sgn; autorewrite with zsimplify_const. + reflexivity. + Qed. + + Lemma for_loop_unroll1 + : for_loop i0 finish step initial body + = if finish =? i0 + then initial + else let initial' := body i0 initial in + if Z.abs (finish - i0) <=? Z.abs step + then initial' + else for_loop (i0 + step) finish step initial' body. + Proof. + break_innermost_match_step; Z.ltb_to_lt. + { apply for_loop_empty; assumption. } + { unfold for_loop. + rewrite repeat_function_unroll1_start_match. + destruct (Z_zerop step); + repeat first [ progress break_innermost_match + | congruence + | lia + | progress Z.ltb_to_lt + | progress subst + | progress rewrite Nat.sub_diag + | progress autorewrite with zsimplify_const in * + | progress rewrite Z.quot_small_iff in * by omega + | progress rewrite Z.quot_small_abs in * by lia + | rewrite Nat.sub_succ_l by omega + | progress destruct_head' and + | rewrite !Z.sub_add_distr + | match goal with + | [ H : ?x = Z.of_nat _ |- context[?x] ] => rewrite H + | [ H : Z.abs ?x <= 0 |- _ ] => assert (x = 0) by lia; clear H + | [ H : 0 = Z.sgn ?x |- _ ] => assert (x = 0) by lia; clear H + | [ H : ?x - ?y = 0 |- _ ] => is_var x; assert (x = y) by omega; subst x + | [ H : Z.to_nat _ = _ |- _ ] => apply Nat2Z.inj_iff in H + | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_eq x) in H by omega + | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_neq x) in H by omega + | [ H : Z.of_nat (Z.to_nat _) = _ |- _ ] + => rewrite Z2Nat.id in H by (apply Z.quot_nonneg_same_sgn; lia) + | [ H : _ = Z.of_nat (S ?x) |- _ ] + => is_var x; destruct x; [ reflexivity | ] + | [ H : ?x + 1 = Z.of_nat (S ?y) |- _ ] + => assert (x = Z.of_nat y) by lia; clear H + | [ |- repeat_function _ ?x ?y = repeat_function _ ?x ?y ] + => apply repeat_function_Proper_le; intros + | [ |- ?f _ ?x = ?f _ ?x ] + => is_var f; apply f_equal2; [ | reflexivity ] + end + | progress rewrite Z.quot_add_sub_sgn_small in * |- by lia + | progress autorewrite with zsimplify ]. } + Qed. +End for_loop. + +Lemma for_loop_notation_empty {stateT} + {i0 : Z} {step : Z} {finish : Z} {initial : stateT} + {cmp : Z -> Z -> bool} + {step_expr finish_expr} (body : Z -> stateT -> stateT) + {Hstep : class_eq (fun i => i = step) step_expr} + {Hfinish : class_eq (fun i => cmp i finish) finish_expr} + {Hgood : for_loop_is_good i0 step finish cmp} + (Heq : i0 = finish) + : @for_loop_notation i0 step finish _ initial cmp step_expr finish_expr body Hstep Hfinish Hgood = initial. +Proof. + unfold for_loop_notation, for_loop_is_good in *; split_andb; Z.ltb_to_lt. + apply for_loop_empty; auto. +Qed. + +Local Notation adjust_bool b p + := (match b as b' return b' = true -> b' = true with + | true => fun _ => eq_refl + | false => fun x => x + end p). + +Lemma for_loop_is_good_step_gen + cmp + (Hcmp : cmp = Z.ltb \/ cmp = Z.gtb) + {i0 step finish} + {H : for_loop_is_good i0 step finish cmp} + (H' : cmp (i0 + step) finish = true) + : for_loop_is_good (i0 + step) step finish cmp. +Proof. + unfold for_loop_is_good in *. + rewrite H', Bool.andb_true_r. + destruct Hcmp; subst; + split_andb; Z.ltb_to_lt; + [ rewrite (Z.sgn_pos (finish - i0)) in * by omega + | rewrite (Z.sgn_neg (finish - i0)) in * by omega ]; + destruct step; simpl in *; try congruence; + symmetry; + [ apply Z.sgn_pos_iff | apply Z.sgn_neg_iff ] + ; omega. +Qed. + +Definition for_loop_is_good_step_lt + {i0 step finish} + {H : for_loop_is_good i0 step finish Z.ltb} + (H' : Z.ltb (i0 + step) finish = true) + : for_loop_is_good (i0 + step) step finish Z.ltb + := for_loop_is_good_step_gen Z.ltb (or_introl eq_refl) (H:=H) H'. +Definition for_loop_is_good_step_gt + {i0 step finish} + {H : for_loop_is_good i0 step finish Z.gtb} + (H' : Z.gtb (i0 + step) finish = true) + : for_loop_is_good (i0 + step) step finish Z.gtb + := for_loop_is_good_step_gen Z.gtb (or_intror eq_refl) (H:=H) H'. +Definition for_loop_is_good_step_lt' + {i0 finish} + {H : for_loop_is_good i0 1 (finish + 1) Z.ltb} + (H' : Z.ltb i0 finish = true) + : for_loop_is_good (i0 + 1) 1 (finish + 1) Z.ltb. +Proof. + apply for_loop_is_good_step_lt; Z.ltb_to_lt; omega. +Qed. +Definition for_loop_is_good_step_gt' + {i0 finish} + {H : for_loop_is_good i0 (-1) (finish - 1) Z.gtb} + (H' : Z.gtb i0 finish = true) + : for_loop_is_good (i0 - 1) (-1) (finish - 1) Z.gtb. +Proof. + apply for_loop_is_good_step_gt; Z.ltb_to_lt; omega. +Qed. + +Local Hint Extern 1 (for_loop_is_good (?i0 + ?step) ?step ?finish _) +=> refine (adjust_bool _ (for_loop_is_good_step_lt _)); try assumption : typeclass_instances. +Local Hint Extern 1 (for_loop_is_good (?i0 + ?step) ?step ?finish _) +=> refine (adjust_bool _ (for_loop_is_good_step_gt _)); try assumption : typeclass_instances. +Local Hint Extern 1 (for_loop_is_good (?i0 - ?step') _ ?finish _) +=> refine (adjust_bool _ (for_loop_is_good_step_gt (step:=-step') _)); try assumption : typeclass_instances. +Local Hint Extern 1 (for_loop_is_good (?i0 + 1) 1 ?finish _) +=> refine (adjust_bool _ (for_loop_is_good_step_lt' _)); try assumption : typeclass_instances. +Local Hint Extern 1 (for_loop_is_good (?i0 - 1) (-1) ?finish _) +=> refine (adjust_bool _ (for_loop_is_good_step_gt' _)); try assumption : typeclass_instances. + +Local Ltac t := + repeat match goal with + | _ => progress unfold for_loop_is_good, for_loop_notation in * + | _ => progress rewrite for_loop_unroll1 by auto + | _ => omega + | _ => progress subst + | _ => reflexivity + | _ => progress split_andb + | _ => progress Z.ltb_to_lt + | _ => progress break_innermost_match_step + | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_eq x) in H by omega + | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_neq x) in H by omega + | [ H : context[Z.sgn ?x] |- _ ] => rewrite (Z.sgn_pos x) in H by omega + | [ H : context[Z.sgn ?x] |- _ ] => rewrite (Z.sgn_neg x) in H by omega + | [ H : Z.sgn _ = 1 |- _ ] => apply Z.sgn_pos_iff in H + | [ H : Z.sgn _ = -1 |- _ ] => apply Z.sgn_neg_iff in H + end. + +Lemma for_loop_lt_unroll1 {stateT} + {i0 : Z} {step : Z} {finish : Z} {initial : stateT} + {step_expr finish_expr} (body : Z -> stateT -> stateT) + {Hstep : class_eq (fun i => i = step) step_expr} + {Hfinish : class_eq (fun i => Z.ltb i finish) finish_expr} + {Hgood : for_loop_is_good i0 step finish Z.ltb} + : (@for_loop_notation i0 step finish _ initial Z.ltb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish Hgood) + = let initial' := body i0 initial in + if Sumbool.sumbool_of_bool (Z.ltb (i0 + step) finish) + then @for_loop_notation (i0 + step) step finish _ initial' Z.ltb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish _ + else initial'. +Proof. t. Qed. + +Lemma for_loop_gt_unroll1 {stateT} + {i0 : Z} {step : Z} {finish : Z} {initial : stateT} + {step_expr finish_expr} (body : Z -> stateT -> stateT) + {Hstep : class_eq (fun i => i = step) step_expr} + {Hfinish : class_eq (fun i => Z.gtb i finish) finish_expr} + {Hgood : for_loop_is_good i0 step finish Z.gtb} + : (@for_loop_notation i0 step finish _ initial Z.gtb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish Hgood) + = let initial' := body i0 initial in + if Sumbool.sumbool_of_bool (Z.gtb (i0 + step) finish) + then @for_loop_notation (i0 + step) step finish _ initial' Z.gtb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish _ + else initial'. +Proof. t. Qed. + +Lemma for_loop_lt1_unroll1 {stateT} + {i0 : Z} {finish : Z} {initial : stateT} + {body : Z -> stateT -> stateT} + {Hgood : for_loop_is_good i0 1 finish _} + : for (int i = i0; i < finish; i++) updating (st = initial) {{ + body i st + }} + = let initial' := body i0 initial in + if Sumbool.sumbool_of_bool (Z.ltb (i0 + 1) finish) + then for (int i = i0+1; i < finish; i++) updating (st = initial') {{ + body i st + }} + else initial'. +Proof. apply for_loop_lt_unroll1. Qed. + +Lemma for_loop_gt1_unroll1 {stateT} + {i0 : Z} {finish : Z} {initial : stateT} + {body : Z -> stateT -> stateT} + {Hgood : for_loop_is_good i0 (-1) finish _} + : for (int i = i0; i > finish; i--) updating (st = initial) {{ + body i st + }} + = let initial' := body i0 initial in + if Sumbool.sumbool_of_bool (Z.gtb (i0 - 1) finish) + then for (int i = i0-1; i > finish; i--) updating (st = initial') {{ + body i st + }} + else initial'. +Proof. apply for_loop_gt_unroll1. Qed. + +Lemma for_loop_le1_unroll1 {stateT} + {i0 : Z} {finish : Z} {initial : stateT} + {body : Z -> stateT -> stateT} + {Hgood : for_loop_is_good i0 1 (finish+1) _} + : for (int i = i0; i <= finish; i++) updating (st = initial) {{ + body i st + }} + = let initial' := body i0 initial in + if Sumbool.sumbool_of_bool (Z.ltb i0 finish) + then for (int i = i0+1; i <= finish; i++) updating (st = initial') {{ + body i st + }} + else initial'. +Proof. + rewrite for_loop_lt_unroll1; unfold for_loop_notation. + break_innermost_match; Z.ltb_to_lt; try omega; try reflexivity. +Qed. + +Lemma for_loop_ge1_unroll1 {stateT} + {i0 : Z} {finish : Z} {initial : stateT} + {body : Z -> stateT -> stateT} + {Hgood : for_loop_is_good i0 (-1) (finish-1) _} + : for (int i = i0; i >= finish; i--) updating (st = initial) {{ + body i st + }} + = let initial' := body i0 initial in + if Sumbool.sumbool_of_bool (Z.gtb i0 finish) + then for (int i = i0-1; i >= finish; i--) updating (st = initial') {{ + body i st + }} + else initial'. +Proof. + rewrite for_loop_gt_unroll1; unfold for_loop_notation. + break_innermost_match; Z.ltb_to_lt; try omega; try reflexivity. +Qed. |