aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-03-06 13:31:56 +0100
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-03 09:00:55 -0400
commit05567335df0a787e66877a222b2284975b0f7f0a (patch)
treee22634de820dc5c14c675141a663d2af6913cf84 /src
parentfcf5f782aade5339ad91e077f23010e1dd27d98c (diff)
move mul_converted to its own module
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v188
1 files changed, 95 insertions, 93 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index c301fc67f..2e46d0dd8 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -946,95 +946,6 @@ Module Columns.
fst (flatten (from_associational m pq_a)).
End mul.
End Columns.
-
- Section mul_converted.
- Context (w w' : nat -> Z).
- Context (w'_0 : w' 0%nat = 1)
- (w'_nonzero : forall i, w' i <> 0)
- (w'_positive : forall i, w' i > 0)
- (w'_divides : forall i : nat, w' (S i) / w' i > 0).
- Context (w_0 : w 0%nat = 1)
- (w_nonzero : forall i, w i <> 0)
- (w_positive : forall i, w i > 0)
- (w_multiples : forall i, w (S i) mod w i = 0)
- (w_divides : forall i : nat, w (S i) / w i > 0).
-
- (* takes in inputs in base w, converts to w', multiplies in that
- format, converts to w again, then flattens. *)
- Definition mul_converted
- n1 n2 (* lengths in original format *)
- m1 m2 (* lengths in converted format *)
- (n3 : nat) (* final length *)
- (idxs : list nat) (* carries to do -- this helps preemptively line up weights *)
- (p1 p2 : list Z) :=
- let p1' := BaseConversion.convert_bases w w' n1 m1 p1 in
- let p2' := BaseConversion.convert_bases w w' n2 m2 p2 in
- let p1_a := Positional.to_associational w' m1 p1' in
- let p2_a := Positional.to_associational w' m2 p2' in
- let p3_a := Associational.mul p1_a p2_a in
- (* important not to use Positional.carry here; we don't want to accumulate yet *)
- let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in
- fst (flatten w (from_associational w n3 p3'_a)).
-
- Hint Rewrite
- @Columns.eval_from_associational
- @Associational.eval_carry
- @Associational.eval_mul
- @Positional.eval_to_associational
- @BaseConversion.eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval.
-
- Lemma eval_carries p idxs :
- Associational.eval (fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p idxs) =
- Associational.eval p.
- Proof. apply fold_right_invariant; intros; autorewrite with push_eval; congruence. Qed.
- Hint Rewrite eval_carries: push_eval.
-
- Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
- length p1 = n1 -> length p2 = n2 ->
- 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
- Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2).
- Proof.
- cbv [mul_converted]; intros.
- rewrite Columns.flatten_mod by auto using Columns.length_from_associational.
- autorewrite with push_eval. auto using Z.mod_small.
- Qed.
- Hint Rewrite eval_mul_converted : push_eval.
-
- Hint Rewrite @length_from_associational : distr_length.
-
- Lemma mul_converted_mod n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
- length p1 = n1 -> length p2 = n2 ->
- 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
- nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1).
- Proof.
- intros; cbv [mul_converted].
- erewrite flatten_partitions by (auto; distr_length).
- autorewrite with distr_length push_eval natsimplify.
- rewrite w_0; autorewrite with zsimplify.
- reflexivity.
- Qed.
-
- Lemma mul_converted_div n1 n2 m1 m2 n3 idxs p1 p2:
- m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat ->
- length p1 = n1 -> length p2 = n2 ->
- 0 <= Positional.eval w n1 p1 ->
- 0 <= Positional.eval w n2 p2 ->
- 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
- nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1).
- Proof.
- intros; subst n3; cbv [mul_converted].
- erewrite flatten_partitions by (auto; distr_length).
- autorewrite with distr_length push_eval.
- rewrite Z.mod_small; omega.
- Qed.
-
- (* shortcut definition for convert-mul-convert for cases when we are halving the bitwidth before multiplying. *)
- (* the most important feature here is the carries--we carry from all the odd indices after multiplying,
- thus pre-aligning everything with the double-size bitwidth *)
- Definition mul_converted_halve n n2 :=
- mul_converted n n n2 n2 n2 (map (fun x => 2*x + 1)%nat (seq 0 n)).
-
- End mul_converted.
End Columns.
Module Rows.
@@ -1508,6 +1419,97 @@ Module Rows.
End Rows.
End Rows.
+Module MulConverted.
+ Section mul_converted.
+ Context (w w' : nat -> Z).
+ Context (w'_0 : w' 0%nat = 1)
+ (w'_nonzero : forall i, w' i <> 0)
+ (w'_positive : forall i, w' i > 0)
+ (w'_divides : forall i : nat, w' (S i) / w' i > 0).
+ Context (w_0 : w 0%nat = 1)
+ (w_nonzero : forall i, w i <> 0)
+ (w_positive : forall i, w i > 0)
+ (w_multiples : forall i, w (S i) mod w i = 0)
+ (w_divides : forall i : nat, w (S i) / w i > 0).
+
+ (* takes in inputs in base w, converts to w', multiplies in that
+ format, converts to w again, then flattens. *)
+ Definition mul_converted
+ n1 n2 (* lengths in original format *)
+ m1 m2 (* lengths in converted format *)
+ (n3 : nat) (* final length *)
+ (idxs : list nat) (* carries to do -- this helps preemptively line up weights *)
+ (p1 p2 : list Z) :=
+ let p1' := BaseConversion.convert_bases w w' n1 m1 p1 in
+ let p2' := BaseConversion.convert_bases w w' n2 m2 p2 in
+ let p1_a := Positional.to_associational w' m1 p1' in
+ let p2_a := Positional.to_associational w' m2 p2' in
+ let p3_a := Associational.mul p1_a p2_a in
+ (* important not to use Positional.carry here; we don't want to accumulate yet *)
+ let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in
+ fst (Columns.flatten w (Columns.from_associational w n3 p3'_a)).
+
+ Hint Rewrite
+ @Columns.eval_from_associational
+ @Associational.eval_carry
+ @Associational.eval_mul
+ @Positional.eval_to_associational
+ @BaseConversion.eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval.
+
+ Lemma eval_carries p idxs :
+ Associational.eval (fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p idxs) =
+ Associational.eval p.
+ Proof. apply fold_right_invariant; intros; autorewrite with push_eval; congruence. Qed.
+ Hint Rewrite eval_carries: push_eval.
+
+ Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
+ length p1 = n1 -> length p2 = n2 ->
+ 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
+ Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2).
+ Proof.
+ cbv [mul_converted]; intros.
+ rewrite Columns.flatten_mod by auto using Columns.length_from_associational.
+ autorewrite with push_eval. auto using Z.mod_small.
+ Qed.
+ Hint Rewrite eval_mul_converted : push_eval.
+
+ Hint Rewrite @Columns.length_from_associational : distr_length.
+
+ Lemma mul_converted_mod n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
+ length p1 = n1 -> length p2 = n2 ->
+ 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
+ nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1).
+ Proof.
+ intros; cbv [mul_converted].
+ erewrite Columns.flatten_partitions by (auto; distr_length).
+ autorewrite with distr_length push_eval natsimplify.
+ rewrite w_0; autorewrite with zsimplify.
+ reflexivity.
+ Qed.
+
+ Lemma mul_converted_div n1 n2 m1 m2 n3 idxs p1 p2:
+ m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat ->
+ length p1 = n1 -> length p2 = n2 ->
+ 0 <= Positional.eval w n1 p1 ->
+ 0 <= Positional.eval w n2 p2 ->
+ 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
+ nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1).
+ Proof.
+ intros; subst n3; cbv [mul_converted].
+ erewrite Columns.flatten_partitions by (auto; distr_length).
+ autorewrite with distr_length push_eval.
+ rewrite Z.mod_small; omega.
+ Qed.
+
+ (* shortcut definition for convert-mul-convert for cases when we are halving the bitwidth before multiplying. *)
+ (* the most important feature here is the carries--we carry from all the odd indices after multiplying,
+ thus pre-aligning everything with the double-size bitwidth *)
+ Definition mul_converted_halve n n2 :=
+ mul_converted n n n2 n2 n2 (map (fun x => 2*x + 1)%nat (seq 0 n)).
+
+ End mul_converted.
+End MulConverted.
+
Module Import MOVEME.
Fixpoint fold_andb_map {A B} (f : A -> B -> bool) (ls1 : list A) (ls2 : list B)
: bool
@@ -7854,8 +7856,8 @@ Module MontgomeryReduction.
Context (n:nat) (Hn: n = 2%nat).
Definition montred' (lo_hi : (Z * Z)) :=
- dlet_nd y := nth_default 0 (Columns.mul_converted_halve w w_half 1%nat n [fst lo_hi] [N']) 0 in
- dlet_nd t1_t2 := Columns.mul_converted_halve w w_half 1%nat n [y] [N] in
+ dlet_nd y := nth_default 0 (MulConverted.mul_converted_halve w w_half 1%nat n [fst lo_hi] [N']) 0 in
+ dlet_nd t1_t2 := MulConverted.mul_converted_halve w w_half 1%nat n [y] [N] in
dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in
dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in
dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in
@@ -7882,7 +7884,7 @@ Module MontgomeryReduction.
end.
Hint Rewrite
- Columns.mul_converted_mod Columns.mul_converted_div using (solve [auto; autorewrite with mul_conv; solve_range])
+ MulConverted.mul_converted_mod MulConverted.mul_converted_div using (solve [auto; autorewrite with mul_conv; solve_range])
: mul_conv.
Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N)
@@ -7893,7 +7895,7 @@ Module MontgomeryReduction.
cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In].
rewrite Hlo, Hhi. subst n.
assert (0 <= T mod R * N' < w 2) by (solve_range).
- cbv [Columns.mul_converted_halve]. cbn.
+ cbv [MulConverted.mul_converted_halve]. cbn [seq map].
autorewrite with mul_conv.
rewrite Hw, ?Z.pow_1_r.
autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct.