diff options
author | jadep <jade.philipoom@gmail.com> | 2018-09-14 18:30:13 -0400 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2018-09-17 21:34:36 -0400 |
commit | 53543a7653b3cdc1269d6b7c8f0d1bc5b442ed25 (patch) | |
tree | 87ae9b981e4f4b5080374027bf9de185c4380805 /src | |
parent | d112376dc841b153b0b3a0bc16c0f614f5988acb (diff) |
move partition and its proofs to a new module and use it for correctness of Columns.flatten
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/NewPipeline/Arithmetic.v | 286 |
1 files changed, 164 insertions, 122 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index fae201547..518701202 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -1178,6 +1178,34 @@ Module Saturated. Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i). Proof using wprops. intros. apply Z.div_exact; auto using weight_multiples_full. Qed. + + Lemma weight_mod_pull_div n x : + x mod weight (S n) / weight n = + (x / weight n) mod (weight (S n) / weight n). + Proof using wprops. + pose proof (@weight_positive _ wprops n). + replace (weight (S n)) with (weight n * (weight (S n) / weight n)); + repeat match goal with + | _ => progress autorewrite with zdiv_to_mod zsimplify_fast + | _ => rewrite Z.mod_pull_div + | _ => rewrite weight_multiples by assumption + | _ => solve [auto using Z.lt_le_incl] + end. + Qed. + + Lemma weight_div_pull_div n x : + x / weight (S n) = + (x / weight n) / (weight (S n) / weight n). + Proof using wprops. + pose proof (@weight_positive _ wprops n). + replace (weight (S n)) with (weight n * (weight (S n) / weight n)); + repeat match goal with + | _ => progress autorewrite with zdiv_to_mod zsimplify_fast + | _ => rewrite Z.div_div by auto + | _ => rewrite weight_multiples by assumption + | _ => solve [auto using Z.lt_le_incl] + end. + Qed. End Weight. Module Associational. @@ -1328,8 +1356,81 @@ Module Saturated. End DivMod. End Saturated. +Module Partition. + Section Partition. + Context weight {wprops : @weight_properties weight}. + Definition partition n x := + map (fun i => (x mod weight (S i)) / weight i) (seq 0 n). + + (* TODO : move to ListUtil *) + Lemma nth_default_map_seq_equiv {A} l f d n + (Hlength : length l = n) + (Hnth_default : forall i, (i < n)%nat -> @nth_default A d l i = f i) : + l = map f (seq 0 n). + Proof using Type. + apply list_elementwise_eq. subst n. + intro i; destruct (lt_dec i (length l)). + { rewrite !nth_error_Some_nth_default with (x:=d) by distr_length. + autorewrite with push_nth_default. + rewrite map_nth_default with (x:=0%nat) by distr_length. + rewrite Hnth_default by distr_length. + rewrite nth_default_seq_inbounds by distr_length. + f_equal. } + { rewrite !nth_error_length_error by distr_length. + reflexivity. } + Qed. + + Lemma nth_default_partitions x : forall p n, + (forall i, (i < n)%nat -> nth_default 0 p i = (x mod weight (S i)) / weight i) -> + length p = n -> + p = partition n x. + Proof using Type. + cbv [partition]; intros. + eauto using nth_default_map_seq_equiv. + Qed. + + Lemma partition_step n x : + partition (S n) x = partition n x ++ [(x mod weight (S n)) / weight n]. + Proof using Type. + cbv [partition]. rewrite seq_snoc. + autorewrite with natsimplify push_map. reflexivity. + Qed. + + Lemma length_partition n x : length (partition n x) = n. + Proof using Type. cbv [partition]; distr_length. Qed. + Hint Rewrite length_partition : distr_length. + + Lemma eval_partition n x : + Positional.eval weight n (partition n x) = x mod (weight n). + Proof using wprops. + induction n; intros. + { cbn. rewrite (weight_0); auto with zarith. } + { rewrite (Z.div_mod (x mod weight (S n)) (weight n)) by auto. + rewrite <-Znumtheory.Zmod_div_mod by (try apply Z.mod_divide; auto). + rewrite partition_step, Positional.eval_snoc with (n:=n) by distr_length. + omega. } + Qed. + + Lemma partition_Proper n : + Proper (Z.equiv_modulo (weight n) ==> eq) (partition n). + Proof using wprops. + cbv [Proper Z.equiv_modulo respectful]. + intros x y Hxy; induction n; intros. + { reflexivity. } + { assert (Hxyn : x mod weight n = y mod weight n). + { erewrite (Znumtheory.Zmod_div_mod _ (weight (S n)) x), (Znumtheory.Zmod_div_mod _ (weight (S n)) y), Hxy + by (try apply Z.mod_divide; auto); + reflexivity. } + rewrite !partition_step, IHn by eauto. + rewrite (Z.div_mod (x mod weight (S n)) (weight n)), (Z.div_mod (y mod weight (S n)) (weight n)) by auto. + rewrite <-!Znumtheory.Zmod_div_mod by (try apply Z.mod_divide; auto). + rewrite Hxy, Hxyn; reflexivity. } + Qed. + End Partition. +End Partition. + Module Columns. - Import Saturated. + Import Saturated. Import Partition. Section Columns. Context weight {wprops : @weight_properties weight}. @@ -1439,25 +1540,46 @@ Module Columns. Proof using Type. cbv [flatten]. induction inp using rev_ind; push. Qed. Hint Rewrite length_flatten : distr_length. + Lemma flatten_snoc x inp : flatten (inp ++ [x]) = flatten_step x (flatten inp). + Proof using Type. cbv [flatten]. rewrite rev_unit. reflexivity. Qed. + + Lemma flatten_correct inp: + forall n, + length inp = n -> + flatten inp = (partition weight n (eval n inp), + eval n inp / (weight n)). + Proof. + induction inp using rev_ind; intros; + destruct n; distr_length; [ reflexivity | ]. + rewrite flatten_snoc. + rewrite partition_step. + erewrite IHinp with (n:=n) by distr_length. + push. + pose proof (@weight_positive _ wprops n). + repeat match goal with + | |- pair _ _ = pair _ _ => f_equal + | |- _ ++ _ = _ ++ _ => f_equal + | |- _ :: _ = _ :: _ => f_equal + | _ => apply partition_Proper; + [ assumption | cbv [Z.equiv_modulo ] ] + | _ => rewrite length_partition + | _ => rewrite weight_mod_pull_div by assumption + | _ => rewrite weight_div_pull_div by assumption + | _ => f_equal; ring + | _ => progress autorewrite with zsimplify + end. + Qed. + Lemma flatten_div_mod n inp : length inp = n -> (Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)) /\ (snd (flatten inp) = eval n inp / weight n). Proof using wprops. - (* to make the invariant take the right form, we make everything depend on output length, not input length *) - intro. subst n. rewrite <-(length_flatten inp). cbv [flatten]. - induction inp using rev_ind; intros; [push|]. - repeat match goal with - | _ => rewrite Nat.add_1_r - | _ => progress (fold (flatten inp) in * ) - | _ => erewrite Positional.eval_snoc by (distr_length; reflexivity) - | H: _ = _ mod (weight _) |- _ => rewrite H - | H: _ = _ / (weight _) |- _ => rewrite H - | _ => progress rewrite ?mod_step, ?div_step by auto - | _ => progress autorewrite with cancel_pair to_div_mod push_sum list push_fold_right push_eval - | _ => progress (distr_length; push_fast) - end. + intros. + rewrite flatten_correct with (n:=n) by auto. + cbn [fst snd]. + rewrite eval_partition; auto. Qed. Lemma flatten_mod {n} inp : @@ -1470,29 +1592,6 @@ Module Columns. length inp = n -> snd (flatten inp) = eval n inp / weight n. Proof using wprops. apply flatten_div_mod. Qed. Hint Rewrite @flatten_div : push_eval. - - Lemma flatten_snoc x inp : flatten (inp ++ [x]) = flatten_step x (flatten inp). - Proof using Type. cbv [flatten]. rewrite rev_unit. reflexivity. Qed. - - Lemma flatten_partitions inp: - forall n i, length inp = n -> (i < n)%nat -> - nth_default 0 (fst (flatten inp)) i = ((eval n inp) mod (weight (S i))) / weight i. - Proof using wprops. - induction inp using rev_ind; intros; destruct n; distr_length. - rewrite flatten_snoc. - push; distr_length; - [rewrite IHinp with (n:=n) by omega; rewrite weight_div_mod with (j:=n) (i:=S i) by (eauto; omega); push_Zmod; push |]. - repeat match goal with - | _ => progress replace (length inp) with n by omega - | _ => progress replace i with n by omega - | _ => progress push - | _ => erewrite flatten_div by eauto - | _ => rewrite <-Z.div_add' by auto - | _ => rewrite Z.mul_div_eq' by auto - | _ => rewrite Z.mod_pull_div by auto using Z.lt_le_incl - | _ => progress autorewrite with push_nth_default natsimplify - end. - Qed. End Flatten. Section FromAssociational. @@ -1566,7 +1665,7 @@ Module Columns. End Columns. Module Rows. - Import Saturated. + Import Saturated. Import Partition. Section Rows. Context weight {wprops : @weight_properties weight}. @@ -2100,66 +2199,9 @@ Module Rows. { rewrite flatten'_partitions with (n:=n); push. } Qed. - Definition partition n x := - map (fun i => (x mod weight (S i)) / weight i) (seq 0 n). - - Lemma nth_default_partitions x : forall p n, - (forall i, (i < n)%nat -> nth_default 0 p i = (x mod weight (S i)) / weight i) -> - length p = n -> - p = partition n x. - Proof using Type. - cbv [partition]; induction p using rev_ind; intros; distr_length; subst n; [reflexivity|]. - rewrite Nat.add_1_r, seq_snoc. - autorewrite with natsimplify push_map. - rewrite <-IHp; auto; intros; - match goal with H : context [nth_default _ (p ++ [ _ ])] |- _ => - rewrite <-H by omega end. - { autorewrite with push_nth_default natsimplify. reflexivity. } - { autorewrite with push_nth_default natsimplify. - break_match; omega. } - Qed. - - Lemma partition_step n x : - partition (S n) x = partition n x ++ [(x mod weight (S n)) / weight n]. - Proof using Type. - cbv [partition]. rewrite seq_snoc. - autorewrite with natsimplify push_map. reflexivity. - Qed. - - Lemma length_partition n x : length (partition n x) = n. - Proof using Type. cbv [partition]; distr_length. Qed. - Hint Rewrite length_partition : distr_length. - - Lemma eval_partition n x : - Positional.eval weight n (partition n x) = x mod (weight n). - Proof using wprops. - induction n; intros. - { cbn. rewrite (weight_0); auto with zarith. } - { rewrite (Z.div_mod (x mod weight (S n)) (weight n)) by auto. - rewrite <-Znumtheory.Zmod_div_mod by (try apply Z.mod_divide; auto). - rewrite partition_step, Positional.eval_snoc with (n:=n) by distr_length. - omega. } - Qed. - - Lemma partition_Proper n : - Proper (Z.equiv_modulo (weight n) ==> eq) (partition n). - Proof using wprops. - cbv [Proper Z.equiv_modulo respectful]. - intros x y Hxy; induction n; intros. - { reflexivity. } - { assert (Hxyn : x mod weight n = y mod weight n). - { erewrite (Znumtheory.Zmod_div_mod _ (weight (S n)) x), (Znumtheory.Zmod_div_mod _ (weight (S n)) y), Hxy - by (try apply Z.mod_divide; auto); - reflexivity. } - rewrite !partition_step, IHn by eauto. - rewrite (Z.div_mod (x mod weight (S n)) (weight n)), (Z.div_mod (y mod weight (S n)) (weight n)) by auto. - rewrite <-!Znumtheory.Zmod_div_mod by (try apply Z.mod_divide; auto). - rewrite Hxy, Hxyn; reflexivity. } - Qed. - Lemma flatten_partitions' inp n : (forall row, In row inp -> length row = n) -> - fst (flatten n inp) = partition n (eval n inp). + fst (flatten n inp) = partition weight n (eval n inp). Proof using wprops. auto using nth_default_partitions, flatten_partitions, length_flatten. Qed. End Flatten. Hint Rewrite length_partition : distr_length. @@ -2228,7 +2270,7 @@ Module Rows. Lemma add_partitions n p q : n <> 0%nat -> length p = n -> length q = n -> - fst (add n p q) = partition n (Positional.eval weight n p + Positional.eval weight n q). + fst (add n p q) = partition weight n (Positional.eval weight n p + Positional.eval weight n q). Proof using wprops. solver. Qed. Lemma add_div n p q : @@ -2239,7 +2281,7 @@ Module Rows. Lemma conditional_add_partitions n mask cond p q : n <> 0%nat -> length p = n -> length q = n -> map (Z.land mask) q = q -> fst (conditional_add n mask cond p q) - = partition n (Positional.eval weight n p + if dec (cond = 0) then 0 else Positional.eval weight n q). + = partition weight n (Positional.eval weight n p + if dec (cond = 0) then 0 else Positional.eval weight n q). Proof using wprops. cbv [conditional_add]; intros; rewrite add_partitions by (distr_length; auto). autorewrite with push_eval; auto. @@ -2268,7 +2310,7 @@ Module Rows. Lemma sub_partitions n p q : n <> 0%nat -> length p = n -> length q = n -> - fst (sub n p q) = partition n (Positional.eval weight n p - Positional.eval weight n q). + fst (sub n p q) = partition weight n (Positional.eval weight n p - Positional.eval weight n q). Proof using wprops. solver. Qed. Lemma sub_div n p q : @@ -2278,7 +2320,7 @@ Module Rows. Lemma mul_partitions base n m p q : base <> 0 -> n <> 0%nat -> m <> 0%nat -> length p = n -> length q = n -> - fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q). + fst (mul base n m p q) = partition weight m (Positional.eval weight n p * Positional.eval weight n q). Proof using wprops. solver. Qed. Lemma mul_div base n m p q : @@ -2374,7 +2416,7 @@ hd 0 p). End Rows. Module BaseConversion. - Import Positional. + Import Positional. Import Partition. Section BaseConversion. Hint Resolve Z.gt_lt. Context (sw dw : nat -> Z) (* source/destination weight functions *) @@ -2465,7 +2507,7 @@ Module BaseConversion. Qed. Lemma from_associational_eq n idxs p (_:n<>0%nat): - from_associational idxs n p = Rows.partition sw n (Associational.eval p). + from_associational idxs n p = partition sw n (Associational.eval p). Proof using dwprops swprops. intros. cbv [from_associational]. rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational. @@ -2529,7 +2571,7 @@ Module BaseConversion. Lemma mul_converted_partitions n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> - mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = Rows.partition sw n3 (Positional.eval sw n1 p1 * Positional.eval sw n2 p2). + mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = partition sw n3 (Positional.eval sw n1 p1 * Positional.eval sw n2 p2). Proof using dwprops swprops. intros; cbv [mul_converted]. rewrite from_associational_eq by auto. push_eval. @@ -2666,7 +2708,7 @@ Module Freeze. | rewrite Rows.conditional_add_partitions | rewrite Rows.sub_partitions | rewrite Rows.sub_div - | rewrite Rows.eval_partition + | rewrite Partition.eval_partition | progress distr_length | progress pull_Zmod (* | progress break_innermost_match_step @@ -2704,7 +2746,7 @@ Module Freeze. (Hp : 0 <= Positional.eval weight n p < 2*modulus) (Hplen : length p = n) (Hmlen : length m = n) - : @freeze n mask m p = Rows.partition weight n (Positional.eval weight n p mod modulus). + : @freeze n mask m p = Partition.partition weight n (Positional.eval weight n p mod modulus). Proof using wprops. pose proof (@weight_positive weight wprops n). pose proof (fun v => Z.mod_pos_bound v (weight n) ltac:(lia)). @@ -2713,8 +2755,8 @@ Module Freeze. erewrite <- eval_freeze by eassumption. cbv [freeze]; eta_expand. rewrite Rows.conditional_add_partitions by (auto; rewrite Rows.sub_partitions; auto; distr_length). - rewrite !Rows.eval_partition by assumption. - apply Rows.partition_Proper; [ assumption .. | ]. + rewrite !Partition.eval_partition by assumption. + apply Partition.partition_Proper; [ assumption .. | ]. cbv [Z.equiv_modulo]. pull_Zmod; reflexivity. Qed. @@ -2849,7 +2891,7 @@ Section freeze_mod_ops. Lemma to_bytes_partitions : forall (f : list Z) (Hf : length f = n), - to_bytes f = Rows.partition bytes_weight bytes_n (Positional.eval weight n f). + to_bytes f = Partition.partition bytes_weight bytes_n (Positional.eval weight n f). Proof using Hn_nz limbwidth_good. clear -Hn_nz limbwidth_good. intros; cbv [to_bytes]. @@ -2867,7 +2909,7 @@ Section freeze_mod_ops. (Hf : length f = n) (Hf_small : 0 <= eval weight n f < weight n), eval bytes_weight bytes_n (to_bytesmod f) = eval weight n f - /\ to_bytesmod f = Rows.partition bytes_weight bytes_n (Positional.eval weight n f). + /\ to_bytesmod f = Partition.partition bytes_weight bytes_n (Positional.eval weight n f). Proof using Hn_nz limbwidth_good. split; apply eval_to_bytes || apply to_bytes_partitions; assumption. Qed. @@ -2877,7 +2919,7 @@ Section freeze_mod_ops. (Hf : length f = n) (Hf_bounded : 0 <= eval weight n f < 2 * m), (eval bytes_weight bytes_n (freeze_to_bytesmod f)) = (eval weight n f) mod m - /\ freeze_to_bytesmod f = Rows.partition bytes_weight bytes_n (Positional.eval weight n f mod m). + /\ freeze_to_bytesmod f = Partition.partition bytes_weight bytes_n (Positional.eval weight n f mod m). Proof using m_enc_correct Hs limbwidth_good Hn_nz c_small Hm_enc_len m_enc_bounded. clear -m_enc_correct Hs limbwidth_good Hn_nz c_small Hm_enc_len m_enc_bounded. intros; subst m s. @@ -3044,7 +3086,7 @@ Module WordByWordMontgomery. Let R := (r^Z.of_nat R_numlimbs). Transparent T. Definition small {n} (v : T n) : Prop - := v = Rows.partition weight n (eval v). + := v = Partition.partition weight n (eval v). Context (small_N : small N) (N_lt_R : eval N < R) (N_nz : 0 < eval N) @@ -3068,18 +3110,18 @@ Module WordByWordMontgomery. Lemma length_small {n v} : @small n v -> length v = n. Proof using Type. clear; cbv [small]; intro H; rewrite H; autorewrite with distr_length; reflexivity. Qed. - Let partition_Proper := (@Rows.partition_Proper _ wprops). + Let partition_Proper := (@Partition.partition_Proper _ wprops). Local Existing Instance partition_Proper. Lemma eval_nonzero n A : @small n A -> nonzero A = 0 <-> @eval n A = 0. Proof using lgr_big. clear -lgr_big partition_Proper. cbv [nonzero eval small]; intro Heq. do 2 rewrite Heq. - rewrite !Rows.eval_partition, Z.mod_mod by auto. + rewrite !Partition.eval_partition, Z.mod_mod by auto. generalize (Positional.eval weight n A); clear Heq A. induction n as [|n IHn]. { cbn; rewrite weight_0 by auto; intros; autorewrite with zsimplify_const; omega. } - { intro; rewrite Rows.partition_step. + { intro; rewrite Partition.partition_step. rewrite fold_right_snoc, Z.lor_comm, <- fold_right_push, Z.lor_eq_0_iff by auto using Z.lor_assoc. assert (Heq : Z.equiv_modulo (weight n) (z mod weight (S n)) (z mod (weight n))). { cbv [Z.equiv_modulo]. @@ -3142,7 +3184,7 @@ Module WordByWordMontgomery. Qed. Local Lemma small_zero : forall n, small (@zero n). Proof using Type. - etransitivity; [ eapply Positional.zeros_ext_map | rewrite eval_zero ]; cbv [Rows.partition]; cbn; try reflexivity; autorewrite with distr_length; reflexivity. + etransitivity; [ eapply Positional.zeros_ext_map | rewrite eval_zero ]; cbv [Partition.partition]; cbn; try reflexivity; autorewrite with distr_length; reflexivity. Qed. Local Hint Immediate small_zero. Local Axiom eval_div : forall n v, small v -> eval (fst (@divmod n v)) = eval v / r. @@ -3192,7 +3234,7 @@ Module WordByWordMontgomery. clear -lgr_big Ha Hb. autounfold with loc in *; destruct (zerop n); subst. { destruct a as [| ? [|] ], b; cbn; try omega. - cbv [Rows.partition seq eval map] in Ha. + cbv [Partition.partition seq eval map] in Ha. cbn in Ha. rewrite (weight_0 wprops) in *. rewrite Z.add_with_get_carry_full_mod. @@ -3790,7 +3832,7 @@ Module WordByWordMontgomery. Let r := 2^bitwidth. Local Notation weight := (UniformWeight.uweight bitwidth). Local Notation eval := (@eval bitwidth n). - Let m_enc := Rows.partition weight n m. + Let m_enc := Partition.partition weight n m. Local Coercion Z.of_nat : nat >-> Z. Context (r' : Z) (m' : Z) @@ -3844,7 +3886,7 @@ Module WordByWordMontgomery. | _ => lia | _ => exact small_m_enc | [ H : small ?x |- context[eval ?x] ] - => rewrite H; cbv [eval]; rewrite Rows.eval_partition by auto + => rewrite H; cbv [eval]; rewrite Partition.eval_partition by auto | [ |- context[weight _] ] => rewrite UniformWeight.uweight_eq_alt by auto with omega | _=> progress Z.rewrite_mod_small | _ => progress Z.zero_bounds @@ -3875,7 +3917,7 @@ Module WordByWordMontgomery. t_fin. Qed. - Definition onemod : list Z := Rows.partition weight n 1. + Definition onemod : list Z := Partition.partition weight n 1. Definition onemod_correct : eval onemod = 1 /\ valid onemod. Proof using n_nz m_big bitwidth_big. @@ -3883,7 +3925,7 @@ Module WordByWordMontgomery. cbv [valid small onemod eval]; autorewrite with push_eval; t_fin. Qed. - Definition R2mod : list Z := Rows.partition weight n ((r^n * r^n) mod m). + Definition R2mod : list Z := Partition.partition weight n ((r^n * r^n) mod m). Definition R2mod_correct : eval R2mod mod m = (r^n*r^n) mod m /\ valid R2mod. Proof using n_nz m_small m_big m'_correct bitwidth_big. @@ -3938,7 +3980,7 @@ Module WordByWordMontgomery. Qed. Definition encodemod (v : Z) : list Z - := mulmod (Rows.partition weight n v) R2mod. + := mulmod (Partition.partition weight n v) R2mod. Local Ltac t_valid v := cbv [valid]; repeat apply conj; @@ -4032,7 +4074,7 @@ Module WordByWordMontgomery. Lemma to_bytesmod_correct : (forall a (_ : valid a), Positional.eval (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (to_bytesmod a) = eval a mod m) - /\ (forall a (_ : valid a), to_bytesmod a = Rows.partition (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (eval a mod m)). + /\ (forall a (_ : valid a), to_bytesmod a = Partition.partition (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (eval a mod m)). Proof using n_nz m_small bitwidth_big. clear -n_nz m_small bitwidth_big. generalize (@length_small bitwidth n); |