diff options
author | Jade Philipoom <jadep@google.com> | 2018-04-09 13:06:01 +0200 |
---|---|---|
committer | Jade Philipoom <jadep@google.com> | 2018-04-09 13:06:01 +0200 |
commit | e2bf39c696a8ecc35c6d9b31011f54854ec3142a (patch) | |
tree | 9d81b6d756e5054c170d9f4201674e4698f5eecc /src | |
parent | 0fc77cb8ae1d0c85ac6d5ed59aaca85472735c4a (diff) |
package properties of weight functions into a record
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 198 |
1 files changed, 64 insertions, 134 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 95c6929d2..4c0ff4e95 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -427,6 +427,15 @@ Module Positional. Section Positional. Hint Rewrite @length_sub @length_opp : distr_length. End Positional. End Positional. +Record weight_properties {weight : nat -> Z} := + { + weight_0 : weight 0%nat = 1; + weight_positive : forall i, 0 < weight i; + weight_multiples : forall i, weight (S i) mod weight i = 0; + weight_divides : forall i : nat, 0 < weight (S i) / weight i; + }. +Hint Resolve weight_0 weight_positive weight_multiples weight_divides. + Section mod_ops. Import Positional. Local Coercion Z.of_nat : nat >-> Z. @@ -463,34 +472,26 @@ Section mod_ops. try reflexivity; try lia. Qed. - Lemma weight_0 : weight 0 = 1. - Proof. - clear. - cbv [weight Z.of_nat]; autorewrite with zsimplify_fast; reflexivity. - Qed. - Local Hint Immediate weight_0. - Local Ltac t_weight_with lem := clear -limbwidth_good; intros; rewrite !weight_ZQ_correct; apply lem; try omega; Q_cbv; destruct limbwidth_den; cbn; try lia. - Local Lemma weight_nz : forall i, weight i <> 0. - Proof. t_weight_with (@pow_ceil_mul_nat_nonzero 2). Qed. - Local Hint Immediate weight_nz. - - Local Lemma weight_div_nz : forall i : nat, weight (S i) / weight i <> 0. - Proof. t_weight_with (@pow_ceil_mul_nat_divide_nonzero 2). Qed. - Local Hint Immediate weight_div_nz. - - (* lemmas for montred *) - Local Lemma weight_divides : forall i, weight (S i) / weight i > 0. - Proof. t_weight_with (@pow_ceil_mul_nat_divide 2). Qed. - Local Lemma weight_positive : forall i, weight i > 0. - Proof. t_weight_with (@pow_ceil_mul_nat_pos 2). Qed. - Local Lemma weight_multiples : forall i, weight (S i) mod weight i = 0. - Proof. t_weight_with (@pow_ceil_mul_nat_multiples 2). Qed. + Definition wprops : @weight_properties weight. + Proof. + constructor. + { cbv [weight Z.of_nat]; autorewrite with zsimplify_fast; reflexivity. } + { intros; apply Z.gt_lt. t_weight_with (@pow_ceil_mul_nat_pos 2). } + { t_weight_with (@pow_ceil_mul_nat_multiples 2). } + { intros; apply Z.gt_lt. t_weight_with (@pow_ceil_mul_nat_divide 2). } + Defined. + Local Hint Immediate (weight_0 wprops). + Local Hint Immediate (weight_positive wprops). + Local Hint Immediate (weight_multiples wprops). + Local Hint Immediate (weight_divides wprops). + Local Hint Resolve Z.positive_is_nonzero Z.lt_gt. + Local Lemma weight_1_gt_1 : weight 1 > 1. Proof. clear -limbwidth_good. @@ -596,13 +597,11 @@ Section mod_ops. End mod_ops. Module Saturated. + Hint Resolve weight_positive weight_0 weight_multiples weight_divides. + Hint Resolve Z.positive_is_nonzero Z.lt_gt Nat2Z.is_nonneg. + Section Weight. - Context (weight : nat->Z) - {weight_0 : weight 0%nat = 1} - {weight_nonzero : forall i, weight i <> 0} - {weight_positive : forall i, weight i > 0} - {weight_multiples : forall i, weight (S i) mod weight i = 0} - {weight_divides : forall i : nat, weight (S i) / weight i > 0}. + Context weight {wprops : @weight_properties weight}. Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0. Proof. @@ -623,8 +622,8 @@ Module Saturated. apply weight_multiples_full'. Qed. - Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0. - Proof. auto using Z.div_positive_gt_0, weight_multiples_full. Qed. + Lemma weight_divides_full j i : (i <= j)%nat -> 0 < weight j / weight i. + Proof. auto using Z.gt_lt, Z.div_positive_gt_0, weight_multiples_full. Qed. Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i). Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed. @@ -738,12 +737,7 @@ End Saturated. Module Columns. Import Saturated. Section Columns. - Context (weight : nat->Z) - {weight_0 : weight 0%nat = 1} - {weight_nonzero : forall i, weight i <> 0} - {weight_positive : forall i, weight i > 0} - {weight_multiples : forall i, weight (S i) mod weight i = 0} - {weight_divides : forall i : nat, weight (S i) / weight i > 0}. + Context weight {wprops : @weight_properties weight}. Definition eval n (x : list (list Z)) : Z := Positional.eval weight n (map sum x). @@ -953,7 +947,7 @@ Module Columns. eval n (from_associational n p) = Associational.eval p. Proof. erewrite <-Positional.eval_from_associational by eauto. - induction p; [ autorewrite with push_eval; congruence |]. + induction p; [ autorewrite with push_eval; solve [auto] |]. cbv [from_associational Positional.from_associational]; autorewrite with push_fold_right. fold (from_associational n p); fold (Positional.from_associational weight n p). cbv [Let_In]. @@ -989,12 +983,7 @@ End Columns. Module Rows. Import Saturated. Section Rows. - Context (weight : nat->Z) - {weight_0 : weight 0%nat = 1} - {weight_nonzero : forall i, weight i <> 0} - {weight_positive : forall i, weight i > 0} - {weight_multiples : forall i, weight (S i) mod weight i = 0} - {weight_divides : forall i : nat, weight (S i) / weight i > 0}. + Context weight {wprops : @weight_properties weight}. Local Notation rows := (list (list Z)) (only parsing). Local Notation cols := (list (list Z)) (only parsing). @@ -1003,7 +992,6 @@ Module Rows. Positional.eval_to_associational Columns.eval_nil Columns.eval_snoc using (auto; solve [distr_length]) : push_eval. Hint Resolve in_eq in_cons. - Hint Resolve Z.gt_lt. Definition eval n (inp : rows) := sum (map (Positional.eval weight n) inp). @@ -1359,7 +1347,7 @@ Module Rows. | _ => rewrite <-(Z.add_assoc _ x1 x2) end. { rewrite div_step by auto using Z.gt_lt. - rewrite Z.mul_div_eq_full by auto; rewrite weight_multiples. push. } + rewrite Z.mul_div_eq_full by auto; rewrite weight_multiples by auto. push. } { rewrite weight_div_mod with (j:=length (fst start_state)) (i:=S j) by (auto; omega). push_Zmod. autorewrite with zsimplify_fast. reflexivity. } { push. replace (length (fst start_state)) with j in * by omega. @@ -1418,7 +1406,7 @@ Module Rows. | _ => progress In_cases | |- _ /\ _ => split | |- context [?x mod ?y] => unique pose proof (Z.mul_div_eq_full x y ltac:(auto)); lia - | _ => solve [repeat (f_equal; try ring)] + | _ => solve [repeat (ring_simplify; f_equal; try ring)] | _ => congruence | _ => solve [eauto] end. @@ -1436,7 +1424,7 @@ Module Rows. destruct (dec (inp = nil)); [subst inp; cbv [is_div_mod] | eapply is_div_mod_result_equal; try apply IHinp]; push. { autorewrite with zsimplify; push. } - { autorewrite with zsimplify; push. } + { rewrite Z.div_add' by auto; push. } Qed. Hint Rewrite (@Positional.length_zeros weight) : distr_length. @@ -1494,7 +1482,7 @@ Module Rows. { subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. push. } { erewrite IHinp; push. rewrite add_mod_l_multiple by auto using weight_divides_full, weight_multiples_full. - repeat (f_equal; try ring). } + push. } Qed. Lemma flatten_partitions inp n : @@ -1574,15 +1562,10 @@ End Rows. Module BaseConversion. Import Positional. Section BaseConversion. + Hint Resolve Z.gt_lt. Context (sw dw : nat -> Z) (* source/destination weight functions *) - (dw_0 : dw 0%nat = 1) - (sw_0 : sw 0%nat = 1) - (dw_nz : forall i, dw i <> 0) - (sw_nz : forall i, sw i <> 0) - (sw_pos : forall i, sw i > 0) - (sw_multiples : forall i, sw (S i) mod sw i = 0) - (sw_divides : forall i, sw (S i) / sw i > 0) - (dw_divides : forall i, dw (S i) / dw i > 0). + {swprops : @weight_properties sw} + {dwprops : @weight_properties dw}. Definition convert_bases (sn dn : nat) (p : list Z) : list Z := let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in @@ -1691,41 +1674,31 @@ Module BaseConversion. (* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *) Section widemul. - Context (dw : nat -> Z) (n : nat) (n_nz : n <> 0%nat). - Context (dw_0 : dw 0%nat = 1) - (dw_pos : forall i, dw i > 0) - (dw_multiples : forall i, dw (S i) mod dw i = 0) - (dw_divides : forall i, dw (S i) / dw i > 0). - Context (nout : nat) (nout_nz : nout <> 0%nat). (* this is always 2, but reification has trouble if it's a constant *) - Let sw i := (dw i) ^ Z.of_nat n. + Context (log2base : Z) (log2base_pos : 0 < log2base). + Context (n : nat) (n_nz : n <> 0%nat) (n_le_log2base : Z.of_nat n <= log2base) + (nout : nat) (nout_2 : nout = 2%nat). (* nout is always 2, but partial evaluation is overeager if it's a constant *) + Let dw : nat -> Z := weight (log2base / Z.of_nat n) 1. + Let sw : nat -> Z := weight log2base 1. + + Local Lemma base_bounds : 0 < 1 <= log2base. Proof. auto with zarith. Qed. + Local Lemma dbase_bounds : 0 < 1 <= log2base / Z.of_nat n. Proof. auto with zarith. Qed. + Let dwprops : @weight_properties dw := wprops (log2base / Z.of_nat n) 1 dbase_bounds. + Let swprops : @weight_properties sw := wprops log2base 1 base_bounds. Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg. - (* TODO : There has got to be a cleaner way to do weight functions *) - Lemma sw_0 : sw 0 = 1. - Proof. subst sw; cbv beta. rewrite dw_0. auto using Z.pow_1_l. Qed. - Lemma sw_pos : forall i, sw i > 0. - Proof. subst sw; intros; apply Z.lt_gt. Z.zero_bounds. Qed. - Lemma sw_nz : forall i, sw i <> 0. - Proof. auto using sw_pos. Qed. - Lemma sw_multiples : forall i, sw (S i) mod sw i = 0. - Proof. - subst sw; cbv beta; intros. - rewrite Saturated.weight_div_mod with (weight := dw) (j:=S i) (i:=i) by auto. - rewrite Z.pow_mul_l. - push_Zmod. autorewrite with zsimplify. reflexivity. - Qed. - Lemma sw_divides : forall i, sw (S i) / sw i > 0. - Proof. intros; apply Z.div_positive_gt_0; auto using sw_pos, sw_multiples. Qed. - Hint Resolve sw_0 sw_pos sw_nz sw_multiples sw_divides. Definition widemul a b := mul_converted sw dw 1 1 n n nout (aligned_carries n nout) [a] [b]. - Lemma widemul_partitions a b : - widemul a b = Rows.partition sw nout (a * b). + Lemma widemul_correct a b : + 0 <= a * b < 2^log2base * 2^log2base -> + widemul a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base]. Proof. - cbv [widemul]. + cbv [widemul]; intros. rewrite mul_converted_partitions by auto with zarith. - cbn; rewrite sw_0; ring_simplify_subterms. reflexivity. + subst nout sw; cbv [weight]; cbn. + autorewrite with zsimplify. + rewrite Z.pow_mul_r, Z.pow_2_r by omega. + Z.rewrite_mod_small. reflexivity. Qed. End widemul. End BaseConversion. @@ -7789,45 +7762,11 @@ Module MontgomeryReduction. Context (R_big_enough : n <= Zlog2R) (R_two_pow : 2^Zlog2R = R). Let w_mul : nat -> Z := weight (Zlog2R / n) 1. - Local Lemma w_mul_pown : forall i, (w_mul i) ^ n = w i. - Proof. - cbv [w_mul w weight]; intro. - autorewrite with pull_Zpow zsimplify. - rewrite <-Z.pow_mul_r by Z.zero_bounds. apply f_equal. - rewrite (Z.div_mod Zlog2R n) at 2 by Z.zero_bounds. - rewrite n_good. lia. - Qed. - Local Lemma log2R_good : 0 < 1 <= Zlog2R. - Proof. clear -R_big_enough Hn_nz; lia. Qed. - Local Lemma half_log2R_good : 0 < 1 <= Zlog2R / n. - Proof. clear -R_big_enough Hn_nz; Z.div_mod_to_quot_rem; nia. Qed. - Let w_mul_0 : w_mul 0%nat = 1 := weight_0 _ _. - Let w_mul_nonzero : forall i, w_mul i <> 0 - := weight_nz _ _ half_log2R_good. - Let w_mul_positive : forall i, w_mul i > 0 - := weight_positive _ _ half_log2R_good. - Let w_mul_multiples : forall i, w_mul (S i) mod w_mul i = 0 - := weight_multiples _ _ half_log2R_good. - Let w_mul_divides : forall i : nat, w_mul (S i) / w_mul i > 0 - := weight_divides _ _ half_log2R_good. -(* - Let w_0 : w 0%nat = 1 := weight_0 _ _. - Let w_nonzero : forall i, w i <> 0 - := weight_nz _ _ log2R_good. - Let w_positive : forall i, w i > 0 - := weight_positive _ _ log2R_good. - Let w_multiples : forall i, w (S i) mod w i = 0 - := weight_multiples _ _ log2R_good. - Let w_divides : forall i : nat, w (S i) / w i > 0 - := weight_divides _ _ log2R_good. - Let w_1_gt1 : w 1 > 1 := weight_1_gt_1 _ _ log2R_good. -*) - Let w_mul_1_gt1 : w_mul 1 > 1 := weight_1_gt_1 _ _ half_log2R_good. Context (nout : nat) (Hnout : nout = 2%nat). Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (BaseConversion.widemul w_mul n nout (fst lo_hi) N') 0 in - dlet_nd t1_t2 := (BaseConversion.widemul w_mul n nout y N) in + dlet_nd y := nth_default 0 (BaseConversion.widemul Zlog2R n nout (fst lo_hi) N') 0 in + dlet_nd t1_t2 := (BaseConversion.widemul Zlog2R n nout y N) in dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in @@ -7841,7 +7780,7 @@ Module MontgomeryReduction. rewrite Z.pow_mul_r, R_two_pow by omega; reflexivity. Qed. - Local Ltac change_weight := rewrite ?w_mul_pown, !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r in *. + Local Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r, ?Z.pow_1_l in *. Local Ltac solve_range := repeat match goal with | _ => progress change_weight @@ -7850,20 +7789,9 @@ Module MontgomeryReduction. | |- 0 <= _ * _ < _ * _ => split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ] | _ => solve [auto] - | _ => cbn - | _ => nia + | _ => omega end. - Lemma widemul_correct x y : - 0 <= x * y < w 2 -> BaseConversion.widemul w_mul n nout x y = [(x * y) mod R; (x * y) / R]. - Proof. - intros. - rewrite BaseConversion.widemul_partitions by (auto; omega). - subst nout. cbn. change_weight. - autorewrite with zsimplify. - Z.rewrite_mod_small. reflexivity. - Qed. - Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): montred' lo_hi = reduce_via_partial N R N' T. @@ -7873,7 +7801,9 @@ Module MontgomeryReduction. rewrite Hlo, Hhi. assert (0 <= T mod R * N' < w 2) by (solve_range). - rewrite !widemul_correct by (rewrite ?widemul_correct; autorewrite with push_nth_default; solve_range). + rewrite !BaseConversion.widemul_correct + by (rewrite ?BaseConversion.widemul_correct; autorewrite with push_nth_default; solve_range). + rewrite R_two_pow. autorewrite with push_nth_default. autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. |