aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-17 14:13:22 +0200
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-30 04:20:04 -0400
commit6cbd9dba8e259d7ee3eda867fd1b0f5512da90f6 (patch)
tree090682a2b555a0195450cf8c3ec80df6532828be /src/Experiments/SimplyTypedArithmetic.v
parent026c09658d1554e8a24cbab8a147c7675deb961b (diff)
fix definitions of saturated operations to avoid unnecessary work, and make Montgomery use them
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v95
1 files changed, 63 insertions, 32 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 8e635c524..853b28545 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -1543,34 +1543,68 @@ Module Rows.
End Flatten.
Section Ops.
- Definition add n p q :=
- let p_a := Positional.to_associational weight n p in
- let q_a := Positional.to_associational weight n q in
- flatten n (from_associational n (p_a ++ q_a)).
+ Definition add n p q := flatten n [p; q].
- Definition sub n p q :=
- let p_a := Positional.to_associational weight n p in
- let q_a := Positional.to_associational weight n q in
- flatten n (from_associational n (p_a ++ Associational.negate_snd q_a)).
+ (* TODO: Although cleaner, using Positional.negate snd inserts
+ dlets which prevent add-opp=>sub transformation in partial
+ evaluation. Should probably either make partial evaluation
+ handle that or remove the dlet in
+ Positional.from_associational. *)
+ Definition sub n p q := flatten n [p; map (fun x => dlet y := x in Z.opp y) q].
+
+ Hint Rewrite eval_cons eval_nil using solve [auto] : push_eval.
Lemma add_partitions n p q :
n <> 0%nat -> length p = n -> length q = n ->
fst (add n p q) = partition n (Positional.eval weight n p + Positional.eval weight n q).
Proof.
intros; cbv [add].
- rewrite flatten_partitions' by eauto using length_from_associational.
- rewrite eval_from_associational by auto.
- autorewrite with push_eval; reflexivity.
+ rewrite flatten_partitions' by (intros; In_cases; subst; distr_length).
+ autorewrite with push_eval. ring_simplify_subterms.
+ reflexivity.
Qed.
+ Lemma add_div n p q :
+ n <> 0%nat -> length p = n -> length q = n ->
+ snd (add n p q) = (Positional.eval weight n p + Positional.eval weight n q) / weight n.
+ Proof.
+ intros; cbv [add].
+ rewrite flatten_div by (intros; In_cases; subst; distr_length).
+ autorewrite with push_eval. ring_simplify_subterms.
+ reflexivity.
+ Qed.
+
+ Lemma eval_map_opp q :
+ forall n, length q = n ->
+ Positional.eval weight n (map Z.opp q) = - Positional.eval weight n q.
+ Proof.
+ induction q using rev_ind; intros;
+ repeat match goal with
+ | _ => progress autorewrite with push_map push_eval
+ | _ => erewrite !Positional.eval_snoc with (n:=length q) by distr_length
+ | _ => rewrite IHq by auto
+ | _ => ring
+ end.
+ Qed. Hint Rewrite eval_map_opp using solve [auto]: push_eval.
+
Lemma sub_partitions n p q :
n <> 0%nat -> length p = n -> length q = n ->
fst (sub n p q) = partition n (Positional.eval weight n p - Positional.eval weight n q).
Proof.
intros; cbv [sub].
- rewrite flatten_partitions' by eauto using length_from_associational.
- rewrite eval_from_associational by auto.
- autorewrite with push_eval; reflexivity.
+ rewrite flatten_partitions' by (intros; In_cases; subst; distr_length).
+ autorewrite with push_eval. ring_simplify_subterms.
+ reflexivity.
+ Qed.
+
+ Lemma sub_div n p q :
+ n <> 0%nat -> length p = n -> length q = n ->
+ snd (sub n p q) = (Positional.eval weight n p - Positional.eval weight n q) / weight n.
+ Proof.
+ intros; cbv [sub].
+ rewrite flatten_div by (intros; In_cases; subst; distr_length).
+ autorewrite with push_eval. ring_simplify_subterms.
+ reflexivity.
Qed.
End Ops.
End Rows.
@@ -8187,10 +8221,9 @@ Module MontgomeryReduction.
Definition montred' (lo_hi : (Z * Z)) :=
dlet_nd y := nth_default 0 (BaseConversion.widemul Zlog2R n nout (fst lo_hi) N') 0 in
dlet_nd t1_t2 := (BaseConversion.widemul Zlog2R n nout y N) in
- dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in
- dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in
- dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in
- dlet_nd lo'' := fst (Z.sub_get_borrow_full R (fst hi'_carry) y') in
+ dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [fst lo_hi; snd lo_hi] t1_t2 in
+ dlet_nd y' := Z.zselect (snd sum_carry) 0 N in
+ dlet_nd lo'' := fst (Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y') in
Z.add_modulo lo'' 0 N.
Local Lemma Hw : forall i, w i = R ^ Z.of_nat i.
@@ -8212,6 +8245,9 @@ Module MontgomeryReduction.
| _ => omega
end.
+ Local Lemma eval2 x y : eval w 2 [x;y] = x + R * y.
+ Proof. cbn. change_weight. ring. Qed.
+
Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N)
(Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R):
montred' lo_hi = reduce_via_partial N R N' T.
@@ -8223,23 +8259,18 @@ Module MontgomeryReduction.
rewrite !BaseConversion.widemul_correct
by (rewrite ?BaseConversion.widemul_correct; autorewrite with push_nth_default; solve_range).
+ rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega).
rewrite R_two_pow.
- autorewrite with push_nth_default.
+ cbv [Rows.partition seq]. rewrite !eval2.
+ autorewrite with push_nth_default push_map.
autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct.
+ change_weight.
(* pull out value before last modular reduction *)
match goal with |- (if (?n <=? ?x)%Z then ?x - ?n else ?x) = (if (?n <=? ?y) then ?y - ?n else ?y)%Z =>
let P := fresh "H" in assert (x = y) as P; [|rewrite P; reflexivity] end.
- match goal with
- |- context [if R * R <=? ?x then _ else _] =>
- match goal with |- context [if dec (?xHigh / R = 0) then _ else _] =>
- assert (x / R = xHigh) as cond_equiv end end.
- { apply Z.mul_cancel_r with (p:=R); [omega|].
- autorewrite with push_Zmul zdiv_to_mod push_Zmod; ring. }
- rewrite <-cond_equiv. rewrite ?Z.mod_pull_div, ?Z.div_div by omega.
- assert (0 < R * R)%Z by Z.zero_bounds.
-
+ autorewrite with zsimplify.
break_match; try reflexivity; autorewrite with ltb_to_lt in *; rewrite Z.div_small_iff in * by omega;
repeat match goal with
| _ => progress autorewrite with zsimplify_fast
@@ -8371,8 +8402,8 @@ montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z *
expr_let x24 := ADDC_256 (x23₂, x14, x17) in
expr_let x25 := ADD_256 (x20, x23₁) in
expr_let x26 := ADDC_256 (x25₂, x19, x24₁) in
- expr_let x27 := ADD_256 (x₁, x25₁) in
- expr_let x28 := ADDC_256 (x27₂, x₂, x26₁) in
+ expr_let x27 := ADD_256 (x25₁, x₁) in
+ expr_let x28 := ADDC_256 (x27₂, x26₁, x₂) in
expr_let x29 := SELC (x28₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
expr_let x30 := Z.cast uint256 @@ (fst @@ SUB_256 (x28₁, x29)) in
ADDM (x30, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
@@ -8513,8 +8544,8 @@ c.Add256($x23, $x21, $x22);
c.Addc($x24, $x14, $x17);
c.Add256($x25, $x20, $x23_lo);
c.Addc($x26, $x19, $x24_lo);
-c.Add256($x27, $x_lo, $x25_lo);
-c.Addc($x28, $x_hi, $x26_lo);
+c.Add256($x27, $x25_lo, $x_lo);
+c.Addc($x28, $x26_lo, $x_hi);
c.Selc($x29,RegZero, RegMod);
c.Sub($x30, $x28_lo, $x29);
c.AddM($ret, $x30, RegZero, RegMod);