diff options
author | Jade Philipoom <jadep@google.com> | 2018-03-06 13:31:56 +0100 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2018-04-03 09:00:55 -0400 |
commit | 05567335df0a787e66877a222b2284975b0f7f0a (patch) | |
tree | e22634de820dc5c14c675141a663d2af6913cf84 /src | |
parent | fcf5f782aade5339ad91e077f23010e1dd27d98c (diff) |
move mul_converted to its own module
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 188 |
1 files changed, 95 insertions, 93 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index c301fc67f..2e46d0dd8 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -946,95 +946,6 @@ Module Columns. fst (flatten (from_associational m pq_a)). End mul. End Columns. - - Section mul_converted. - Context (w w' : nat -> Z). - Context (w'_0 : w' 0%nat = 1) - (w'_nonzero : forall i, w' i <> 0) - (w'_positive : forall i, w' i > 0) - (w'_divides : forall i : nat, w' (S i) / w' i > 0). - Context (w_0 : w 0%nat = 1) - (w_nonzero : forall i, w i <> 0) - (w_positive : forall i, w i > 0) - (w_multiples : forall i, w (S i) mod w i = 0) - (w_divides : forall i : nat, w (S i) / w i > 0). - - (* takes in inputs in base w, converts to w', multiplies in that - format, converts to w again, then flattens. *) - Definition mul_converted - n1 n2 (* lengths in original format *) - m1 m2 (* lengths in converted format *) - (n3 : nat) (* final length *) - (idxs : list nat) (* carries to do -- this helps preemptively line up weights *) - (p1 p2 : list Z) := - let p1' := BaseConversion.convert_bases w w' n1 m1 p1 in - let p2' := BaseConversion.convert_bases w w' n2 m2 p2 in - let p1_a := Positional.to_associational w' m1 p1' in - let p2_a := Positional.to_associational w' m2 p2' in - let p3_a := Associational.mul p1_a p2_a in - (* important not to use Positional.carry here; we don't want to accumulate yet *) - let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in - fst (flatten w (from_associational w n3 p3'_a)). - - Hint Rewrite - @Columns.eval_from_associational - @Associational.eval_carry - @Associational.eval_mul - @Positional.eval_to_associational - @BaseConversion.eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval. - - Lemma eval_carries p idxs : - Associational.eval (fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p idxs) = - Associational.eval p. - Proof. apply fold_right_invariant; intros; autorewrite with push_eval; congruence. Qed. - Hint Rewrite eval_carries: push_eval. - - Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): - length p1 = n1 -> length p2 = n2 -> - 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> - Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). - Proof. - cbv [mul_converted]; intros. - rewrite Columns.flatten_mod by auto using Columns.length_from_associational. - autorewrite with push_eval. auto using Z.mod_small. - Qed. - Hint Rewrite eval_mul_converted : push_eval. - - Hint Rewrite @length_from_associational : distr_length. - - Lemma mul_converted_mod n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): - length p1 = n1 -> length p2 = n2 -> - 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> - nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1). - Proof. - intros; cbv [mul_converted]. - erewrite flatten_partitions by (auto; distr_length). - autorewrite with distr_length push_eval natsimplify. - rewrite w_0; autorewrite with zsimplify. - reflexivity. - Qed. - - Lemma mul_converted_div n1 n2 m1 m2 n3 idxs p1 p2: - m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat -> - length p1 = n1 -> length p2 = n2 -> - 0 <= Positional.eval w n1 p1 -> - 0 <= Positional.eval w n2 p2 -> - 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> - nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1). - Proof. - intros; subst n3; cbv [mul_converted]. - erewrite flatten_partitions by (auto; distr_length). - autorewrite with distr_length push_eval. - rewrite Z.mod_small; omega. - Qed. - - (* shortcut definition for convert-mul-convert for cases when we are halving the bitwidth before multiplying. *) - (* the most important feature here is the carries--we carry from all the odd indices after multiplying, - thus pre-aligning everything with the double-size bitwidth *) - Definition mul_converted_halve n n2 := - mul_converted n n n2 n2 n2 (map (fun x => 2*x + 1)%nat (seq 0 n)). - - End mul_converted. End Columns. Module Rows. @@ -1508,6 +1419,97 @@ Module Rows. End Rows. End Rows. +Module MulConverted. + Section mul_converted. + Context (w w' : nat -> Z). + Context (w'_0 : w' 0%nat = 1) + (w'_nonzero : forall i, w' i <> 0) + (w'_positive : forall i, w' i > 0) + (w'_divides : forall i : nat, w' (S i) / w' i > 0). + Context (w_0 : w 0%nat = 1) + (w_nonzero : forall i, w i <> 0) + (w_positive : forall i, w i > 0) + (w_multiples : forall i, w (S i) mod w i = 0) + (w_divides : forall i : nat, w (S i) / w i > 0). + + (* takes in inputs in base w, converts to w', multiplies in that + format, converts to w again, then flattens. *) + Definition mul_converted + n1 n2 (* lengths in original format *) + m1 m2 (* lengths in converted format *) + (n3 : nat) (* final length *) + (idxs : list nat) (* carries to do -- this helps preemptively line up weights *) + (p1 p2 : list Z) := + let p1' := BaseConversion.convert_bases w w' n1 m1 p1 in + let p2' := BaseConversion.convert_bases w w' n2 m2 p2 in + let p1_a := Positional.to_associational w' m1 p1' in + let p2_a := Positional.to_associational w' m2 p2' in + let p3_a := Associational.mul p1_a p2_a in + (* important not to use Positional.carry here; we don't want to accumulate yet *) + let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in + fst (Columns.flatten w (Columns.from_associational w n3 p3'_a)). + + Hint Rewrite + @Columns.eval_from_associational + @Associational.eval_carry + @Associational.eval_mul + @Positional.eval_to_associational + @BaseConversion.eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval. + + Lemma eval_carries p idxs : + Associational.eval (fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p idxs) = + Associational.eval p. + Proof. apply fold_right_invariant; intros; autorewrite with push_eval; congruence. Qed. + Hint Rewrite eval_carries: push_eval. + + Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + length p1 = n1 -> length p2 = n2 -> + 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> + Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). + Proof. + cbv [mul_converted]; intros. + rewrite Columns.flatten_mod by auto using Columns.length_from_associational. + autorewrite with push_eval. auto using Z.mod_small. + Qed. + Hint Rewrite eval_mul_converted : push_eval. + + Hint Rewrite @Columns.length_from_associational : distr_length. + + Lemma mul_converted_mod n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + length p1 = n1 -> length p2 = n2 -> + 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> + nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1). + Proof. + intros; cbv [mul_converted]. + erewrite Columns.flatten_partitions by (auto; distr_length). + autorewrite with distr_length push_eval natsimplify. + rewrite w_0; autorewrite with zsimplify. + reflexivity. + Qed. + + Lemma mul_converted_div n1 n2 m1 m2 n3 idxs p1 p2: + m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat -> + length p1 = n1 -> length p2 = n2 -> + 0 <= Positional.eval w n1 p1 -> + 0 <= Positional.eval w n2 p2 -> + 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> + nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1). + Proof. + intros; subst n3; cbv [mul_converted]. + erewrite Columns.flatten_partitions by (auto; distr_length). + autorewrite with distr_length push_eval. + rewrite Z.mod_small; omega. + Qed. + + (* shortcut definition for convert-mul-convert for cases when we are halving the bitwidth before multiplying. *) + (* the most important feature here is the carries--we carry from all the odd indices after multiplying, + thus pre-aligning everything with the double-size bitwidth *) + Definition mul_converted_halve n n2 := + mul_converted n n n2 n2 n2 (map (fun x => 2*x + 1)%nat (seq 0 n)). + + End mul_converted. +End MulConverted. + Module Import MOVEME. Fixpoint fold_andb_map {A B} (f : A -> B -> bool) (ls1 : list A) (ls2 : list B) : bool @@ -7854,8 +7856,8 @@ Module MontgomeryReduction. Context (n:nat) (Hn: n = 2%nat). Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (Columns.mul_converted_halve w w_half 1%nat n [fst lo_hi] [N']) 0 in - dlet_nd t1_t2 := Columns.mul_converted_halve w w_half 1%nat n [y] [N] in + dlet_nd y := nth_default 0 (MulConverted.mul_converted_halve w w_half 1%nat n [fst lo_hi] [N']) 0 in + dlet_nd t1_t2 := MulConverted.mul_converted_halve w w_half 1%nat n [y] [N] in dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in @@ -7882,7 +7884,7 @@ Module MontgomeryReduction. end. Hint Rewrite - Columns.mul_converted_mod Columns.mul_converted_div using (solve [auto; autorewrite with mul_conv; solve_range]) + MulConverted.mul_converted_mod MulConverted.mul_converted_div using (solve [auto; autorewrite with mul_conv; solve_range]) : mul_conv. Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) @@ -7893,7 +7895,7 @@ Module MontgomeryReduction. cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. rewrite Hlo, Hhi. subst n. assert (0 <= T mod R * N' < w 2) by (solve_range). - cbv [Columns.mul_converted_halve]. cbn. + cbv [MulConverted.mul_converted_halve]. cbn [seq map]. autorewrite with mul_conv. rewrite Hw, ?Z.pow_1_r. autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. |