From bd34d676010ebf6dd4aa5076537f89572803dd3d Mon Sep 17 00:00:00 2001 From: jadep Date: Wed, 3 Apr 2019 14:45:32 -0400 Subject: partition -> Partition.partition to prevent confusion with List.partition --- src/Arithmetic/BarrettReduction.v | 52 +++++++++++++------------- src/Arithmetic/BaseConversion.v | 16 ++++---- src/Arithmetic/FancyMontgomeryReduction.v | 4 +- src/Arithmetic/Freeze.v | 18 ++++----- src/Arithmetic/Partition.v | 16 +++++--- src/Arithmetic/Saturated.v | 26 ++++++------- src/Arithmetic/UniformWeight.v | 16 ++++---- src/Arithmetic/WordByWordMontgomery.v | 22 +++++------ src/PushButtonSynthesis/BarrettReduction.v | 4 +- src/PushButtonSynthesis/Primitives.v | 2 +- src/PushButtonSynthesis/WordByWordMontgomery.v | 4 +- 11 files changed, 92 insertions(+), 88 deletions(-) diff --git a/src/Arithmetic/BarrettReduction.v b/src/Arithmetic/BarrettReduction.v index 2f1f272ac..54740308b 100644 --- a/src/Arithmetic/BarrettReduction.v +++ b/src/Arithmetic/BarrettReduction.v @@ -35,7 +35,7 @@ Section Generic. (width_pos : 0 < width) (strong_bound : b ^ 1 * (b ^ (2 * k) mod M) <= b ^ (k + 1) - mu). Local Notation weight := (uweight width). - Local Notation partition := (partition weight). + Local Notation partition := (Partition.partition weight). Context (q1 : list Z -> list Z) (q1_correct : forall x, @@ -116,7 +116,7 @@ Module Fancy. (k_eq : k = width * Z.of_nat sz). (* sz = 1, width = k = 256 *) Local Notation w := (uweight width). Local Notation eval := (Positional.eval w). - Context (mut Mt : list Z) (mut_correct : mut = partition w (sz+1) mu) (Mt_correct : Mt = partition w sz M). + Context (mut Mt : list Z) (mut_correct : mut = Partition.partition w (sz+1) mu) (Mt_correct : Mt = Partition.partition w sz M). Context (mu_eq : mu = 2 ^ (2 * k) / M) (muHigh_one : mu / w sz = 1) (M_range : 2^(k-1) < M < 2^k). Local Lemma wprops : @weight_properties w. Proof. apply uwprops; auto with lia. Qed. @@ -201,7 +201,7 @@ Module Fancy. Lemma shiftr'_correct m n : forall t tn, (m <= tn)%nat -> 0 <= t < w tn -> 0 <= n < width -> - shiftr' m (partition w tn t) n = partition w m (t / 2 ^ n). + shiftr' m (Partition.partition w tn t) n = Partition.partition w m (t / 2 ^ n). Proof. cbv [shiftr']. induction m; intros; [ reflexivity | ]. rewrite !partition_step, seq_snoc. @@ -232,7 +232,7 @@ Module Fancy. forall t tn, (Z.to_nat (n / width) <= tn)%nat -> (m <= tn - Z.to_nat (n / width))%nat -> 0 <= t < w tn -> 0 <= n -> - shiftr m (partition w tn t) n = partition w m (t / 2 ^ n). + shiftr m (Partition.partition w tn t) n = Partition.partition w m (t / 2 ^ n). Proof. cbv [shiftr]; intros. break_innermost_match; [ | solve [auto using shiftr'_correct with zarith] ]. @@ -256,7 +256,7 @@ Module Fancy. (* 2 ^ (k + 1) bits fit in sz + 1 limbs because we know 2^k bits fit in sz and 1 <= width *) Lemma q1_correct x : 0 <= x < w (sz * 2) -> - q1 (partition w (sz*2)%nat x) = partition w (sz+1)%nat (x / 2 ^ (k - 1)). + q1 (Partition.partition w (sz*2)%nat x) = Partition.partition w (sz+1)%nat (x / 2 ^ (k - 1)). Proof. cbv [q1]; intros. assert (1 <= Z.of_nat sz) by (destruct sz; lia). assert (Z.to_nat ((k-1) / width) < sz)%nat. { @@ -266,13 +266,13 @@ Module Fancy. autorewrite with pull_partition. reflexivity. Qed. - Lemma low_correct n a : (sz <= n)%nat -> low (partition w n a) = partition w sz a. + Lemma low_correct n a : (sz <= n)%nat -> low (Partition.partition w n a) = Partition.partition w sz a. Proof. cbv [low]; auto using uweight_firstn_partition with lia. Qed. - Lemma high_correct a : high (partition w (sz*2) a) = partition w sz (a / w sz). + Lemma high_correct a : high (Partition.partition w (sz*2) a) = Partition.partition w sz (a / w sz). Proof. cbv [high]. rewrite uweight_skipn_partition by lia. f_equal; lia. Qed. Lemma fill_correct n m a : (n <= m)%nat -> - fill m (partition w n a) = partition w m (a mod w n). + fill m (Partition.partition w n a) = Partition.partition w m (a mod w n). Proof. cbv [fill]; intros. distr_length. rewrite <-partition_0 with (weight:=w). @@ -282,21 +282,21 @@ Module Fancy. Hint Rewrite low_correct high_correct fill_correct using lia : pull_partition. Lemma wideadd_correct a b : - wideadd (partition w (sz*2) a) (partition w (sz*2) b) = partition w (sz*2) (a + b). + wideadd (Partition.partition w (sz*2) a) (Partition.partition w (sz*2) b) = Partition.partition w (sz*2) (a + b). Proof. cbv [wideadd]. rewrite Rows.add_partitions by (distr_length; auto). autorewrite with push_eval. apply partition_eq_mod; auto with zarith. Qed. Lemma widesub_correct a b : - widesub (partition w (sz*2) a) (partition w (sz*2) b) = partition w (sz*2) (a - b). + widesub (Partition.partition w (sz*2) a) (Partition.partition w (sz*2) b) = Partition.partition w (sz*2) (a - b). Proof. cbv [widesub]. rewrite Rows.sub_partitions by (distr_length; auto). autorewrite with push_eval. apply partition_eq_mod; auto with zarith. Qed. Lemma widemul_correct a b : - widemul (partition w sz a) (partition w sz b) = partition w (sz*2) ((a mod w sz) * (b mod w sz)). + widemul (Partition.partition w sz a) (Partition.partition w sz b) = Partition.partition w (sz*2) ((a mod w sz) * (b mod w sz)). Proof. cbv [widemul]. rewrite BaseConversion.widemul_inlined_correct; (distr_length; auto). autorewrite with push_eval. reflexivity. @@ -327,8 +327,8 @@ Module Fancy. Lemma mul_high_correct a b (Ha : a / w sz = 1) a0b1 (Ha0b1 : a0b1 = a mod w sz * (b / w sz)) : - mul_high (partition w (sz*2) a) (partition w (sz*2) b) (partition w (sz*2) a0b1) = - partition w (sz*2) (a * b / w sz). + mul_high (Partition.partition w (sz*2) a) (Partition.partition w (sz*2) b) (Partition.partition w (sz*2) a0b1) = + Partition.partition w (sz*2) (a * b / w sz). Proof. cbv [mul_high Let_In]. erewrite mul_high_idea by auto using Z.div_mod with zarith. @@ -359,7 +359,7 @@ Module Fancy. Lemma muSelect_correct x : 0 <= x < w (sz * 2) -> - muSelect (partition w (sz*2) x) = partition w sz (mu mod (w sz) * (x / 2 ^ (k - 1) / (w sz))). + muSelect (Partition.partition w (sz*2) x) = Partition.partition w sz (mu mod (w sz) * (x / 2 ^ (k - 1) / (w sz))). Proof. cbv [muSelect]; intros; repeat match goal with @@ -391,7 +391,7 @@ Module Fancy. Qed. Lemma q3_correct x (Hx : 0 <= x < w (sz * 2)) q1 (Hq1 : q1 = x / 2 ^ (k - 1)) : - q3 (partition w (sz*2) x) (partition w (sz+1) q1) = partition w (sz+1) ((mu*q1) / 2 ^ (k + 1)). + q3 (Partition.partition w (sz*2) x) (Partition.partition w (sz+1) q1) = Partition.partition w (sz+1) ((mu*q1) / 2 ^ (k + 1)). Proof. cbv [q3 Let_In]. intros. pose proof mu_q1_range x ltac:(lia). pose proof mu_range'. pose proof q1_range x ltac:(lia). @@ -408,8 +408,8 @@ Module Fancy. Qed. Lemma cond_sub_correct a b : - cond_sub (partition w (sz*2) a) (partition w sz b) - = partition w sz (if dec ((a / w sz) mod 2 = 0) + cond_sub (Partition.partition w (sz*2) a) (Partition.partition w sz b) + = Partition.partition w sz (if dec ((a / w sz) mod 2 = 0) then a else a - b). Proof. @@ -425,8 +425,8 @@ Module Fancy. Qed. Hint Rewrite cond_sub_correct : pull_partition. Lemma cond_subM_correct a : - cond_subM (partition w sz a) - = partition w sz (if dec (a mod w sz < M) + cond_subM (Partition.partition w sz a) + = Partition.partition w sz (if dec (a mod w sz < M) then a else a - M). Proof. @@ -479,7 +479,7 @@ Module Fancy. 0 <= x < M * 2 ^ k -> 0 <= q3 -> (exists b : bool, q3 = x / M + (if b then -1 else 0)) -> - r (partition w (sz*2) x) (partition w (sz+1) q3) = partition w sz (x mod M). + r (Partition.partition w (sz*2) x) (Partition.partition w (sz+1) q3) = Partition.partition w sz (x mod M). Proof. intros; cbv [r Let_In]. pose proof M_range'. assert (0 < 2^(k-1)) by Z.zero_bounds. autorewrite with pull_partition. Z.rewrite_mod_small. @@ -515,7 +515,7 @@ Module Fancy. Lemma fancy_reduce_muSelect_first_correct x : 0 <= x < M * 2^k -> 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu -> - fancy_reduce_muSelect_first (partition w (sz*2) x) = partition w sz (x mod M). + fancy_reduce_muSelect_first (Partition.partition w (sz*2) x) = Partition.partition w sz (x mod M). Proof. intros. pose proof w_eq_22k. erewrite <-reduce_correct with (b:=2) (k:=k) (mu:=mu) by @@ -528,7 +528,7 @@ Module Fancy. forall x, 0 <= x < M * 2^k -> 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu -> - fancy_reduce' (partition w (sz*2) x) = partition w sz (x mod M)) + fancy_reduce' (Partition.partition w (sz*2) x) = Partition.partition w sz (x mod M)) As fancy_reduce'_correct. Proof. intros. assert (k = width) as width_eq_k by nia. @@ -546,9 +546,9 @@ Module Fancy. Lemma partition_2 xLow xHigh : 0 <= xLow < 2 ^ k -> 0 <= xHigh < M -> - partition w 2 (xLow + 2^k * xHigh) = [xLow;xHigh]. + Partition.partition w 2 (xLow + 2^k * xHigh) = [xLow;xHigh]. Proof. - replace k with width in M_range |- * by nia; intros. cbv [partition map seq]. + replace k with width in M_range |- * by nia; intros. cbv [Partition.partition map seq]. rewrite !uweight_S, !weight_0 by auto with zarith lia. autorewrite with zsimplify. rewrite <-Z.mod_pull_div by Z.zero_bounds. @@ -567,10 +567,10 @@ Module Fancy. intros. cbv [fancy_reduce]. rewrite <-partition_2 by lia. replace 2%nat with (sz*2)%nat by lia. rewrite fancy_reduce'_correct by nia. - rewrite sz_eq_1; cbv [partition map seq hd]. + rewrite sz_eq_1; cbv [Partition.partition map seq hd]. rewrite !uweight_S, !weight_0 by auto with zarith lia. autorewrite with zsimplify. reflexivity. Qed. End Def. End Fancy. -End Fancy. \ No newline at end of file +End Fancy. diff --git a/src/Arithmetic/BaseConversion.v b/src/Arithmetic/BaseConversion.v index dde04ae2a..ca5890705 100644 --- a/src/Arithmetic/BaseConversion.v +++ b/src/Arithmetic/BaseConversion.v @@ -41,13 +41,13 @@ Module BaseConversion. Lemma convert_bases_partitions sn dn p (dw_unique : forall i j : nat, (i <= dn)%nat -> (j <= dn)%nat -> dw i = dw j -> i = j) (p_bounded : 0 <= eval sw sn p < dw dn) - : convert_bases sn dn p = partition dw dn (eval sw sn p). + : convert_bases sn dn p = Partition.partition dw dn (eval sw sn p). Proof using dwprops. apply list_elementwise_eq; intro i. destruct (lt_dec i dn); [ | now rewrite !nth_error_length_error by distr_length ]. erewrite !(@nth_error_Some_nth_default _ _ 0) by (break_match; distr_length). apply f_equal. - cbv [convert_bases partition]. + cbv [convert_bases Partition.partition]. unshelve erewrite map_nth_default, nth_default_chained_carries_no_reduce_pred; repeat first [ progress autorewrite with distr_length push_eval | rewrite eval_from_associational, eval_to_associational @@ -117,7 +117,7 @@ Module BaseConversion. Hint Rewrite eval_from_associational using solve [push_eval; distr_length] : push_eval. Lemma from_associational_partitions n idxs p (_:n<>0%nat): - from_associational idxs n p = partition sw n (Associational.eval p). + from_associational idxs n p = Partition.partition sw n (Associational.eval p). Proof using dwprops swprops. intros. cbv [from_associational]. rewrite Rows.flatten_correct with (n:=n) by eauto using Rows.length_from_associational. @@ -181,7 +181,7 @@ Module BaseConversion. Lemma mul_converted_partitions n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> - mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = partition sw n3 (Positional.eval sw n1 p1 * Positional.eval sw n2 p2). + mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = Partition.partition sw n3 (Positional.eval sw n1 p1 * Positional.eval sw n2 p2). Proof using dwprops swprops. intros; cbv [mul_converted]. rewrite from_associational_partitions by auto. push_eval. @@ -215,14 +215,14 @@ Module BaseConversion. Lemma widemul_correct a b : length a = m -> length b = m -> - widemul a b = partition sw nout (seval m a * seval m b). + widemul a b = Partition.partition sw nout (seval m a * seval m b). Proof. apply mul_converted_partitions; auto with zarith. Qed. Derive widemul_inlined SuchThat (forall a b, length a = m -> length b = m -> - widemul_inlined a b = partition sw nout (seval m a * seval m b)) + widemul_inlined a b = Partition.partition sw nout (seval m a * seval m b)) As widemul_inlined_correct. Proof. intros. @@ -238,7 +238,7 @@ Module BaseConversion. SuchThat (forall a b, length a = m -> length b = m -> - widemul_inlined_reverse a b = partition sw nout (seval m a * seval m b)) + widemul_inlined_reverse a b = Partition.partition sw nout (seval m a * seval m b)) As widemul_inlined_reverse_correct. Proof. intros. @@ -258,4 +258,4 @@ Module BaseConversion. reflexivity. } Qed. End widemul. -End BaseConversion. \ No newline at end of file +End BaseConversion. diff --git a/src/Arithmetic/FancyMontgomeryReduction.v b/src/Arithmetic/FancyMontgomeryReduction.v index 51b578bed..f69f7c1d4 100644 --- a/src/Arithmetic/FancyMontgomeryReduction.v +++ b/src/Arithmetic/FancyMontgomeryReduction.v @@ -91,7 +91,7 @@ Module MontgomeryReduction. autorewrite with widemul. rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega). (* rewrite R_two_pow. *) - cbv [partition seq]. + cbv [Partition.partition seq]. repeat match goal with | _ => progress rewrite ?eval1, ?eval2 | _ => progress rewrite ?Z.zselect_correct, ?Z.add_modulo_correct @@ -127,4 +127,4 @@ Module MontgomeryReduction. apply reduce_via_partial_in_range; omega. Qed. End MontRed'. -End MontgomeryReduction. \ No newline at end of file +End MontgomeryReduction. diff --git a/src/Arithmetic/Freeze.v b/src/Arithmetic/Freeze.v index e766e7aea..bda62617a 100644 --- a/src/Arithmetic/Freeze.v +++ b/src/Arithmetic/Freeze.v @@ -116,7 +116,7 @@ Module Freeze. (Hp : 0 <= Positional.eval weight n p < 2*modulus) (Hplen : length p = n) (Hmlen : length m = n) - : @freeze n mask m p = partition weight n (Positional.eval weight n p mod modulus). + : @freeze n mask m p = Partition.partition weight n (Positional.eval weight n p mod modulus). Proof using wprops. pose proof (@weight_positive weight wprops n). pose proof (fun v => Z.mod_pos_bound v (weight n) ltac:(lia)). @@ -252,7 +252,7 @@ Section freeze_mod_ops. : forall (f : list Z) (Hf : length f = n) (Hf_small : 0 <= eval weight n f < weight n), - to_bytes f = partition bytes_weight bytes_n (Positional.eval weight n f). + to_bytes f = Partition.partition bytes_weight bytes_n (Positional.eval weight n f). Proof using Hn_nz limbwidth_good. clear -Hn_nz limbwidth_good. intros; cbv [to_bytes]. @@ -265,7 +265,7 @@ Section freeze_mod_ops. (Hf : length f = n) (Hf_small : 0 <= eval weight n f < weight n), eval bytes_weight bytes_n (to_bytesmod f) = eval weight n f - /\ to_bytesmod f = partition bytes_weight bytes_n (Positional.eval weight n f). + /\ to_bytesmod f = Partition.partition bytes_weight bytes_n (Positional.eval weight n f). Proof using Hn_nz limbwidth_good. split; apply eval_to_bytes || apply to_bytes_partitions; assumption. Qed. @@ -275,7 +275,7 @@ Section freeze_mod_ops. (Hf : length f = n) (Hf_bounded : 0 <= eval weight n f < 2 * m), (eval bytes_weight bytes_n (freeze_to_bytesmod f)) = (eval weight n f) mod m - /\ freeze_to_bytesmod f = partition bytes_weight bytes_n (Positional.eval weight n f mod m). + /\ freeze_to_bytesmod f = Partition.partition bytes_weight bytes_n (Positional.eval weight n f mod m). Proof using m_enc_correct Hs limbwidth_good Hn_nz c_small Hm_enc_len m_enc_bounded. clear -m_enc_correct Hs limbwidth_good Hn_nz c_small Hm_enc_len m_enc_bounded. intros; subst m s. @@ -299,7 +299,7 @@ Section freeze_mod_ops. : forall (f : list Z) (Hf : length f = n) (Hf_bounded : 0 <= eval weight n f < 2 * m), - freeze_to_bytesmod f = partition bytes_weight bytes_n (Positional.eval weight n f mod m). + freeze_to_bytesmod f = Partition.partition bytes_weight bytes_n (Positional.eval weight n f mod m). Proof using m_enc_correct Hs limbwidth_good Hn_nz c_small Hm_enc_len m_enc_bounded. intros; now apply eval_freeze_to_bytesmod_and_partitions. Qed. @@ -320,7 +320,7 @@ Section freeze_mod_ops. Lemma from_bytes_partitions : forall (f : list Z) (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n), - from_bytes f = partition weight n (Positional.eval bytes_weight bytes_n f). + from_bytes f = Partition.partition weight n (Positional.eval bytes_weight bytes_n f). Proof using limbwidth_good. clear -limbwidth_good. intros; cbv [from_bytes]. @@ -337,7 +337,7 @@ Section freeze_mod_ops. Lemma from_bytesmod_partitions : forall (f : list Z) (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n), - from_bytesmod f = partition weight n (Positional.eval bytes_weight bytes_n f). + from_bytesmod f = Partition.partition weight n (Positional.eval bytes_weight bytes_n f). Proof using limbwidth_good. apply from_bytes_partitions. Qed. Lemma eval_from_bytesmod_and_partitions @@ -345,9 +345,9 @@ Section freeze_mod_ops. (Hf : length f = bytes_n) (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n), eval weight n (from_bytesmod f) = eval bytes_weight bytes_n f - /\ from_bytesmod f = partition weight n (Positional.eval bytes_weight bytes_n f). + /\ from_bytesmod f = Partition.partition weight n (Positional.eval bytes_weight bytes_n f). Proof using limbwidth_good Hn_nz. now (split; [ apply eval_from_bytesmod | apply from_bytes_partitions ]). Qed. End freeze_mod_ops. -Hint Rewrite eval_freeze_to_bytesmod eval_to_bytes eval_to_bytesmod eval_from_bytes eval_from_bytesmod : push_eval. \ No newline at end of file +Hint Rewrite eval_freeze_to_bytesmod eval_to_bytes eval_to_bytesmod eval_from_bytes eval_from_bytesmod : push_eval. diff --git a/src/Arithmetic/Partition.v b/src/Arithmetic/Partition.v index 2d2fb87fa..ed2c45f9e 100644 --- a/src/Arithmetic/Partition.v +++ b/src/Arithmetic/Partition.v @@ -9,11 +9,15 @@ Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div. Require Import Crypto.Util.Notations. Import ListNotations Weight. Local Open Scope Z_scope. -Section Partition. - Context weight {wprops : @weight_properties weight}. - - Definition partition n x := +(* extra name wrapper so partition won't be confused with List.partition *) +Module Partition. + Definition partition (weight : nat -> Z) n x := map (fun i => (x mod weight (S i)) / weight i) (seq 0 n). +End Partition. + +Section PartitionProofs. + Context weight {wprops : @weight_properties weight}. + Local Notation partition := (Partition.partition weight). Lemma partition_step n x : partition (S n) x = partition n x ++ [(x mod weight (S n)) / weight n]. @@ -124,6 +128,6 @@ Section Partition. autorewrite with zsimplify; reflexivity. Qed. -End Partition. +End PartitionProofs. Hint Rewrite length_partition length_recursive_partition : distr_length. -Hint Rewrite eval_partition using (solve [auto; distr_length]) : push_eval. \ No newline at end of file +Hint Rewrite eval_partition using (solve [auto; distr_length]) : push_eval. diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index c82f6afa9..dc258cfd9 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -219,7 +219,7 @@ Module Columns. Lemma flatten_correct inp: forall n, length inp = n -> - flatten inp = (partition weight n (eval n inp), + flatten inp = (Partition.partition weight n (eval n inp), eval n inp / (weight n)). Proof using wprops. induction inp using rev_ind; intros; @@ -654,10 +654,10 @@ Module Rows. nm = (n + m)%nat -> let eval := Positional.eval weight in snd (fst start_state) = (eval m row1' + eval m row2') / weight m -> - (fst (fst start_state) = partition weight m (eval m row1' + eval m row2')) -> + (fst (fst start_state) = Partition.partition weight m (eval m row1' + eval m row2')) -> let sum := eval nm (row1' ++ row1) + eval nm (row2' ++ row2) in sum_rows' start_state row1 row2 - = (partition weight nm sum, sum / weight nm, nm) . + = (Partition.partition weight nm sum, sum / weight nm, nm) . Proof using wprops. destruct start_state as [ [acc rem] m]. cbn [fst snd]. revert acc rem m. @@ -687,7 +687,7 @@ Module Rows. Lemma sum_rows_correct row1: forall row2 n, length row1 = n -> length row2 = n -> let sum := Positional.eval weight n row1 + Positional.eval weight n row2 in - sum_rows row1 row2 = (partition weight n sum, sum / weight n). + sum_rows row1 row2 = (Partition.partition weight n sum, sum / weight n). Proof using wprops. cbv [sum_rows]; intros. erewrite sum_rows'_correct with (nm:=n) (row1':=nil) (row2':=nil)by (cbn; distr_length; reflexivity). @@ -755,7 +755,7 @@ Module Rows. (forall row, In row inp -> length row = n) -> inp <> nil -> let sum := Positional.eval weight n (fst start_state) + eval n inp + weight n * snd start_state in - flatten' start_state inp = (partition weight n sum, sum / weight n). + flatten' start_state inp = (Partition.partition weight n sum, sum / weight n). Proof using wprops. induction inp using rev_ind; push. subst sum. destruct (dec (inp = nil)); [ subst inp; cbn | ]; @@ -776,7 +776,7 @@ Module Rows. Lemma flatten_correct inp n : (forall row, In row inp -> length row = n) -> - flatten n inp = (partition weight n (eval n inp), eval n inp / weight n). + flatten n inp = (Partition.partition weight n (eval n inp), eval n inp / weight n). Proof using wprops. intros; cbv [flatten]. destruct inp; [|destruct inp]; cbn [hd tl]; @@ -877,7 +877,7 @@ Module Rows. Lemma add_partitions n p q : length p = n -> length q = n -> - fst (add n p q) = partition weight n (Positional.eval weight n p + Positional.eval weight n q). + fst (add n p q) = Partition.partition weight n (Positional.eval weight n p + Positional.eval weight n q). Proof using wprops. solver. Qed. Lemma add_div n p q : @@ -888,7 +888,7 @@ Module Rows. Lemma conditional_add_partitions n mask cond p q : length p = n -> length q = n -> map (Z.land mask) q = q -> fst (conditional_add n mask cond p q) - = partition weight n (Positional.eval weight n p + if dec (cond = 0) then 0 else Positional.eval weight n q). + = Partition.partition weight n (Positional.eval weight n p + if dec (cond = 0) then 0 else Positional.eval weight n q). Proof using wprops. cbv [conditional_add]; intros; rewrite add_partitions by (distr_length; auto). autorewrite with push_eval; reflexivity. @@ -917,7 +917,7 @@ Module Rows. Lemma sub_partitions n p q : length p = n -> length q = n -> - fst (sub n p q) = partition weight n (Positional.eval weight n p - Positional.eval weight n q). + fst (sub n p q) = Partition.partition weight n (Positional.eval weight n p - Positional.eval weight n q). Proof using wprops. solver. Qed. Lemma sub_div n p q : @@ -926,10 +926,10 @@ Module Rows. Proof using wprops. solver. Qed. Lemma conditional_sub_partitions n p q - (Hp : p = partition weight n (Positional.eval weight n p)) : + (Hp : p = Partition.partition weight n (Positional.eval weight n p)) : length q = n -> 0 <= Positional.eval weight n q < weight n -> - conditional_sub n p q = partition weight n (if Positional.eval weight n q <=? Positional.eval weight n p then Positional.eval weight n p - Positional.eval weight n q else Positional.eval weight n p). + conditional_sub n p q = Partition.partition weight n (if Positional.eval weight n q <=? Positional.eval weight n p then Positional.eval weight n p - Positional.eval weight n q else Positional.eval weight n p). Proof using wprops. cbv [conditional_sub]; intros. rewrite (surjective_pairing (sub _ _ _)). @@ -952,7 +952,7 @@ Module Rows. map (Z.land mask) r = r -> 0 <= Positional.eval weight n p < weight n -> 0 <= Positional.eval weight n q < weight n -> - fst (sub_then_maybe_add n mask p q r) = partition weight n (sub_then_maybe_add_Z (Positional.eval weight n p) (Positional.eval weight n q) (Positional.eval weight n r)). + fst (sub_then_maybe_add n mask p q r) = Partition.partition weight n (sub_then_maybe_add_Z (Positional.eval weight n p) (Positional.eval weight n q) (Positional.eval weight n r)). Proof using wprops. cbv [sub_then_maybe_add]. subst sub_then_maybe_add_Z. intros. @@ -969,7 +969,7 @@ Module Rows. Lemma mul_partitions base n m p q : base <> 0 -> m <> 0%nat -> length p = n -> length q = n -> - fst (mul base n m p q) = partition weight m (Positional.eval weight n p * Positional.eval weight n q). + fst (mul base n m p q) = Partition.partition weight m (Positional.eval weight n p * Positional.eval weight n q). Proof using wprops. solver. Qed. Lemma mul_div base n m p q : diff --git a/src/Arithmetic/UniformWeight.v b/src/Arithmetic/UniformWeight.v index 9af083994..25d6fb92d 100644 --- a/src/Arithmetic/UniformWeight.v +++ b/src/Arithmetic/UniformWeight.v @@ -80,7 +80,7 @@ Proof using Type. erewrite IHn. reflexivity. Qed. Lemma uweight_recursive_partition_equiv lgr (Hr : 0 < lgr) n i x: - partition (uweight lgr) n x = + Partition.partition (uweight lgr) n x = recursive_partition (uweight lgr) n i x. Proof using Type. rewrite recursive_partition_equiv by auto using uwprops. @@ -88,9 +88,9 @@ Proof using Type. Qed. Lemma uweight_firstn_partition lgr (Hr : 0 < lgr) n x m (Hm : (m <= n)%nat) : - firstn m (partition (uweight lgr) n x) = partition (uweight lgr) m x. + firstn m (Partition.partition (uweight lgr) n x) = Partition.partition (uweight lgr) m x. Proof. - cbv [partition]; + cbv [Partition.partition]; repeat match goal with | _ => progress intros | _ => progress autorewrite with push_firstn natsimplify zsimplify_fast @@ -101,9 +101,9 @@ Proof. Qed. Lemma uweight_skipn_partition lgr (Hr : 0 < lgr) n x m : - skipn m (partition (uweight lgr) n x) = partition (uweight lgr) (n - m) (x / uweight lgr m). + skipn m (Partition.partition (uweight lgr) n x) = Partition.partition (uweight lgr) (n - m) (x / uweight lgr m). Proof. - cbv [partition]; + cbv [Partition.partition]; repeat match goal with | _ => progress intros | _ => progress autorewrite with push_skipn natsimplify zsimplify_fast @@ -116,7 +116,7 @@ Qed. Lemma uweight_partition_unique lgr (Hr : 0 < lgr) n ls : length ls = n -> (forall x, List.In x ls -> 0 <= x <= 2^lgr - 1) -> - ls = partition (uweight lgr) n (Positional.eval (uweight lgr) n ls). + ls = Partition.partition (uweight lgr) n (Positional.eval (uweight lgr) n ls). Proof using Type. intro; subst n. rewrite uweight_recursive_partition_equiv with (i:=0%nat) by assumption. @@ -169,8 +169,8 @@ Lemma uweight_eval_app lgr (Hr : 0 <= lgr) n m x y : Proof using Type. intros. subst m. apply uweight_eval_app'; lia. Qed. Lemma uweight_partition_app lgr (Hr : 0 < lgr) n m a b : - partition (uweight lgr) n a ++ partition (uweight lgr) m b - = partition (uweight lgr) (n+m) (a mod uweight lgr n + b * uweight lgr n). + Partition.partition (uweight lgr) n a ++ Partition.partition (uweight lgr) m b + = Partition.partition (uweight lgr) (n+m) (a mod uweight lgr n + b * uweight lgr n). Proof. assert (0 < uweight lgr n) by auto using uwprops. match goal with |- _ = ?rhs => rewrite <-(firstn_skipn n rhs) end. diff --git a/src/Arithmetic/WordByWordMontgomery.v b/src/Arithmetic/WordByWordMontgomery.v index 3fb3437c1..5b1a8a256 100644 --- a/src/Arithmetic/WordByWordMontgomery.v +++ b/src/Arithmetic/WordByWordMontgomery.v @@ -126,7 +126,7 @@ Module WordByWordMontgomery. Let R := (r^Z.of_nat R_numlimbs). Transparent T. Definition small {n} (v : T n) : Prop - := v = partition weight n (eval v). + := v = Partition.partition weight n (eval v). Context (small_N : small N) (N_lt_R : eval N < R) (N_nz : 0 < eval N) @@ -161,9 +161,9 @@ Module WordByWordMontgomery. Qed. Lemma mask_r_sub1 n x : - map (Z.land (r - 1)) (partition weight n x) = partition weight n x. + map (Z.land (r - 1)) (Partition.partition weight n x) = Partition.partition weight n x. Proof using lgr_big. - clear - lgr_big. cbv [partition]. + clear - lgr_big. cbv [Partition.partition]. rewrite map_map. apply map_ext; intros. rewrite uweight_S by omega. rewrite <-Z.mod_pull_div by auto with zarith. @@ -248,7 +248,7 @@ Module WordByWordMontgomery. Qed. Local Lemma small_zero : forall n, small (@zero n). Proof using Type. - etransitivity; [ eapply Positional.zeros_ext_map | rewrite eval_zero ]; cbv [partition]; cbn; try reflexivity; autorewrite with distr_length; reflexivity. + etransitivity; [ eapply Positional.zeros_ext_map | rewrite eval_zero ]; cbv [Partition.partition]; cbn; try reflexivity; autorewrite with distr_length; reflexivity. Qed. Local Hint Immediate small_zero. @@ -288,7 +288,7 @@ Module WordByWordMontgomery. Qed. Definition canon_rep {n} x (v : T n) : Prop := - (v = partition weight n x) /\ (0 <= x < weight n). + (v = Partition.partition weight n x) /\ (0 <= x < weight n). Lemma eval_canon_rep n x v : @canon_rep n x v -> eval v = x. Proof using lgr_big. clear - lgr_big. @@ -974,7 +974,7 @@ Module WordByWordMontgomery. Let r := 2^bitwidth. Local Notation weight := (uweight bitwidth). Local Notation eval := (@eval bitwidth n). - Let m_enc := partition weight n m. + Let m_enc := Partition.partition weight n m. Local Coercion Z.of_nat : nat >-> Z. Context (r' : Z) (m' : Z) @@ -1059,7 +1059,7 @@ Module WordByWordMontgomery. t_fin. Qed. - Definition onemod : list Z := partition weight n 1. + Definition onemod : list Z := Partition.partition weight n 1. Definition onemod_correct : eval onemod = 1 /\ valid onemod. Proof using n_nz m_big bitwidth_big. @@ -1070,7 +1070,7 @@ Module WordByWordMontgomery. Lemma eval_onemod : eval onemod = 1. Proof. apply onemod_correct. Qed. - Definition R2mod : list Z := partition weight n ((r^n * r^n) mod m). + Definition R2mod : list Z := Partition.partition weight n ((r^n * r^n) mod m). Definition R2mod_correct : eval R2mod mod m = (r^n*r^n) mod m /\ valid R2mod. Proof using n_nz m_small m_big m'_correct bitwidth_big. @@ -1137,7 +1137,7 @@ Module WordByWordMontgomery. Proof. apply squaremod_correct. Qed. Definition encodemod (v : Z) : list Z - := mulmod (partition weight n v) R2mod. + := mulmod (Partition.partition weight n v) R2mod. Local Ltac t_valid v := cbv [valid]; repeat apply conj; @@ -1260,7 +1260,7 @@ Module WordByWordMontgomery. Lemma to_bytesmod_correct : (forall a (_ : valid a), Positional.eval (uweight 8) (bytes_n bitwidth 1 n) (to_bytesmod a) = eval a mod m) - /\ (forall a (_ : valid a), to_bytesmod a = partition (uweight 8) (bytes_n bitwidth 1 n) (eval a mod m)). + /\ (forall a (_ : valid a), to_bytesmod a = Partition.partition (uweight 8) (bytes_n bitwidth 1 n) (eval a mod m)). Proof using n_nz m_small bitwidth_big. clear -n_nz m_small bitwidth_big. generalize (@length_small bitwidth n); @@ -1279,4 +1279,4 @@ Module WordByWordMontgomery. = eval a mod m). Proof. apply to_bytesmod_correct. Qed. End modops. -End WordByWordMontgomery. \ No newline at end of file +End WordByWordMontgomery. diff --git a/src/PushButtonSynthesis/BarrettReduction.v b/src/PushButtonSynthesis/BarrettReduction.v index 42be4b7d4..990a0de03 100644 --- a/src/PushButtonSynthesis/BarrettReduction.v +++ b/src/PushButtonSynthesis/BarrettReduction.v @@ -75,7 +75,7 @@ Section rbarrett_red. Lemma mut_correct : 0 < machine_wordsize -> - partition (uweight machine_wordsize) (1 + 1) (muLow + 2 ^ machine_wordsize) = [muLow; 1]. + Partition.partition (uweight machine_wordsize) (1 + 1) (muLow + 2 ^ machine_wordsize) = [muLow; 1]. Proof. intros; cbn. subst muLow. assert (0 < 2^machine_wordsize) by ZeroBounds.Z.zero_bounds. @@ -90,7 +90,7 @@ Section rbarrett_red. Lemma Mt_correct : 0 < machine_wordsize -> 2^(machine_wordsize - 1) < M < 2^machine_wordsize -> - partition (uweight machine_wordsize) 1 M = [M]. + Partition.partition (uweight machine_wordsize) 1 M = [M]. Proof. intros; cbn. assert (0 < 2^(machine_wordsize-1)) by ZeroBounds.Z.zero_bounds. rewrite !uweight_S, weight_0; auto using uwprops with lia. diff --git a/src/PushButtonSynthesis/Primitives.v b/src/PushButtonSynthesis/Primitives.v index 93c2e4c69..a3df39c25 100644 --- a/src/PushButtonSynthesis/Primitives.v +++ b/src/PushButtonSynthesis/Primitives.v @@ -574,7 +574,7 @@ Module CorrectnessStringification. Ltac stringify ctx correctness fname arg_var_data out_var_data := let G := match goal with |- ?G => G end in let correctness := (eval hnf in correctness) in - let correctness := (eval cbv [partition WordByWordMontgomery.valid WordByWordMontgomery.small] in correctness) in + let correctness := (eval cbv [Partition.partition WordByWordMontgomery.valid WordByWordMontgomery.small] in correctness) in let correctness := strip_bounds_info correctness in let arg_var_names := constr:(type.map_for_each_lhs_of_arrow (@ToString.C.OfPHOAS.names_of_var_data) arg_var_data) in let out_var_names := constr:(ToString.C.OfPHOAS.names_of_base_var_data out_var_data) in diff --git a/src/PushButtonSynthesis/WordByWordMontgomery.v b/src/PushButtonSynthesis/WordByWordMontgomery.v index ebf250664..ea67f8b9c 100644 --- a/src/PushButtonSynthesis/WordByWordMontgomery.v +++ b/src/PushButtonSynthesis/WordByWordMontgomery.v @@ -103,9 +103,9 @@ Section __. end. Let n_bytes := bytes_n machine_wordsize 1 n. Let prime_upperbound_list : list Z - := partition (uweight machine_wordsize) n (s-1). + := Partition.partition (uweight machine_wordsize) n (s-1). Let prime_bytes_upperbound_list : list Z - := partition (weight 8 1) n_bytes (s-1). + := Partition.partition (weight 8 1) n_bytes (s-1). Let upperbounds : list Z := prime_upperbound_list. Definition prime_bound : ZRange.type.option.interp (base.type.Z) := Some r[0~>m-1]%zrange. -- cgit v1.2.3