diff options
author | Jason Gross <jgross@mit.edu> | 2018-06-19 00:03:37 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-06-21 17:03:40 -0400 |
commit | dfbdd15176a1d50ffc5468bb066a18b1bd539588 (patch) | |
tree | aa2a365015e0a2a0de341aa15d3c0ae2b60e1614 /src | |
parent | 1759e06bdc9ef25125216e0398c4cabcd0c0b3f5 (diff) |
Add [freeze] to Arithmetic
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/NewPipeline/Arithmetic.v | 273 |
1 files changed, 239 insertions, 34 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index a52a4ee52..44834e5e3 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -16,6 +16,7 @@ Require Import Crypto.Util.Tactics.RunTacticAsConstr. Require Import Crypto.Util.Tactics.Head. Require Import Crypto.Util.Option. Require Import Crypto.Util.OptionList. +Require Import Crypto.Util.Prod. Require Import Crypto.Util.Sum. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div Crypto.Util.ZUtil.Hints.Core. @@ -459,9 +460,62 @@ Module Positional. Section Positional. End sub. Hint Rewrite @eval_opp @eval_sub : push_eval. Hint Rewrite @length_sub @length_opp : distr_length. + + Section select. + Definition select (mask cond:Z) (p:list Z) := + dlet t := Z.zselect cond 0 mask in List.map (Z.land t) p. + + Lemma map_and_0 n (p:list Z) : length p = n -> map (Z.land 0) p = zeros n. + Proof. + intro; subst; induction p as [|x xs IHxs]; [reflexivity | ]. + cbn; f_equal; auto. + Qed. + Lemma eval_select n mask cond p (H:List.map (Z.land mask) p = p) : + length p = n + -> eval n (select mask cond p) = + if dec (cond = 0) then 0 else eval n p. + Proof. + cbv [select Let_In]. + rewrite Z.zselect_correct; break_match. + { intros; erewrite map_and_0 by eassumption. apply eval_zeros. } + { rewrite H; reflexivity. } + Qed. + Lemma length_select mask cond p : + length (select mask cond p) = length p. + Proof using Type. clear dependent weight. cbv [select Let_In]; break_match; intros; distr_length. Qed. + End select. End Positional. (* Hint Rewrite disappears after the end of a section *) -Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp : distr_length. +Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp @length_select : distr_length. +Hint Rewrite @eval_select : push_eval. +Section Positional_nonuniform. + Context (weight weight' : nat -> Z). + + Lemma eval_hd_tl n (xs:list Z) : + length xs = n -> + eval weight n xs = weight 0%nat * hd 0 xs + eval (fun i => weight (S i)) (pred n) (tl xs). + Proof. + intro; subst; destruct xs as [|x xs]; [ cbn; omega | ]. + cbv [eval to_associational Associational.eval] in *; cbn. + rewrite <- map_S_seq; reflexivity. + Qed. + + Lemma eval_cons n (x:Z) (xs:list Z) : + length xs = n -> + eval weight (S n) (x::xs) = weight 0%nat * x + eval (fun i => weight (S i)) n xs. + Proof. intro; subst; apply eval_hd_tl; reflexivity. Qed. + + Lemma eval_weight_mul n p k : + (forall i, In i (seq 0 n) -> weight i = k * weight' i) -> + eval weight n p = k * eval weight' n p. + Proof. + setoid_rewrite List.in_seq. + revert n weight weight'; induction p as [|x xs IHxs], n as [|n]; intros weight weight' Hwt; + cbv [eval to_associational Associational.eval] in *; cbn in *; try omega. + rewrite Hwt, Z.mul_add_distr_l, Z.mul_assoc by omega. + erewrite <- !map_S_seq, IHxs; [ reflexivity | ]; cbn; eauto with omega. + Qed. +End Positional_nonuniform. End Positional. Record weight_properties {weight : nat -> Z} := @@ -1006,7 +1060,6 @@ Module Columns. Qed. Hint Rewrite eval_cons_to_nth using (solve [distr_length]) : push_eval. Hint Rewrite Positional.eval_zeros : push_eval. - Hint Rewrite Positional.length_from_associational : distr_length. Hint Rewrite Positional.eval_add_to_nth using (solve [distr_length]): push_eval. (* from_associational *) @@ -1108,13 +1161,13 @@ Module Rows. destruct x; cbn [hd tl]; rewrite ?sum_nil, ?sum_cons; ring. Qed. Hint Rewrite eval_extract_row using (solve [distr_length]) : push_eval. - Lemma length_fst_extract_row n (inp : cols) : - length inp = n -> length (fst (extract_row inp)) = n. + Lemma length_fst_extract_row (inp : cols) : + length (fst (extract_row inp)) = length inp. Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed. Hint Rewrite length_fst_extract_row : distr_length. - Lemma length_snd_extract_row n (inp : cols) : - length inp = n -> length (snd (extract_row inp)) = n. + Lemma length_snd_extract_row (inp : cols) : + length (snd (extract_row inp)) = length inp. Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed. Hint Rewrite length_snd_extract_row : distr_length. @@ -1169,7 +1222,8 @@ Module Rows. | _ => progress In_cases | _ => split; try omega | H: _ /\ _ |- _ => destruct H - | _ => solve [auto using length_fst_extract_row, length_snd_extract_row] + | _ => progress distr_length + | _ => solve [auto] end. Qed. Lemma length_fst_from_columns' m st : @@ -1227,7 +1281,7 @@ Module Rows. eapply length_snd_from_columns'; eauto. autorewrite with cancel_pair; intros; In_cases. Qed. - Hint Rewrite length_from_columns : distr_length. + Hint Rewrite length_from_columns using eassumption : distr_length. (* from associational *) Definition from_associational n (p : list (Z * Z)) := from_columns (Columns.from_associational weight n p). @@ -1513,7 +1567,7 @@ Module Rows. Lemma flatten_div_mod inp n : (forall row, In row inp -> length row = n) -> is_div_mod (Positional.eval weight n) (flatten n inp) (eval n inp) (weight n). - Proof. + Proof using wprops. intros; cbv [flatten]. destruct inp; [|destruct inp]; cbn [hd tl]. { cbv [is_div_mod]; push. @@ -1556,7 +1610,7 @@ Module Rows. forall i, (i < n)%nat -> nth_default 0 (fst (flatten' start_state inp)) i = ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i). - Proof. + Proof using wprops. induction inp using rev_ind; push. destruct (dec (inp = nil)). { subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. push. } @@ -1569,7 +1623,7 @@ Module Rows. (forall row, In row inp -> length row = n) -> forall i, (i < n)%nat -> nth_default 0 (fst (flatten n inp)) i = (eval n inp mod weight (S i)) / (weight i). - Proof. + Proof using wprops. intros; cbv [flatten]. intros; destruct inp as [| ? [| ? ?] ]; try congruence; cbn [hd tl] in *; try solve [push]. { cbn. autorewrite with push_nth_default. @@ -1587,7 +1641,7 @@ Module Rows. (forall i, (i < n)%nat -> nth_default 0 p i = (x mod weight (S i)) / weight i) -> length p = n -> p = partition n x. - Proof. + Proof using Type. cbv [partition]; induction p using rev_ind; intros; distr_length; subst n; [reflexivity|]. rewrite Nat.add_1_r, seq_snoc. autorewrite with natsimplify push_map. @@ -1601,18 +1655,18 @@ Module Rows. Lemma partition_step n x : partition (S n) x = partition n x ++ [(x mod weight (S n)) / weight n]. - Proof. + Proof using Type. cbv [partition]. rewrite seq_snoc. autorewrite with natsimplify push_map. reflexivity. Qed. Lemma length_partition n x : length (partition n x) = n. - Proof. cbv [partition]; distr_length. Qed. + Proof using Type. cbv [partition]; distr_length. Qed. Hint Rewrite length_partition : distr_length. Lemma eval_partition n x : Positional.eval weight n (partition n x) = x mod (weight n). - Proof. + Proof using wprops. induction n; intros. { cbn. rewrite (weight_0); auto with zarith. } { rewrite (Z.div_mod (x mod weight (S n)) (weight n)) by auto. @@ -1621,10 +1675,26 @@ Module Rows. omega. } Qed. + Lemma partition_Proper n : + Proper (Z.equiv_modulo (weight n) ==> eq) (partition n). + Proof using wprops. + cbv [Proper Z.equiv_modulo respectful]. + intros x y Hxy; induction n; intros. + { reflexivity. } + { assert (Hxyn : x mod weight n = y mod weight n). + { erewrite (Znumtheory.Zmod_div_mod _ (weight (S n)) x), (Znumtheory.Zmod_div_mod _ (weight (S n)) y), Hxy + by (try apply Z.mod_divide; auto); + reflexivity. } + rewrite !partition_step, IHn by eauto. + rewrite (Z.div_mod (x mod weight (S n)) (weight n)), (Z.div_mod (y mod weight (S n)) (weight n)) by auto. + rewrite <-!Znumtheory.Zmod_div_mod by (try apply Z.mod_divide; auto). + rewrite Hxy, Hxyn; reflexivity. } + Qed. + Lemma flatten_partitions' inp n : (forall row, In row inp -> length row = n) -> fst (flatten n inp) = partition n (eval n inp). - Proof. auto using nth_default_partitions, flatten_partitions, length_flatten. Qed. + Proof using wprops. auto using nth_default_partitions, flatten_partitions, length_flatten. Qed. End Flatten. Section Ops. @@ -1639,6 +1709,10 @@ Module Rows. fine; we should check this. *) Definition sub n p q := flatten n [p; map (fun x => dlet y := x in Z.opp y) q]. + Definition conditional_add n mask cond (p q:list Z) := + let qq := Positional.select mask cond q in + add n p qq. + Hint Rewrite eval_cons eval_nil using solve [auto] : push_eval. Definition mul base n m (p q : list Z) := @@ -1675,17 +1749,34 @@ Module Rows. Lemma add_partitions n p q : n <> 0%nat -> length p = n -> length q = n -> fst (add n p q) = partition n (Positional.eval weight n p + Positional.eval weight n q). - Proof. solver. Qed. + Proof using wprops. solver. Qed. Lemma add_div n p q : n <> 0%nat -> length p = n -> length q = n -> snd (add n p q) = (Positional.eval weight n p + Positional.eval weight n q) / weight n. - Proof. solver. Qed. + Proof using wprops. solver. Qed. + + Lemma conditional_add_partitions n mask cond p q : + n <> 0%nat -> length p = n -> length q = n -> map (Z.land mask) q = q -> + fst (conditional_add n mask cond p q) + = partition n (Positional.eval weight n p + if dec (cond = 0) then 0 else Positional.eval weight n q). + Proof using wprops. + cbv [conditional_add]; intros; rewrite add_partitions by (distr_length; auto). + autorewrite with push_eval; auto. + Qed. + + Lemma conditional_add_div n mask cond p q : + n <> 0%nat -> length p = n -> length q = n -> map (Z.land mask) q = q -> + snd (conditional_add n mask cond p q) = (Positional.eval weight n p + if dec (cond = 0) then 0 else Positional.eval weight n q) / weight n. + Proof using wprops. + cbv [conditional_add]; intros; rewrite add_div by (distr_length; auto). + autorewrite with push_eval; auto. + Qed. Lemma eval_map_opp q : forall n, length q = n -> Positional.eval weight n (map Z.opp q) = - Positional.eval weight n q. - Proof. + Proof using Type. induction q using rev_ind; intros; repeat match goal with | _ => progress autorewrite with push_map push_eval @@ -1698,23 +1789,23 @@ Module Rows. Lemma sub_partitions n p q : n <> 0%nat -> length p = n -> length q = n -> fst (sub n p q) = partition n (Positional.eval weight n p - Positional.eval weight n q). - Proof. solver. Qed. + Proof using wprops. solver. Qed. Lemma sub_div n p q : n <> 0%nat -> length p = n -> length q = n -> snd (sub n p q) = (Positional.eval weight n p - Positional.eval weight n q) / weight n. - Proof. solver. Qed. + Proof using wprops. solver. Qed. Lemma mul_partitions base n m p q : base <> 0 -> n <> 0%nat -> m <> 0%nat -> length p = n -> length q = n -> fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q). - Proof. solver. Qed. + Proof using wprops. solver. Qed. Lemma eval_sat_reduce base s c p : base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 -> Associational.eval (sat_reduce base s c p) mod (s - Associational.eval c) = Associational.eval p mod (s - Associational.eval c). - Proof. + Proof using Type. intros; cbv [sat_reduce]. autorewrite with push_eval. rewrite <-Associational.reduction_rule by omega. @@ -1726,7 +1817,7 @@ Module Rows. base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 -> Associational.eval (repeat_sat_reduce base s c p n) mod (s - Associational.eval c) = Associational.eval p mod (s - Associational.eval c). - Proof. + Proof using Type. intros; cbv [repeat_sat_reduce]. apply fold_right_invariant; intros; autorewrite with push_eval; auto. Qed. @@ -1738,13 +1829,16 @@ Module Rows. (Positional.eval weight n (fst (mulmod base s c n nreductions p q)) + weight n * (snd (mulmod base s c n nreductions p q))) mod (s - Associational.eval c) = (Positional.eval weight n p * Positional.eval weight n q) mod (s - Associational.eval c). - Proof. + Proof using wprops. solver. rewrite <-Z.div_mod'' by auto. autorewrite with push_eval; reflexivity. Qed. End Ops. End Rows. + Hint Rewrite length_from_columns using eassumption : distr_length. + Hint Rewrite length_sum_rows using solve [ reflexivity | eassumption | distr_length; eauto ] : distr_length. + Hint Rewrite length_fst_extract_row length_snd_extract_row length_flatten length_flatten' length_partition length_fst_from_columns' length_snd_from_columns' : distr_length. End Rows. Module BaseConversion. @@ -1762,7 +1856,7 @@ Module BaseConversion. Lemma eval_convert_bases sn dn p : (dn <> 0%nat) -> length p = sn -> eval dw dn (convert_bases sn dn p) = eval sw sn p. - Proof. + Proof using dwprops. cbv [convert_bases]; intros. rewrite eval_chained_carries_no_reduce; auto using ZUtil.Z.positive_is_nonzero. rewrite eval_from_associational; auto. @@ -1792,7 +1886,7 @@ Module BaseConversion. Lemma eval_reordering_carry w fw p (_:fw<>0): Associational.eval (reordering_carry w fw p) = Associational.eval p. - Proof. + Proof using Type. cbv [reordering_carry]. induction p; [reflexivity |]. autorewrite with push_fold_right. break_match; push_eval. Qed. @@ -1814,13 +1908,13 @@ Module BaseConversion. Lemma eval_to_associational n m p : m <> 0%nat -> length p = n -> Associational.eval (to_associational n m p) = Positional.eval sw n p. - Proof. cbv [to_associational]; push_eval. Qed. + Proof using dwprops. cbv [to_associational]; push_eval. Qed. Hint Rewrite eval_to_associational using solve [push_eval; distr_length] : push_eval. Lemma eval_from_associational idxs n p : n <> 0%nat -> 0 <= Associational.eval p < sw n -> Positional.eval sw n (from_associational idxs n p) = Associational.eval p. - Proof. + Proof using dwprops swprops. cbv [from_associational]; intros. rewrite Rows.flatten_mod by eauto using Rows.length_from_associational. rewrite Associational.bind_snd_correct. @@ -1831,7 +1925,7 @@ Module BaseConversion. Lemma from_associational_partitions n idxs p (_:n<>0%nat): forall i, (i < n)%nat -> nth_default 0 (from_associational idxs n p) i = (Associational.eval p) mod (sw (S i)) / sw i. - Proof. + Proof using dwprops swprops. intros; cbv [from_associational]. rewrite Rows.flatten_partitions with (n:=n) by (eauto using Rows.length_from_associational; omega). rewrite Associational.bind_snd_correct. @@ -1840,7 +1934,7 @@ Module BaseConversion. Lemma from_associational_eq n idxs p (_:n<>0%nat): from_associational idxs n p = Rows.partition sw n (Associational.eval p). - Proof. + Proof using dwprops swprops. intros. cbv [from_associational]. rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational. rewrite Associational.bind_snd_correct. @@ -1898,13 +1992,13 @@ Module BaseConversion. length p1 = n1 -> length p2 = n2 -> 0 <= (Positional.eval sw n1 p1 * Positional.eval sw n2 p2) < sw n3 -> Positional.eval sw n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval sw n1 p1) * (Positional.eval sw n2 p2). - Proof. cbv [mul_converted]; push_eval. Qed. + Proof using dwprops swprops. cbv [mul_converted]; push_eval. Qed. Hint Rewrite eval_mul_converted : push_eval. Lemma mul_converted_partitions n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = Rows.partition sw n3 (Positional.eval sw n1 p1 * Positional.eval sw n2 p2). - Proof. + Proof using dwprops swprops. intros; cbv [mul_converted]. rewrite from_associational_eq by auto. push_eval. Qed. @@ -1931,7 +2025,7 @@ Module BaseConversion. Lemma widemul_correct a b : 0 <= a * b < 2^log2base * 2^log2base -> widemul a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base]. - Proof. + Proof using dwprops swprops. cbv [widemul]; intros. rewrite mul_converted_partitions by auto with zarith. subst nout sw; cbv [weight]; cbn. @@ -1980,3 +2074,114 @@ Module BaseConversion. Qed. End widemul. End BaseConversion. + +(* TODO: rename this module? (Should it be, e.g., [Rows.freeze]?) *) +Module Freeze. + Section Freeze. + Context weight {wprops : @weight_properties weight}. + + Definition freeze n mask (m p:list Z) : list Z := + let '(p, carry) := Rows.sub weight n p m in + let '(r, carry) := Rows.conditional_add weight n mask carry p m in + r. + + Lemma freezeZ m s c y : + m = s - c -> + 0 < c < s -> + s <> 0 -> + 0 <= y < 2*m -> + ((y - m) + (if (dec ((y - m) / s = 0)) then 0 else m)) mod s + = y mod m. + Proof using Type. + clear; intros. + transitivity ((y - m) mod m); + repeat first [ progress intros + | progress subst + | break_innermost_match_step + | progress autorewrite with zsimplify_fast + | rewrite Z.div_small_iff in * by auto + | progress (Z.rewrite_mod_small; push_Zmod; Z.rewrite_mod_small) + | progress destruct_head'_or + | omega ]. + Qed. + + Lemma length_freeze n mask m p : + length m = n -> length p = n -> length (freeze n mask m p) = n. + Proof using wprops. + cbv [freeze Rows.conditional_add Rows.add]; eta_expand; intros. + distr_length; try assumption; cbn; intros; destruct_head'_or; destruct_head' False; subst; + distr_length. + erewrite Rows.length_sum_rows by (reflexivity || eassumption || distr_length); distr_length. + Qed. + Lemma eval_freeze_eq n mask m p + (n_nonzero:n<>0%nat) + (Hmask : List.map (Z.land mask) m = m) + (Hplen : length p = n) + (Hmlen : length m = n) + : Positional.eval weight n (@freeze n mask m p) + = (Positional.eval weight n p - Positional.eval weight n m + + (if dec ((Positional.eval weight n p - Positional.eval weight n m) / weight n = 0) then 0 else Positional.eval weight n m)) + mod weight n. + (*if dec ((Positional.eval weight n p - Positional.eval weight n m) / weight n = 0) + then Positional.eval weight n p - Positional.eval weight n m + else Positional.eval weight n p mod weight n.*) + Proof using wprops. + pose proof (@weight_positive weight wprops n). + cbv [freeze Z.equiv_modulo]; eta_expand. + repeat first [ solve [auto] + | rewrite Rows.conditional_add_partitions + | rewrite Rows.sub_partitions + | rewrite Rows.sub_div + | rewrite Rows.eval_partition + | progress distr_length + | progress pull_Zmod (* + | progress break_innermost_match_step + | progress destruct_head'_or + | omega + | f_equal; omega + | rewrite Z.div_small_iff in * by (auto using (@weight_positive weight ltac:(assumption))) + | progress Z.rewrite_mod_small *) ]. + Qed. + + Lemma eval_freeze n c mask m p + (n_nonzero:n<>0%nat) + (Hc : 0 < Associational.eval c < weight n) + (Hmask : List.map (Z.land mask) m = m) + modulus (Hm : Positional.eval weight n m = Z.pos modulus) + (Hp : 0 <= Positional.eval weight n p < 2*(Z.pos modulus)) + (Hsc : Z.pos modulus = weight n - Associational.eval c) + (Hplen : length p = n) + (Hmlen : length m = n) + : Positional.eval weight n (@freeze n mask m p) + = Positional.eval weight n p mod (Z.pos modulus). + Proof using wprops. + pose proof (@weight_positive weight wprops n). + rewrite eval_freeze_eq by assumption. + erewrite freezeZ; try eassumption; try omega. + f_equal; omega. + Qed. + + Lemma freeze_partitions n c mask m p + (n_nonzero:n<>0%nat) + (Hc : 0 < Associational.eval c < weight n) + (Hmask : List.map (Z.land mask) m = m) + modulus (Hm : Positional.eval weight n m = Z.pos modulus) + (Hp : 0 <= Positional.eval weight n p < 2*(Z.pos modulus)) + (Hsc : Z.pos modulus = weight n - Associational.eval c) + (Hplen : length p = n) + (Hmlen : length m = n) + : @freeze n mask m p = Rows.partition weight n (Positional.eval weight n p mod (Z.pos modulus)). + Proof using wprops. + pose proof (@weight_positive weight wprops n). + pose proof (fun v => Z.mod_pos_bound v (weight n) ltac:(lia)). + pose proof (Z.mod_pos_bound (Positional.eval weight n p) (Z.pos modulus) ltac:(lia)). + erewrite <- eval_freeze by eassumption. + cbv [freeze]; eta_expand. + rewrite Rows.conditional_add_partitions by (auto; rewrite Rows.sub_partitions; auto; distr_length). + rewrite !Rows.eval_partition by assumption. + apply Rows.partition_Proper; [ assumption .. | ]. + cbv [Z.equiv_modulo]. + pull_Zmod; reflexivity. + Qed. + End Freeze. +End Freeze. |