From 6cbd9dba8e259d7ee3eda867fd1b0f5512da90f6 Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Tue, 17 Apr 2018 14:13:22 +0200 Subject: fix definitions of saturated operations to avoid unnecessary work, and make Montgomery use them --- src/Experiments/SimplyTypedArithmetic.v | 95 ++++++++++++++++++++++----------- 1 file changed, 63 insertions(+), 32 deletions(-) (limited to 'src/Experiments/SimplyTypedArithmetic.v') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 8e635c524..853b28545 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -1543,34 +1543,68 @@ Module Rows. End Flatten. Section Ops. - Definition add n p q := - let p_a := Positional.to_associational weight n p in - let q_a := Positional.to_associational weight n q in - flatten n (from_associational n (p_a ++ q_a)). + Definition add n p q := flatten n [p; q]. - Definition sub n p q := - let p_a := Positional.to_associational weight n p in - let q_a := Positional.to_associational weight n q in - flatten n (from_associational n (p_a ++ Associational.negate_snd q_a)). + (* TODO: Although cleaner, using Positional.negate snd inserts + dlets which prevent add-opp=>sub transformation in partial + evaluation. Should probably either make partial evaluation + handle that or remove the dlet in + Positional.from_associational. *) + Definition sub n p q := flatten n [p; map (fun x => dlet y := x in Z.opp y) q]. + + Hint Rewrite eval_cons eval_nil using solve [auto] : push_eval. Lemma add_partitions n p q : n <> 0%nat -> length p = n -> length q = n -> fst (add n p q) = partition n (Positional.eval weight n p + Positional.eval weight n q). Proof. intros; cbv [add]. - rewrite flatten_partitions' by eauto using length_from_associational. - rewrite eval_from_associational by auto. - autorewrite with push_eval; reflexivity. + rewrite flatten_partitions' by (intros; In_cases; subst; distr_length). + autorewrite with push_eval. ring_simplify_subterms. + reflexivity. Qed. + Lemma add_div n p q : + n <> 0%nat -> length p = n -> length q = n -> + snd (add n p q) = (Positional.eval weight n p + Positional.eval weight n q) / weight n. + Proof. + intros; cbv [add]. + rewrite flatten_div by (intros; In_cases; subst; distr_length). + autorewrite with push_eval. ring_simplify_subterms. + reflexivity. + Qed. + + Lemma eval_map_opp q : + forall n, length q = n -> + Positional.eval weight n (map Z.opp q) = - Positional.eval weight n q. + Proof. + induction q using rev_ind; intros; + repeat match goal with + | _ => progress autorewrite with push_map push_eval + | _ => erewrite !Positional.eval_snoc with (n:=length q) by distr_length + | _ => rewrite IHq by auto + | _ => ring + end. + Qed. Hint Rewrite eval_map_opp using solve [auto]: push_eval. + Lemma sub_partitions n p q : n <> 0%nat -> length p = n -> length q = n -> fst (sub n p q) = partition n (Positional.eval weight n p - Positional.eval weight n q). Proof. intros; cbv [sub]. - rewrite flatten_partitions' by eauto using length_from_associational. - rewrite eval_from_associational by auto. - autorewrite with push_eval; reflexivity. + rewrite flatten_partitions' by (intros; In_cases; subst; distr_length). + autorewrite with push_eval. ring_simplify_subterms. + reflexivity. + Qed. + + Lemma sub_div n p q : + n <> 0%nat -> length p = n -> length q = n -> + snd (sub n p q) = (Positional.eval weight n p - Positional.eval weight n q) / weight n. + Proof. + intros; cbv [sub]. + rewrite flatten_div by (intros; In_cases; subst; distr_length). + autorewrite with push_eval. ring_simplify_subterms. + reflexivity. Qed. End Ops. End Rows. @@ -8187,10 +8221,9 @@ Module MontgomeryReduction. Definition montred' (lo_hi : (Z * Z)) := dlet_nd y := nth_default 0 (BaseConversion.widemul Zlog2R n nout (fst lo_hi) N') 0 in dlet_nd t1_t2 := (BaseConversion.widemul Zlog2R n nout y N) in - dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in - dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in - dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in - dlet_nd lo'' := fst (Z.sub_get_borrow_full R (fst hi'_carry) y') in + dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [fst lo_hi; snd lo_hi] t1_t2 in + dlet_nd y' := Z.zselect (snd sum_carry) 0 N in + dlet_nd lo'' := fst (Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y') in Z.add_modulo lo'' 0 N. Local Lemma Hw : forall i, w i = R ^ Z.of_nat i. @@ -8212,6 +8245,9 @@ Module MontgomeryReduction. | _ => omega end. + Local Lemma eval2 x y : eval w 2 [x;y] = x + R * y. + Proof. cbn. change_weight. ring. Qed. + Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): montred' lo_hi = reduce_via_partial N R N' T. @@ -8223,23 +8259,18 @@ Module MontgomeryReduction. rewrite !BaseConversion.widemul_correct by (rewrite ?BaseConversion.widemul_correct; autorewrite with push_nth_default; solve_range). + rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega). rewrite R_two_pow. - autorewrite with push_nth_default. + cbv [Rows.partition seq]. rewrite !eval2. + autorewrite with push_nth_default push_map. autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. + change_weight. (* pull out value before last modular reduction *) match goal with |- (if (?n <=? ?x)%Z then ?x - ?n else ?x) = (if (?n <=? ?y) then ?y - ?n else ?y)%Z => let P := fresh "H" in assert (x = y) as P; [|rewrite P; reflexivity] end. - match goal with - |- context [if R * R <=? ?x then _ else _] => - match goal with |- context [if dec (?xHigh / R = 0) then _ else _] => - assert (x / R = xHigh) as cond_equiv end end. - { apply Z.mul_cancel_r with (p:=R); [omega|]. - autorewrite with push_Zmul zdiv_to_mod push_Zmod; ring. } - rewrite <-cond_equiv. rewrite ?Z.mod_pull_div, ?Z.div_div by omega. - assert (0 < R * R)%Z by Z.zero_bounds. - + autorewrite with zsimplify. break_match; try reflexivity; autorewrite with ltb_to_lt in *; rewrite Z.div_small_iff in * by omega; repeat match goal with | _ => progress autorewrite with zsimplify_fast @@ -8371,8 +8402,8 @@ montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * expr_let x24 := ADDC_256 (x23₂, x14, x17) in expr_let x25 := ADD_256 (x20, x23₁) in expr_let x26 := ADDC_256 (x25₂, x19, x24₁) in - expr_let x27 := ADD_256 (x₁, x25₁) in - expr_let x28 := ADDC_256 (x27₂, x₂, x26₁) in + expr_let x27 := ADD_256 (x25₁, x₁) in + expr_let x28 := ADDC_256 (x27₂, x26₁, x₂) in expr_let x29 := SELC (x28₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in expr_let x30 := Z.cast uint256 @@ (fst @@ SUB_256 (x28₁, x29)) in ADDM (x30, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) @@ -8513,8 +8544,8 @@ c.Add256($x23, $x21, $x22); c.Addc($x24, $x14, $x17); c.Add256($x25, $x20, $x23_lo); c.Addc($x26, $x19, $x24_lo); -c.Add256($x27, $x_lo, $x25_lo); -c.Addc($x28, $x_hi, $x26_lo); +c.Add256($x27, $x25_lo, $x_lo); +c.Addc($x28, $x26_lo, $x_hi); c.Selc($x29,RegZero, RegMod); c.Sub($x30, $x28_lo, $x29); c.AddM($ret, $x30, RegZero, RegMod); -- cgit v1.2.3