aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v75
1 files changed, 51 insertions, 24 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 0a0db7dfd..bafe666e4 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -1617,16 +1617,19 @@ Module Rows.
let pq_a := Associational.sat_mul base p_a q_a in
flatten m (from_associational m pq_a).
+ (* TODO : move sat_reduce and repeat_sat_reduce to Saturated.Associational *)
Definition sat_reduce base s c (p : list (Z * Z)) :=
let lo_hi := Associational.split s p in
fst lo_hi ++ (Associational.sat_mul_const base c (snd lo_hi)).
- (* TODO : hardcoded to 2 reductions; fix *)
- Definition mulmod base s c n (p q : list Z) :=
+ Definition repeat_sat_reduce base s c (p : list (Z * Z)) n :=
+ fold_right (fun _ q => sat_reduce base s c q) p (seq 0 n).
+
+ Definition mulmod base s c n nreductions (p q : list Z) :=
let p_a := Positional.to_associational weight n p in
let q_a := Positional.to_associational weight n q in
let pq_a := Associational.sat_mul base p_a q_a in
- let r_a := sat_reduce base s c (sat_reduce base s c pq_a) in
+ let r_a := repeat_sat_reduce base s c pq_a nreductions in
flatten n (from_associational n r_a).
Hint Rewrite Associational.eval_sat_mul_const Associational.eval_sat_mul Associational.eval_split using solve [auto] : push_eval.
@@ -1677,19 +1680,38 @@ Module Rows.
fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q).
Proof. solver. Qed.
- Lemma eval_mulmod base s c n p q :
+ Lemma eval_sat_reduce base s c p :
+ base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 ->
+ Associational.eval (sat_reduce base s c p) mod (s - Associational.eval c)
+ = Associational.eval p mod (s - Associational.eval c).
+ Proof.
+ intros; cbv [sat_reduce].
+ autorewrite with push_eval.
+ rewrite <-Associational.reduction_rule by omega.
+ autorewrite with push_eval; reflexivity.
+ Qed.
+ Hint Rewrite eval_sat_reduce using auto : push_eval.
+
+ Lemma eval_repeat_sat_reduce base s c p n :
+ base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 ->
+ Associational.eval (repeat_sat_reduce base s c p n) mod (s - Associational.eval c)
+ = Associational.eval p mod (s - Associational.eval c).
+ Proof.
+ intros; cbv [repeat_sat_reduce].
+ apply fold_right_invariant; intros; autorewrite with push_eval; auto.
+ Qed.
+ Hint Rewrite eval_repeat_sat_reduce using auto : push_eval.
+
+ Lemma eval_mulmod base s c n nreductions p q :
base <> 0 -> s <> 0 -> s - Associational.eval c <> 0 ->
n <> 0%nat -> length p = n -> length q = n ->
- (Positional.eval weight n (fst (mulmod base s c n p q))
- + weight n * (snd (mulmod base s c n p q))) mod (s - Associational.eval c)
+ (Positional.eval weight n (fst (mulmod base s c n nreductions p q))
+ + weight n * (snd (mulmod base s c n nreductions p q))) mod (s - Associational.eval c)
= (Positional.eval weight n p * Positional.eval weight n q) mod (s - Associational.eval c).
Proof.
solver.
rewrite <-Z.div_mod'' by auto.
- rewrite <-Associational.reduction_rule by omega.
- autorewrite with push_eval.
- rewrite <-Associational.reduction_rule by omega.
- autorewrite with push_eval; auto.
+ autorewrite with push_eval; reflexivity.
Qed.
End Ops.
End Rows.
@@ -8478,7 +8500,7 @@ Module SaturatedSolinas.
Context (s : Z) (c : list (Z * Z))
(s_nz : s <> 0) (modulus_nz : s - Associational.eval c <> 0).
Context (log2base : Z) (log2base_pos : 0 < log2base)
- (n : nat) (n_nz : n <> 0%nat).
+ (n nreductions : nat) (n_nz : n <> 0%nat).
Let weight := weight log2base 1.
Let props : @weight_properties weight := wprops log2base 1 ltac:(omega).
@@ -8493,7 +8515,7 @@ Module SaturatedSolinas.
As eval_mulmod.
Proof.
intros.
- rewrite <-Rows.eval_mulmod with (base:=2^log2base) (s:=s) (c:=c) by auto using base_nz.
+ rewrite <-Rows.eval_mulmod with (base:=2^log2base) (s:=s) (c:=c) (nreductions:=nreductions) by auto using base_nz.
eapply f_equal2; [|trivial].
(* expand_lists (). *) (* uncommenting this line removes some unused multiplications but also inlines a bunch of carry stuff at the end *)
subst mulmod. reflexivity.
@@ -8502,15 +8524,15 @@ Module SaturatedSolinas.
End MulMod.
Derive mulmod_gen
- SuchThat (forall (log2base s : Z) (c : list (Z * Z)) (n m : nat)
+ SuchThat (forall (log2base s : Z) (c : list (Z * Z)) (n nreductions : nat)
(f g : list Z),
Interp (t:=type.reify_type_of mulmod')
- mulmod_gen s c log2base n f g
- = mulmod' s c log2base n f g)
+ mulmod_gen s c log2base n nreductions f g
+ = mulmod' s c log2base n nreductions f g)
As mulmod_gen_correct.
Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed.
Module Export ReifyHints.
- Global Hint Extern 1 (_ = mulmod' _ _ _ _ _ _) => simple apply mulmod_gen_correct : reify_gen_cache.
+ Global Hint Extern 1 (_ = mulmod' _ _ _ _ _ _ _) => simple apply mulmod_gen_correct : reify_gen_cache.
End ReifyHints.
Section rmulmod.
@@ -8521,7 +8543,15 @@ Module SaturatedSolinas.
Definition relax_zrange_of_machine_wordsize
:= relax_zrange_gen [1; machine_wordsize]%Z.
- Let n : nat := Z.to_nat (Z.log2 s / machine_wordsize).
+ Let n : nat := Z.to_nat (Qceiling (Z.log2_up s / machine_wordsize)).
+ (* Number of reductions is calculated as follows :
+ Let i be the highest limb index of c. Then, each reduction
+ decreases the number of extra limbs by (n-i). So, to go from
+ the n extra limbs we have post-multiplication down to 0, we
+ need ceil (n / (n - i)) reductions. *)
+ Let nreductions : nat :=
+ let i := fold_right Z.max 0 (map (fun t => Z.log2 (fst t) / machine_wordsize) c) in
+ Z.to_nat (Qceiling (Z.of_nat n / (Z.of_nat n - i))).
Let relax_zrange := relax_zrange_of_machine_wordsize.
Let bound := Some r[0 ~> (2^machine_wordsize - 1)]%zrange.
Let boundsn : list (ZRange.type.option.interp type.Z)
@@ -8563,7 +8593,7 @@ Module SaturatedSolinas.
:= BoundsPipeline_correct
(Some boundsn, Some boundsn)
(Some boundsn)
- (mulmod' s c machine_wordsize n).
+ (mulmod' s c machine_wordsize n nreductions).
Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _).
Definition rmulmod_correctT rv : Prop
@@ -8839,6 +8869,8 @@ mulmod = fun var : type -> Type => λ x : var (type.list (type.type_primitive ty
End P192_32.
+(* TODO : Too slow! Many, many terms in this one. *)
+(*
Module P256_32.
Definition s := 2^256.
Definition c := [(2^224, 1); (2^192, -1); (2^96, -1); (1,1)].
@@ -8854,14 +8886,9 @@ Module P256_32.
Set Printing Width 100000.
Print mulmod.
- (* TODO : this one should use more reductions *)
- (* Since 224 = 7*32, first reduce will leave 7 extra words
- Each reduce changes that number by only 1
- therefore we need 8 reductions
- *)
-
End P256_32.
+*)
Module MontgomeryReduction.
Section MontRed'.