diff options
-rw-r--r-- | src/Arithmetic/Saturated.v | 161 | ||||
-rw-r--r-- | src/Util/Tuple.v | 181 |
2 files changed, 228 insertions, 114 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index 884a59ef7..addfe08e5 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -127,6 +127,12 @@ Module Columns. Proof. reflexivity. Qed. Hint Rewrite eval_unit : push_basesystem_eval. + Lemma eval_single (x:list Z) : eval (n:=1) x = sum x. + Proof. + cbv [eval]. simpl map. cbv - [Z.mul Z.add sum]. + rewrite weight_0; ring. + Qed. Hint Rewrite eval_single : push_basesystem_eval. + Definition eval_from {n} (offset:nat) (x : (list Z)^n) : Z := B.Positional.eval (fun i => weight (i+offset)) (Tuple.map sum x). @@ -226,65 +232,103 @@ Module Columns. rewrite Z.div_add' by auto; nsatz. Qed. - Definition compact_invariant n i (starter rem:Z) (inp : tuple (list Z) n) (out : tuple Z n) := - B.Positional.eval_from weight i out + weight (i + n) * rem = eval_from i inp + weight i*starter. - - Lemma compact_invariant_holds n i starter rem inp out : - compact_invariant n (S i) (fst (compact_step_cps i starter (hd inp) id)) rem (tl inp) out -> - compact_invariant (S n) i starter rem inp (append (snd (compact_step_cps i starter (hd inp) id)) out). - Proof using Type*. - cbv [compact_invariant B.Positional.eval_from]; intros. - repeat match goal with - | _ => rewrite B.Positional.eval_step - | _ => rewrite eval_from_S - | _ => rewrite sum_cons - | _ => rewrite weight_multiples - | _ => rewrite Nat.add_succ_l in * - | _ => rewrite Nat.add_succ_r in * - | _ => (rewrite fst_fst_compact_step in * ) - | _ => progress ring_simplify - | _ => rewrite ZUtil.Z.mul_div_eq_full by apply weight_nonzero - | _ => cbv [compact_step_cps] in *; - autorewrite with uncps push_id in *; - rewrite !compact_digit_mod, !compact_digit_div in * - | _ => progress (autorewrite with natsimplify in * ) - end; - rewrite B.Positional.eval_wt_equiv with (wtb := fun i0 => weight (i0 + S i)) by (intros; f_equal; try omega). - { - rewrite Z.mod_eq by auto using Z.positive_is_nonzero. - rewrite sum_cons in H. - ring_simplify. - match type of H with - context [?y * (?a / (?y / ?x))] => - replace (y * (a / (y / x))) with (x * (y / x) * (a / (y / x))) in H - by (rewrite Z.mul_div_eq_full by auto using Z.positive_is_nonzero; - rewrite weight_multiples; ring) - end. - nsatz. - } + (* TODO : move to Core? *) + Lemma Pos_eval_unit : B.Positional.eval (n:=0) weight tt = 0. + Proof. reflexivity. Qed. + Hint Rewrite Pos_eval_unit B.Positional.eval_single + @B.Positional.eval_step : push_basesystem_eval. + (* TODO : move to Core? *) + Lemma Pos_eval_left_append {n} : forall wt x xs, + B.Positional.eval wt (left_append (n:=n) x xs) + = wt n * x + B.Positional.eval wt xs. + Proof. + induction n; intros; try destruct xs; + unfold left_append; fold @left_append; + autorewrite with push_basesystem_eval; [ring|]. + rewrite IHn. + rewrite (subst_append xs), hd_append, tl_append. + rewrite B.Positional.eval_step. + ring. + Qed. + Hint Rewrite @Pos_eval_left_append : push_basesystem_eval. + + Lemma small_mod_eq a b n: a mod n = b mod n -> 0 <= a < n -> a = b mod n. + Proof. intros; rewrite <-(Z.mod_small a n); auto. Qed. + + (* helper for some of the modular logic in compact *) + Lemma compact_mod_step a b c d: 0 < a -> 0 < b -> + a * ((c / a + d) mod b) + c mod a = (a * d + c) mod (a * b). + Proof. + intros Ha Hb. assert (a <= a * b) by (apply Z.le_mul_diag_r; omega). + pose proof (Z.mod_pos_bound c a Ha). + pose proof (Z.mod_pos_bound (c/a+d) b Hb). + apply small_mod_eq. + { rewrite <-(Z.mod_small (c mod a) (a * b)) by omega. + rewrite <-Z.mul_mod_distr_l with (c:=a) by omega. + rewrite Z.mul_add_distr_l, Z.mul_div_eq, <-Z.add_mod_full by omega. + f_equal; ring. } + { split; [zero_bounds|]. + apply Z.lt_le_trans with (m:=a*(b-1)+a); [|ring_simplify; omega]. + apply Z.add_le_lt_mono; try apply Z.mul_le_mono_nonneg_l; omega. } Qed. - - Lemma compact_invariant_base i rem : compact_invariant 0 i rem rem tt tt. - Proof using Type. cbv [compact_invariant]. simpl. repeat (f_equal; try omega). Qed. - - Lemma compact_invariant_end {n} start (input : (list Z)^n): - compact_invariant n 0%nat start (fst (mapi_with_cps compact_step_cps start input id)) input (snd (mapi_with_cps compact_step_cps start input id)). - Proof using Type*. - autorewrite with uncps push_id. - apply (mapi_with_invariant _ compact_invariant - compact_invariant_holds compact_invariant_base). + Lemma compact_div_step a b c d : 0 < a -> 0 < b -> + (c / a + d) / b = (a * d + c) / (a * b). + Proof. + intros Ha Hb. + rewrite <-Z.div_div by omega. + rewrite Z.div_add_l' by omega. + f_equal; ring. Qed. - Lemma eval_compact {n} (xs : tuple (list Z) n) : - B.Positional.eval weight (snd (compact xs)) = eval xs - (weight n * fst (compact xs)). - Proof using Type*. - pose proof (compact_invariant_end 0 xs) as Hinv. - cbv [compact_invariant] in Hinv. - simpl in Hinv. autorewrite with zsimplify natsimplify in Hinv. - rewrite eval_from_0, B.Positional.eval_from_0 in Hinv. nsatz. + Lemma compact_div_mod {n} inp : + (B.Positional.eval weight (snd (compact inp)) + = (eval inp) mod (weight n)) + /\ (fst (compact inp) = eval (n:=n) inp / weight n). + Proof. + cbv [compact compact_cps compact_step compact_step_cps]; + autorewrite with uncps push_id. + change (fun i s a => compact_digit_cps i (s :: a) id) + with (fun i s a => compact_digit i (s :: a)). + + apply mapi_with'_linvariant; [|tauto]. + + clear n inp. intros until 0. intros Hst Hys [Hmod Hdiv]. + pose proof (weight_positive n). pose proof (weight_divides n). + autorewrite with push_basesystem_eval. + destruct n; cbv [mapi_with] in *; simpl tuple in *; + [destruct xs, ys; subst; simpl| cbv [eval] in *]; + repeat match goal with + | _ => rewrite mapi_with'_left_step + | _ => rewrite compact_digit_div, sum_cons + | _ => rewrite compact_digit_mod, sum_cons + | _ => rewrite map_left_append + | _ => rewrite eval_left_append + | _ => rewrite weight_0, ?Z.div_1_r, ?Z.mod_1_r + | _ => rewrite Hdiv + | _ => rewrite Hmod + | _ => progress subst + | _ => progress autorewrite with natsimplify cancel_pair push_basesystem_eval + | _ => solve [split; ring_simplify; f_equal; ring] + end. + remember (weight (S (S n)) / weight (S n)) as bound. + replace (weight (S (S n))) with (weight (S n) * bound) + by (subst bound; rewrite Z.mul_div_eq by omega; + rewrite weight_multiples; ring). + split; [apply compact_mod_step | apply compact_div_step]; omega. Qed. + Lemma compact_mod {n} inp : + (B.Positional.eval weight (snd (compact inp)) + = (eval (n:=n) inp) mod (weight n)). + Proof. apply (proj1 (compact_div_mod inp)). Qed. + Hint Rewrite @compact_mod : push_basesystem_eval. + + Lemma compact_div {n} inp : + fst (compact inp) = eval (n:=n) inp / weight n. + Proof. apply (proj2 (compact_div_mod inp)). Qed. + Hint Rewrite @compact_div : push_basesystem_eval. + Definition cons_to_nth_cps {n} i (x:Z) (t:(list Z)^n) {T} (f:(list Z)^n->T) := @on_tuple_cps _ _ nil (update_nth_cps i (cons x)) n n t _ f. @@ -404,7 +448,8 @@ Hint Rewrite @Columns.from_associational_id : uncps. Hint Rewrite - @Columns.eval_compact + @Columns.compact_mod + @Columns.compact_div @Columns.eval_cons_to_nth @Columns.eval_from_associational @Columns.eval_nils @@ -526,9 +571,13 @@ Section Freeze. = B.Positional.eval weight p + (if (dec (cond = 0)) then 0 else B.Positional.eval weight q) - weight n * (fst (conditional_add mask cond p q)). Proof. cbv [conditional_add_cps conditional_add]; - repeat progress autounfold; rewrite ?Hmask, ?map_land_zero; + repeat progress autounfold in *; rewrite ?Hmask, ?map_land_zero; autorewrite with uncps push_id push_basesystem_eval; - break_match; ring. + break_match; + match goal with + |- context [weight ?n * (?x / weight ?n)] => + pose proof (Z.div_mod x (weight n) (weight_nonzero n)) + end; omega. Qed. Hint Rewrite @eval_conditional_add using (omega || assumption) : push_basesystem_eval. diff --git a/src/Util/Tuple.v b/src/Util/Tuple.v index eab1736a5..3fa0dabde 100644 --- a/src/Util/Tuple.v +++ b/src/Util/Tuple.v @@ -1,6 +1,7 @@ Require Import Coq.Classes.Morphisms. Require Import Coq.Relations.Relation_Definitions. Require Import Coq.Lists.List. +Require Import Coq.omega.Omega. Require Import Crypto.Util.Option. Require Import Crypto.Util.Prod. Require Import Crypto.Util.Tactics.DestructHead. @@ -794,6 +795,81 @@ Fixpoint nth_default {A m} (d:A) n : tuple A m -> A := | S m', S n' => fun x => nth_default d n' (tl x) end. +Fixpoint left_tl {T n} : tuple T (S n) -> tuple T n := + match n with + | O => fun _ => tt + | S n' => fun xs => append (hd xs) (left_tl (tl xs)) + end. + +Fixpoint left_hd {T n} : tuple T (S n) -> T := + match n with + | O => fun x => x + | S n' => fun xs => left_hd (tl xs) + end. + +Fixpoint left_append {T n} (x : T) : tuple T n -> tuple T (S n) := + match n with + | O => fun _ => x + | S n' => fun xs => append (hd xs) (left_append x (tl xs)) + end. + +Lemma left_append_left_hd {T n} (xs : tuple T n) x : + left_hd (left_append x xs) = x. +Proof. induction n; [reflexivity | apply IHn]. Qed. + +Lemma left_append_left_tl {T n} (xs : tuple T n) x : + left_tl (left_append x xs) = xs. +Proof. + induction n; [destruct xs; reflexivity|]. + simpl. rewrite IHn. + symmetry; apply subst_append. +Qed. + +Lemma left_append_append {T n} (xs : tuple T n) r l : + left_append l (append r xs) = append r (left_append l xs). +Proof. destruct n; reflexivity. Qed. + +Lemma left_tl_append {T n} (xs : tuple T (S n)) x: + left_tl (append x xs) = append x (left_tl xs). +Proof. reflexivity. Qed. + +Lemma left_hd_append {T n} (xs : tuple T (S n)) x: + left_hd (append x xs) = left_hd xs. +Proof. reflexivity. Qed. + +Lemma tl_left_append {T n} (xs : tuple T (S n)) x : + tl (left_append x xs) = left_append x (tl xs). +Proof. destruct n; reflexivity. Qed. + +Lemma hd_left_append {T n} (xs : tuple T (S n)) x : + hd (left_append x xs) = hd xs. +Proof. destruct n; reflexivity. Qed. + +Lemma map'_left_append {A B n} f xs x0 : + @map' A B f (S n) (left_append (n:=S n) x0 xs) + = left_append (n:=S n) (f x0) (map' f xs). +Proof. + induction n; try reflexivity. + transitivity (map' f (tl (left_append x0 xs)), f (hd (left_append x0 xs))); [reflexivity|]. + rewrite tl_left_append, IHn. reflexivity. +Qed. + +Lemma map_left_append {A B n} f xs x0 : + @map (S n) A B f (left_append x0 xs) + = left_append (f x0) (map f xs). +Proof. + destruct n; [ destruct xs; reflexivity| apply map'_left_append]. +Qed. + +Lemma subst_left_append {T n} (xs : tuple T (S n)) : + xs = left_append (left_hd xs) (left_tl xs). +Proof. + induction n; [reflexivity|]. + simpl tuple in *; destruct xs as [xs x0]. + simpl; rewrite hd_append, tl_append. + rewrite <-IHn; reflexivity. +Qed. + (* map operation that carries state *) (* first argument to f is index in tuple *) Fixpoint mapi_with' {T A B n} i (f: nat->T->A->T*B) (start:T) @@ -813,37 +889,55 @@ Fixpoint mapi_with {T A B n} (f: nat->T->A->T*B) (start:T) | S n' => fun ys => mapi_with' 0 f start ys end. -Lemma mapi_with'_invariant {T A B} (f: nat->T->A->T*B) - (P : forall n, nat -> T -> T -> tuple A n -> tuple B n -> Prop) - (P_holds : forall n i starter rem inp out, - P n (S i) (fst (f i starter (hd inp))) rem (tl inp) out - -> P (S n) i starter rem inp (append (snd (f i starter (hd inp))) out)) - (P_base : forall i rem, P 0%nat i rem rem tt tt) - : - forall {n} i (start : T) (input : tuple A (S n)), - P (S n) i start (fst (mapi_with' i f start input)) input (snd (mapi_with' i f start input)). -Proof. - induction n; intros. - { specialize (P_holds 0%nat i start (fst (f i start input)) input tt). - apply P_holds. apply P_base. } - { specialize (P_holds (S n) i start (fst (mapi_with' i f start input)) input). - apply P_holds. apply IHn. } -Qed. +Lemma mapi_with'_step {T A B n} i f start inp : + @mapi_with' T A B (S n) i f start inp = + (fst (mapi_with' (S i) f (fst (f i start (hd inp))) (tl inp)), + (snd (mapi_with'(S i) f (fst (f i start (hd inp))) (tl inp)), snd (f i start (hd inp)))). +Proof. reflexivity. Qed. -Lemma mapi_with_invariant {T A B} (f: nat->T->A->T*B) - (P : forall n, nat -> T -> T -> tuple A n -> tuple B n -> Prop) - (P_holds : forall n i starter rem inp out, - P n (S i) (fst (f i starter (hd inp))) rem (tl inp) out - -> P (S n) i starter rem inp (append (snd (f i starter (hd inp))) out)) - (P_base : forall i rem, P 0%nat i rem rem tt tt) - : - forall {n} (start : T) (input : tuple A n), - P n 0%nat start (fst (mapi_with f start input)) input (snd (mapi_with f start input)). -Proof. - destruct n; [intros; destruct input; apply P_base|]; - apply mapi_with'_invariant; auto. +Lemma mapi_with'_left_step {T A B n} f a0: + forall i start (xs : tuple' A n), + @mapi_with' T A B (S n) i f start (left_append (n:=S n) a0 xs) + = (fst (f (i + S n)%nat (fst (mapi_with' i f start xs)) a0), + left_append (n:=S n) + (snd (f (i + S n)%nat + (fst (mapi_with' i f start xs)) a0)) + (snd (mapi_with' i f start xs))). +Proof. + induction n; intros; [subst; simpl; repeat f_equal; omega|]. + rewrite mapi_with'_step; autorewrite with cancel_pair. + rewrite tl_left_append, hd_left_append. + erewrite IHn by reflexivity; subst; autorewrite with cancel_pair. + match goal with |- context [(?xs ,?x0)] => + change (xs, x0) with (append x0 xs) end. + rewrite <-left_append_append. + repeat (f_equal; try omega). +Qed. + +Lemma mapi_with'_linvariant {T A B n} start f + (P : forall n, T -> tuple A n -> tuple B n -> Prop) + (P_holds : forall n st x0 xs ys, + (st = fst (mapi_with f start xs)) -> + (ys = snd (mapi_with f start xs)) -> + P n st xs ys -> + P (S n) (fst (f n st x0)) + (left_append x0 xs) + (left_append (snd (f n st x0)) ys)) + (P_base : P 0%nat start tt tt) (inp : tuple A n): + P n (fst (mapi_with f start inp)) inp (snd (mapi_with f start inp)). +Proof. + induction n; [destruct inp; apply P_base |]. + rewrite (subst_left_append inp). + cbv [mapi_with]. specialize (P_holds n). + destruct n. + { apply (P_holds _ inp tt tt (eq_refl _) (eq_refl _)). + apply P_base. } + { rewrite mapi_with'_left_step. + autorewrite with cancel_pair natsimplify. + apply P_holds; try apply IHn; reflexivity. } Qed. + Fixpoint repeat {A} (a:A) n : tuple A n := match n with | O => tt @@ -859,35 +953,6 @@ Qed. Lemma to_list_repeat {A} (a:A) n : to_list _ (repeat a n) = List.repeat a n. Proof. induction n; [reflexivity|destruct n; simpl in *; congruence]. Qed. -Fixpoint lastn {A m} n : n <= m -> tuple A m -> tuple A n := - match n as n0 return (n0 <= m -> tuple A m -> tuple A n0) with - | O => fun _ _ => tt - | S n' => - match m as m0 return (S n' <= m0 -> tuple A m0 -> tuple A (S n')) with - | O => fun (H : S n' <= 0) _ => - False_rect _ (NPeano.Nat.nle_succ_0 _ H) - | S m' => fun (H : S n' <= S m') xs => - append (hd xs) (lastn n' (le_S_n _ _ H) (tl xs)) - end - end. - -Lemma to_list_lastn {A} n: forall {m} H xs, - to_list n (@lastn A m n H xs) = firstn n (to_list m xs). -Proof. - induction n; intros; destruct m; try rewrite (subst_append xs); - repeat match goal with - | _ => rewrite to_list_append - | _ => rewrite hd_append - | _ => rewrite tl_append - | _ => progress simpl lastn - | _ => progress simpl firstn - | _ => reflexivity - | _ => solve [distr_length] - end. - rewrite IHn. reflexivity. -Qed. - -Definition nth {A} Global Instance map'_Proper {n A B} : Proper (pointwise_relation _ eq ==> eq ==> eq) (fun f => @map' A B f n) | 10. Proof. @@ -903,4 +968,4 @@ Global Instance map_Proper {n A B} : Proper (pointwise_relation _ eq ==> eq ==> Proof. destruct n; [ | apply map'_Proper ]. { repeat (intros [] || intro); auto. } -Qed.
\ No newline at end of file +Qed. |