From 76f2195bfbb31caccac87203f65e4538f3a0dafc Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Mon, 9 Apr 2018 10:06:10 +0200 Subject: reorganization: move more things into BaseConversion --- coqprime | 2 +- src/Experiments/SimplyTypedArithmetic.v | 321 ++++++++++++++++---------------- 2 files changed, 162 insertions(+), 161 deletions(-) diff --git a/coqprime b/coqprime index 59e3bf69a..bd626ee33 160000 --- a/coqprime +++ b/coqprime @@ -1 +1 @@ -Subproject commit 59e3bf69a84c593ad733b83dbcfa90036f5d052a +Subproject commit bd626ee330cc28aadfc2d675772f5077b098f717 diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index dd52e96f4..8be01866a 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -595,30 +595,6 @@ Section mod_ops. Qed. End mod_ops. -Module BaseConversion. - Import Positional. - Section BaseConversion. - Context (sw dw : nat -> Z) (* source/destination weight functions *) - (dw_0 : dw 0%nat = 1) - (dw_nz : forall i, dw i <> 0). - Context (dw_divides : forall i : nat, dw (S i) / dw i > 0). - - Definition convert_bases (sn dn : nat) (p : list Z) : list Z := - let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in - chained_carries_no_reduce dw dn p' (seq 0 (pred dn)). - - Lemma eval_convert_bases sn dn p : - (dn <> 0%nat) -> length p = sn -> - eval dw dn (convert_bases sn dn p) = eval sw sn p. - Proof. - cbv [convert_bases]; intros. - rewrite eval_chained_carries_no_reduce; auto using ZUtil.Z.positive_is_nonzero. - rewrite eval_from_associational; auto. - Qed. - - End BaseConversion. -End BaseConversion. - Module Saturated. Section Weight. Context (weight : nat->Z) @@ -640,7 +616,7 @@ Module Saturated. | _ => reflexivity end. Qed. - + Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0. Proof. intros; replace j with (i + (j - i))%nat by omega. @@ -714,7 +690,7 @@ Module Saturated. rewrite Z.mod_pull_div, Z.mul_div_eq' by auto using Z.gt_lt. repeat (f_equal; try omega). Qed. - + Lemma add_mod_l_multiple a b n m: 0 < n / m -> m <> 0 -> n mod m = 0 -> (a mod n + b) mod m = (a + b) mod m. @@ -727,7 +703,7 @@ Module Saturated. Qed. Definition is_div_mod {T} (evalf : T -> Z) dm y n := - evalf (fst dm) = y mod n /\ snd dm = y / n. + evalf (fst dm) = y mod n /\ snd dm = y / n. Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x : n1 > 0 -> @@ -842,14 +818,14 @@ Module Columns. | _ => progress autorewrite with pull_Zmod pull_Zdiv zsimplify_fast | _ => progress autorewrite with list distr_length push_eval end. - + Lemma flatten_column_mod fw (xs : list Z) : fst (flatten_column fw xs) = sum xs mod fw. Proof. induction xs; simpl flatten_column; cbv [Let_In]; repeat match goal with | _ => rewrite IHxs - | _ => progress push + | _ => progress push end. Qed. Hint Rewrite flatten_column_mod : to_div_mod. @@ -964,7 +940,7 @@ Module Columns. Hint Rewrite Positional.eval_zeros : push_eval. Hint Rewrite Positional.length_from_associational : distr_length. Hint Rewrite Positional.eval_add_to_nth using (solve [distr_length]): push_eval. - + (* from_associational *) Definition from_associational n (p:list (Z*Z)) : list (list Z) := List.fold_right (fun t ls => @@ -1062,7 +1038,7 @@ Module Rows. | _ => progress distr_length | _ => rewrite Positional.eval_snoc with (n:=n) by distr_length | _ => progress autorewrite with cancel_pair push_eval push_map in * - | _ => ring + | _ => ring end. rewrite IHinp by distr_length. destruct x; cbn [hd tl]; rewrite ?sum_nil, ?sum_cons; ring. @@ -1072,7 +1048,7 @@ Module Rows. length inp = n -> length (fst (extract_row inp)) = n. Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed. Hint Rewrite length_fst_extract_row : distr_length. - + Lemma length_snd_extract_row n (inp : cols) : length inp = n -> length (snd (extract_row inp)) = n. Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed. @@ -1084,7 +1060,7 @@ Module Rows. (* TODO: move to where list is defined *) Hint Rewrite @app_nil_l : list. Hint Rewrite <-@app_comm_cons: list. - + Lemma max_column_size_nil : max_column_size nil = 0%nat. Proof. reflexivity. Qed. Hint Rewrite max_column_size_nil : push_max_column_size. Lemma max_column_size_cons col (inp : cols) : @@ -1142,7 +1118,7 @@ Module Rows. Proof. apply eval_from_columns'_with_length. reflexivity. Qed. Hint Rewrite length_snd_from_columns' : distr_length. Lemma eval_from_columns' m st n : - (length (fst st) = n) -> + (length (fst st) = n) -> eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st) - Columns.eval weight n (fst (from_columns' m st)). Proof. apply eval_from_columns'_with_length. Qed. @@ -1191,7 +1167,7 @@ Module Rows. (* from associational *) Definition from_associational n (p : list (Z * Z)) := from_columns (Columns.from_associational weight n p). - + Lemma eval_from_associational n p: (n <> 0%nat \/ p = nil) -> eval n (from_associational n p) = Associational.eval p. Proof. @@ -1253,7 +1229,7 @@ Module Rows. rewrite <-nth_default_eq in *. autorewrite with push_nth_default in *. rewrite eq_nat_dec_refl in *. - congruence. + congruence. Qed. Lemma from_associational_nonnil n p : @@ -1334,7 +1310,7 @@ Module Rows. apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length in *; try omega. eapply is_div_mod_step with (x := x1 + x2); try eassumption; push. Qed. - + Lemma sum_rows_div_mod n row1 row2 : length row1 = n -> length row2 = n -> let eval := Positional.eval weight in @@ -1373,7 +1349,7 @@ Module Rows. = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i). Proof. induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; []. - + rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). apply IHrow1; clear IHrow1; push; repeat match goal with @@ -1390,7 +1366,7 @@ Module Rows. push. rewrite add_mod_div_multiple by auto using Z.lt_le_incl. push. } Qed. - + Lemma sum_rows_partitions row1: forall row2 n i, length row1 = n -> length row2 = n -> (i < n)%nat -> nth_default 0 (fst (sum_rows row1 row2)) i @@ -1493,7 +1469,7 @@ Module Rows. (forall row, In row inp -> length row = n) -> length (fst (flatten' start_state inp)) = n. Proof. apply flatten'_div_mod_length. Qed. - Hint Rewrite length_flatten' : distr_length. + Hint Rewrite length_flatten' : distr_length. Lemma length_flatten n inp : (forall row, In row inp -> length row = n) -> @@ -1551,8 +1527,8 @@ Module Rows. match goal with H : context [nth_default _ (p ++ [ _ ])] |- _ => rewrite <-H by omega end. { autorewrite with push_nth_default natsimplify. reflexivity. } - { autorewrite with push_nth_default natsimplify. - break_match; omega. } + { autorewrite with push_nth_default natsimplify. + break_match; omega. } Qed. Lemma flatten_partitions' inp n : @@ -1595,31 +1571,96 @@ Module Rows. End Rows. End Rows. -Module MulConverted. - Section mul_converted. - Context (w w' : nat -> Z). - Context (w'_0 : w' 0%nat = 1) - (w'_nonzero : forall i, w' i <> 0) - (w'_positive : forall i, w' i > 0) - (w'_divides : forall i : nat, w' (S i) / w' i > 0). - Context (w_0 : w 0%nat = 1) - (w_nonzero : forall i, w i <> 0) - (w_positive : forall i, w i > 0) - (w_multiples : forall i, w (S i) mod w i = 0) - (w_divides : forall i : nat, w (S i) / w i > 0). - - (* TODO : move this stuff to BaseConversion *) - Definition to_associational n m p : list (Z * Z) := - let p' := BaseConversion.convert_bases w w' n m p in - Positional.to_associational w' m p'. - - Definition from_associational idxs n (p : list (Z * Z)) : list Z := - (* important not to use Positional.carry here; we don't want to accumulate yet *) - let p' := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p (rev idxs) in - fst (Rows.flatten w n (Rows.from_associational w n p')). - - (* takes in inputs in base w, converts to w', multiplies in that - format, converts to w again, then flattens. *) +Module BaseConversion. + Import Positional. + Section BaseConversion. + Context (sw dw : nat -> Z) (* source/destination weight functions *) + (dw_0 : dw 0%nat = 1) + (sw_0 : sw 0%nat = 1) + (dw_nz : forall i, dw i <> 0) + (sw_nz : forall i, sw i <> 0) + (sw_pos : forall i, sw i > 0) + (sw_multiples : forall i, sw (S i) mod sw i = 0) + (sw_divides : forall i, sw (S i) / sw i > 0) + (dw_divides : forall i, dw (S i) / dw i > 0). + + Definition convert_bases (sn dn : nat) (p : list Z) : list Z := + let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in + chained_carries_no_reduce dw dn p' (seq 0 (pred dn)). + + Lemma eval_convert_bases sn dn p : + (dn <> 0%nat) -> length p = sn -> + eval dw dn (convert_bases sn dn p) = eval sw sn p. + Proof. + cbv [convert_bases]; intros. + rewrite eval_chained_carries_no_reduce; auto using ZUtil.Z.positive_is_nonzero. + rewrite eval_from_associational; auto. + Qed. + + Hint Rewrite + @Rows.eval_from_associational + @Associational.eval_carry + @Associational.eval_mul + @Positional.eval_to_associational + @eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval. + + Ltac push_eval := intros; autorewrite with push_eval; auto with zarith. + + (* convert from positional in one weight to the other, then to associational *) + Definition to_associational n m p : list (Z * Z) := + let p' := convert_bases n m p in + Positional.to_associational dw m p'. + + (* carry at specified indices in dw, then use Rows.flatten to convert to Positional with sw *) + Definition from_associational idxs n (p : list (Z * Z)) : list Z := + (* important not to use Positional.carry here; we don't want to accumulate yet *) + let p' := fold_right (fun i acc => Associational.carry (dw i) (dw (S i) / dw i) acc) p (rev idxs) in + fst (Rows.flatten sw n (Rows.from_associational sw n p')). + + Lemma eval_carries p idxs : + Associational.eval (fold_right (fun i acc => Associational.carry (dw i) (dw (S i) / dw i) acc) p idxs) = + Associational.eval p. + Proof. apply fold_right_invariant; push_eval. Qed. + Hint Rewrite eval_carries: push_eval. + + Lemma eval_to_associational n m p : + m <> 0%nat -> length p = n -> + Associational.eval (to_associational n m p) = Positional.eval sw n p. + Proof. cbv [to_associational]; push_eval. Qed. + Hint Rewrite eval_to_associational using solve [push_eval; distr_length] : push_eval. + + Lemma eval_from_associational idxs n p : + n <> 0%nat -> 0 <= Associational.eval p < sw n -> + Positional.eval sw n (from_associational idxs n p) = Associational.eval p. + Proof. + cbv [from_associational]; intros. + rewrite Rows.flatten_mod by eauto using Rows.length_from_associational. + push_eval. + Qed. + Hint Rewrite eval_from_associational using solve [push_eval; distr_length] : push_eval. + + Lemma from_associational_partitions n idxs p (_:n<>0%nat): + forall i, (i < n)%nat -> + nth_default 0 (from_associational idxs n p) i = (Associational.eval p) mod (sw (S i)) / sw i. + Proof. + intros; cbv [from_associational]. + rewrite Rows.flatten_partitions with (n:=n) by (eauto using Rows.length_from_associational; omega). + push_eval. + Qed. + + Lemma from_associational_eq n idxs p (_:n<>0%nat): + from_associational idxs n p = Rows.partition sw n (Associational.eval p). + Proof. + intros. cbv [from_associational]. + rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational. + push_eval. + Qed. + + (* carry chain that aligns terms in the intermediate weight with the final weight *) + Definition aligned_carries (log_dw_sw nout : nat) + := (map (fun i => ((log_dw_sw * (i + 1)) - 1))%nat (seq 0 nout)). + + Section mul_converted. Definition mul_converted n1 n2 (* lengths in original format *) m1 m2 (* lengths in converted format *) @@ -1631,104 +1672,63 @@ Module MulConverted. let p3_a := Associational.mul p1_a p2_a in from_associational idxs n3 p3_a. - Hint Rewrite - @Rows.eval_from_associational - @Associational.eval_carry - @Associational.eval_mul - @Positional.eval_to_associational - @BaseConversion.eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval. - - Ltac push_eval := intros; autorewrite with push_eval; auto with zarith. - - Lemma eval_carries p idxs : - Associational.eval (fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p idxs) = - Associational.eval p. - Proof. apply fold_right_invariant; push_eval. Qed. - Hint Rewrite eval_carries: push_eval. - - Lemma eval_to_associational n m p : - m <> 0%nat -> length p = n -> - Associational.eval (to_associational n m p) = Positional.eval w n p. - Proof. cbv [to_associational]; push_eval. Qed. - Hint Rewrite eval_to_associational using solve [push_eval; distr_length] : push_eval. - - Lemma eval_from_associational idxs n p : - n <> 0%nat -> 0 <= Associational.eval p < w n -> - Positional.eval w n (from_associational idxs n p) = Associational.eval p. - Proof. - cbv [from_associational]; intros. - rewrite Rows.flatten_mod by eauto using Rows.length_from_associational. - push_eval. - Qed. - Hint Rewrite eval_from_associational using solve [push_eval; distr_length] : push_eval. - Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> - 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> - Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). + 0 <= (Positional.eval sw n1 p1 * Positional.eval sw n2 p2) < sw n3 -> + Positional.eval sw n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval sw n1 p1) * (Positional.eval sw n2 p2). Proof. cbv [mul_converted]; push_eval. Qed. Hint Rewrite eval_mul_converted : push_eval. - Lemma from_associational_partitions n idxs p (_:n<>0%nat): - forall i, (i < n)%nat -> - nth_default 0 (from_associational idxs n p) i = (Associational.eval p) mod (w (S i)) / w i. - Proof. - intros; cbv [from_associational]. - rewrite Rows.flatten_partitions with (n:=n) by (eauto using Rows.length_from_associational; omega). - push_eval. - Qed. - Lemma mul_converted_partitions n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> - forall i, (i < n3)%nat -> - nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) i = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w (S i)) / w i. - Proof. - intros; cbv [mul_converted]. - rewrite from_associational_partitions by auto. push_eval. - Qed. - - Lemma from_associational_eq n idxs p (_:n<>0%nat): - from_associational idxs n p = Rows.partition w n (Associational.eval p). - Proof. - intros. cbv [from_associational]. - rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational. - push_eval. - Qed. - - (* TODO: convert all _partitions proofs to this form? *) - Lemma mul_converted_eq n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): - length p1 = n1 -> length p2 = n2 -> - mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = Rows.partition w n3 (Positional.eval w n1 p1 * Positional.eval w n2 p2). + mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = Rows.partition sw n3 (Positional.eval sw n1 p1 * Positional.eval sw n2 p2). Proof. intros; cbv [mul_converted]. rewrite from_associational_eq by auto. push_eval. Qed. + End mul_converted. + End BaseConversion. - Section aligned. - Context (log_w'_w : nat) (Hlog_w'_w : forall i, (w' i) ^ Z.of_nat log_w'_w = w i). - Context (n n' nout : nat) (Hn : n <> 0%nat) (Hn' : n' <> 0%nat) (Hnout : nout <> 0%nat). - - (* carry chain that aligns terms in the intermediate weight with the final weight *) - Definition aligned_carries := (map (fun i => ((log_w'_w * (i + 1)) - 1))%nat (seq 0 nout)). - - Definition from_associational_aligned := from_associational aligned_carries nout. - - (* assumes both inputs have the same length *) - Definition mul_aligned := mul_converted n n n' n' nout aligned_carries. - - Lemma mul_aligned_eq p1 p2 : - length p1 = n -> length p2 = n -> - mul_aligned p1 p2 = Rows.partition w nout (Positional.eval w n p1 * Positional.eval w n p2). - Proof. cbv [mul_aligned aligned_carries]. auto using mul_converted_eq. Qed. - - Lemma eval_mul_aligned p1 p2 : - length p1 = n -> length p2 = n -> - 0 <= Positional.eval w n p1 * Positional.eval w n p2 < w nout -> - Positional.eval w nout (mul_aligned p1 p2) = Positional.eval w n p1 * Positional.eval w n p2. - Proof. cbv [mul_aligned aligned_carries]. auto using eval_mul_converted. Qed. - End aligned. - End mul_converted. -End MulConverted. + (* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *) + Section widemul. + Context (dw : nat -> Z) (n : nat) (n_nz : n <> 0%nat). + Context (dw_0 : dw 0%nat = 1) + (dw_pos : forall i, dw i > 0) + (dw_multiples : forall i, dw (S i) mod dw i = 0) + (dw_divides : forall i, dw (S i) / dw i > 0). + Context (nout : nat) (nout_nz : nout <> 0%nat). (* this is always 2, but reification has trouble if it's a constant *) + Let sw i := (dw i) ^ Z.of_nat n. + + Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg. + (* TODO : There has got to be a cleaner way to do weight functions *) + Lemma sw_0 : sw 0 = 1. + Proof. subst sw; cbv beta. rewrite dw_0. auto using Z.pow_1_l. Qed. + Lemma sw_pos : forall i, sw i > 0. + Proof. subst sw; intros; apply Z.lt_gt. Z.zero_bounds. Qed. + Lemma sw_nz : forall i, sw i <> 0. + Proof. auto using sw_pos. Qed. + Lemma sw_multiples : forall i, sw (S i) mod sw i = 0. + Proof. + subst sw; cbv beta; intros. + rewrite Saturated.weight_div_mod with (weight := dw) (j:=S i) (i:=i) by auto. + rewrite Z.pow_mul_l. + push_Zmod. autorewrite with zsimplify. reflexivity. + Qed. + Lemma sw_divides : forall i, sw (S i) / sw i > 0. + Proof. intros; apply Z.div_positive_gt_0; auto using sw_pos, sw_multiples. Qed. + Hint Resolve sw_0 sw_pos sw_nz sw_multiples sw_divides. + + Definition widemul a b := mul_converted sw dw 1 1 n n nout (aligned_carries n nout) [a] [b]. + + Lemma widemul_partitions a b : + widemul a b = Rows.partition sw nout (a * b). + Proof. + cbv [widemul]. + rewrite mul_converted_partitions by auto with zarith. + cbn; rewrite sw_0; ring_simplify_subterms. reflexivity. + Qed. + End widemul. +End BaseConversion. Module Import MOVEME. Fixpoint fold_andb_map {A B} (f : A -> B -> bool) (ls1 : list A) (ls2 : list B) @@ -7810,6 +7810,7 @@ Module MontgomeryReduction. := weight_multiples _ _ half_log2R_good. Let w_mul_divides : forall i : nat, w_mul (S i) / w_mul i > 0 := weight_divides _ _ half_log2R_good. +(* Let w_0 : w 0%nat = 1 := weight_0 _ _. Let w_nonzero : forall i, w i <> 0 := weight_nz _ _ log2R_good. @@ -7820,14 +7821,13 @@ Module MontgomeryReduction. Let w_divides : forall i : nat, w (S i) / w i > 0 := weight_divides _ _ log2R_good. Let w_1_gt1 : w 1 > 1 := weight_1_gt_1 _ _ log2R_good. +*) Let w_mul_1_gt1 : w_mul 1 > 1 := weight_1_gt_1 _ _ half_log2R_good. Context (nout : nat) (Hnout : nout = 2%nat). - Definition widemul a b := MulConverted.mul_aligned w w_mul n 1%nat n nout [a] [b]. - Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (widemul (fst lo_hi) N') 0 in - dlet_nd t1_t2 := widemul y N in + dlet_nd y := nth_default 0 (BaseConversion.widemul w_mul n nout (fst lo_hi) N') 0 in + dlet_nd t1_t2 := (BaseConversion.widemul w_mul n nout y N) in dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in @@ -7841,7 +7841,7 @@ Module MontgomeryReduction. rewrite Z.pow_mul_r, R_two_pow by omega; reflexivity. Qed. - Local Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r in *. + Local Ltac change_weight := rewrite ?w_mul_pown, !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r in *. Local Ltac solve_range := repeat match goal with | _ => progress change_weight @@ -7855,9 +7855,10 @@ Module MontgomeryReduction. end. Lemma widemul_correct x y : - 0 <= x * y < w 2 -> widemul x y = [(x * y) mod R; (x * y) / R]. + 0 <= x * y < w 2 -> BaseConversion.widemul w_mul n nout x y = [(x * y) mod R; (x * y) / R]. Proof. - intros; cbv [widemul]. rewrite MulConverted.mul_aligned_eq by (auto; distr_length). + intros. + rewrite BaseConversion.widemul_partitions by (auto; omega). subst nout. cbn. change_weight. autorewrite with zsimplify. Z.rewrite_mod_small. reflexivity. -- cgit v1.2.3