diff options
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 448 |
1 files changed, 342 insertions, 106 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 853b28545..98d445e6e 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -7797,13 +7797,14 @@ Module BarrettReduction. (muSelect_correct: muSelect = mu mod 2 ^ k * (x / 2 ^ (k - 1) / 2 ^ k)). Definition qt := - let q1 := shiftr xt (k - 1) in - let twoq := mul_high mut q1 muSelect in + dlet_nd muSelect := muSelect in (* makes sure muSelect is not inlined in the output *) + dlet_nd q1 := shiftr xt (k - 1) in + dlet_nd twoq := mul_high mut q1 muSelect in shiftr twoq 1. Definition reduce := - let r2 := mul (low qt) M in - let r := sub xt r2 in - let q3 := cond_sub1 r M in + dlet_nd r2 := mul (low qt) M in + dlet_nd r := sub xt r2 in + dlet_nd q3 := cond_sub1 r M in cond_sub2 q3 M. Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k). @@ -7881,7 +7882,7 @@ Module BarrettReduction. pose proof looser_bound. pose proof r_bounds. pose proof q_bounds. assert (2 * M < 2^k * 2^k) by nia. rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega). - cbv [reduce]. + cbv [reduce Let_In]. erewrite low_correct by eauto. Z.rewrite_mod_small. erewrite two_conditional_subtracts by solve_rep. rewrite !cond_sub2_correct. @@ -7890,14 +7891,13 @@ Module BarrettReduction. End Generic. Section BarrettReduction. - Context (k : Z) (Hk_positive : 0 < k). + Context (k : Z) (k_bound : 2 <= k). Context (M muLow : Z). Context (M_pos : 0 < M) (muLow_eq : muLow + 2^k = 2^(2*k) / M) (muLow_bounds : 0 <= muLow < 2^k) (M_bound1 : 2 ^ (k - 1) < M < 2^k) (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)). - Context (pow2_k_bound : 4 <= 2^k). Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k). Context (nout : nat) (Hnout : nout = 2%nat). @@ -7907,30 +7907,37 @@ Module BarrettReduction. Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval. - Let T : Type := Z * Z. + Definition low (t : list Z) : Z := nth_default 0 t 0. + Definition high (t : list Z) : Z := nth_default 0 t 1. + Definition represents (t : list Z) (x : Z) := + t = [x mod 2^k; x / 2^k] /\ 0 <= x < 2^k * 2^k. - Definition represents (t : T) (x : Z) := - fst t = x mod 2^k /\ snd t = x / 2^k /\ 0 <= x < 2^k * 2^k. - - Lemma represents_fst t x : - represents t x -> fst t = x mod 2^k. + Lemma represents_eq t x : + represents t x -> t = [x mod 2^k; x / 2^k]. Proof. cbv [represents]; tauto. Qed. - Lemma represents_snd t x : - represents t x -> snd t = x / 2^k. - Proof. cbv [represents]; tauto. Qed. + Lemma represents_length t x : represents t x -> length t = 2%nat. + Proof. cbv [represents]; intuition. subst t; reflexivity. Qed. + + Lemma represents_low t x : + represents t x -> low t = x mod 2^k. + Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. - Lemma represents_fst_range t x : + Lemma represents_high t x : + represents t x -> high t = x / 2^k. + Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. + + Lemma represents_low_range t x : represents t x -> 0 <= x mod 2^k < 2^k. Proof. auto with zarith. Qed. - Hint Resolve represents_fst_range. - Lemma represents_snd_range t x : + + Lemma represents_high_range t x : represents t x -> 0 <= x / 2^k < 2^k. Proof. destruct 1 as [? [? ?] ]; intros. auto using Z.div_lt_upper_bound with zarith. Qed. - Hint Resolve represents_snd_range. + Hint Resolve represents_length represents_low_range represents_high_range. Lemma represents_range t x : represents t x -> 0 <= x < 2^k*2^k. @@ -7938,7 +7945,7 @@ Module BarrettReduction. Lemma represents_id x : 0 <= x < 2^k * 2^k -> - represents (x mod 2^k, x / 2^k) x. + represents [x mod 2^k; x / 2^k] x. Proof. intros; cbv [represents]; autorewrite with cancel_pair. Z.rewrite_mod_small; tauto. @@ -7946,14 +7953,14 @@ Module BarrettReduction. Local Ltac push_rep := repeat match goal with - | H : represents ?t ?x |- _ => unique pose proof (represents_fst_range _ _ H) - | H : represents ?t ?x |- _ => unique pose proof (represents_snd_range _ _ H) - | H : represents ?t ?x |- _ => rewrite (represents_fst t x) in * by assumption - | H : represents ?t ?x |- _ => rewrite (represents_snd t x) in * by assumption + | H : represents ?t ?x |- _ => unique pose proof (represents_low_range _ _ H) + | H : represents ?t ?x |- _ => unique pose proof (represents_high_range _ _ H) + | H : represents ?t ?x |- _ => rewrite (represents_low t x) in * by assumption + | H : represents ?t ?x |- _ => rewrite (represents_high t x) in * by assumption end. - Definition shiftr (t : T) (n : Z) : T := - (Z.rshi (2^k) (snd t) (fst t) n, Z.rshi (2^k) 0 (snd t) n). + Definition shiftr (t : list Z) (n : Z) : list Z := + [Z.rshi (2^k) (high t) (low t) n; Z.rshi (2^k) 0 (high t) n]. Lemma shiftr_represents a i x : represents a x -> @@ -7971,7 +7978,7 @@ Module BarrettReduction. | _ => rewrite <-Z.div_mod''' by auto with zarith | _ => progress autorewrite with zsimplify_fast | _ => progress Z.rewrite_mod_small - | |- context [represents ((?a / ?c) mod ?b, ?a / ?b / ?c)] => + | |- context [represents [(?a / ?c) mod ?b; ?a / ?b / ?c] ] => rewrite (Z.div_div_comm a b c) by auto with zarith | _ => solve [auto using represents_id, Z.div_lt_upper_bound with zarith lia] end. @@ -7980,42 +7987,34 @@ Module BarrettReduction. Context (Hw : forall i, w i = (2 ^ k) ^ Z.of_nat i). Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r. - Definition wideadd t1 t2 := - let sum := fst (Rows.add w 2 [fst t1; snd t1] [fst t2; snd t2]) in - (nth_default 0 sum 0, nth_default 0 sum 1). + Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2). + Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2). + Definition widemul := BaseConversion.widemul k n nout. - Definition widesub t1 t2 := - let sum := fst (Rows.sub w 2 [fst t1; snd t1] [fst t2; snd t2]) in - (nth_default 0 sum 0, nth_default 0 sum 1). - - Definition widemul x y := - let xy := BaseConversion.widemul k n nout x y in - (nth_default 0 xy 0, nth_default 0 xy 1). - - Lemma partition_represents x y : - 0 <= y < 2^k*2^k -> - x = Rows.partition w 2 y -> - represents (nth_default 0 x 0, nth_default 0 x 1) y. + Lemma partition_represents x : + 0 <= x < 2^k*2^k -> + represents (Rows.partition w 2 x) x. Proof. - intros; subst x; cbv [represents Rows.partition]. - cbn; change_weight. Z.rewrite_mod_small. - auto with zarith. + intros; cbn. change_weight. + Z.rewrite_mod_small. + autorewrite with zsimplify_fast. + auto using represents_id. Qed. Lemma eval_represents t x : - represents t x -> - eval w 2 [fst t; snd t] = x. + represents t x -> eval w 2 t = x. Proof. - intros; cbn. change_weight; push_rep. + intros; rewrite (represents_eq t x) by assumption. + cbn. change_weight; push_rep. autorewrite with zsimplify. reflexivity. Qed. Ltac wide_op partitions_pf := repeat match goal with - | _ => apply partition_represents; auto with zarith; [ ] - | _ => rewrite partitions_pf by auto + | _ => rewrite partitions_pf by eauto + | _ => rewrite partitions_pf by auto with zarith | _ => erewrite eval_represents by eauto - | _ => reflexivity + | _ => solve [auto using partition_represents, represents_id] end. Lemma wideadd_represents t1 t2 x y : @@ -8038,30 +8037,29 @@ Module BarrettReduction. represents (widemul x y) (x * y). Proof. intros; cbv [widemul]. - rewrite BaseConversion.widemul_correct by auto with zarith. - autorewrite with push_nth_default. - auto using represents_id with zarith. + assert (0 <= x * y < 2^k*2^k) by auto with zarith. + wide_op BaseConversion.widemul_correct. Qed. - Definition mul_high (a b : T) a0b1 : T := - let a0b0 := widemul (fst a) (fst b) in - let ab := wideadd (snd a0b0, snd b) (fst b, 0) in - wideadd ab (a0b1, 0). + Definition mul_high (a b : list Z) a0b1 : list Z := + dlet_nd a0b0 := widemul (low a) (low b) in + dlet_nd ab := wideadd [high a0b0; high b] [low b; 0] in + wideadd ab [a0b1; 0]. - Lemma mul_high_idea s a b a0 a1 b0 b1 : - s <> 0 -> - a = s * a1 + a0 -> - b = s * b1 + b0 -> - (a * b) / s = a0 * b0 / s + s * a1 * b1 + a1 * b0 + a0 * b1. + Lemma mul_high_idea d a b a0 a1 b0 b1 : + d <> 0 -> + a = d * a1 + a0 -> + b = d * b1 + b0 -> + (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1. Proof. intros. subst a b. autorewrite with push_Zmul. ring_simplify_subterms. rewrite Z.pow_2_r. rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega). repeat match goal with - | |- context [s * ?a * ?b * ?c] => - replace (s * a * b * c) with (a * b * c * s) by ring - | |- context [s * ?a * ?b] => - replace (s * a * b) with (a * b * s) by ring + | |- context [d * ?a * ?b * ?c] => + replace (d * a * b * c) with (a * b * c * d) by ring + | |- context [d * ?a * ?b] => + replace (d * a * b) with (a * b * d) by ring end. rewrite !Z.div_add by omega. autorewrite with zsimplify. @@ -8074,16 +8072,18 @@ Module BarrettReduction. represents t x. Proof. congruence. Qed. - Lemma represents_add a b x y : - a = x -> b = y -> + Lemma represents_add x y : 0 <= x < 2 ^ k -> 0 <= y < 2 ^ k -> - represents (a,b) (x + 2^k*y). - Proof. intros; subst a b; repeat split; autorewrite with cancel_pair zsimplify; nia. Qed. + represents [x;y] (x + 2^k*y). + Proof. + intros; cbv [represents]; autorewrite with zsimplify. + repeat split; (reflexivity || nia). + Qed. Lemma represents_small x : 0 <= x < 2^k -> - represents (x, 0) x. + represents [x; 0] x. Proof. intros. eapply represents_trans. @@ -8099,10 +8099,11 @@ Module BarrettReduction. a0b1 = x mod 2^k * (y / 2^k) -> represents (mul_high a b a0b1) ((x * y) / 2^k). Proof. - cbv [mul_high]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros. + cbv [mul_high Let_In]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros. + assert (4 <= 2 ^ k) by (transitivity (Z.pow 2 2); auto with zarith). assert (0 <= x * y / 2^k < 2^k*2^k) by (Z.div_mod_to_quot_rem; nia). - rewrite mul_high_idea with (a:=x) (b:=y) (a0 := fst a) (a1 := snd a) (b0 := fst b) (b1 := snd b) in * + rewrite mul_high_idea with (a:=x) (b:=y) (a0 := low a) (a1 := high a) (b0 := low b) (b1 := high b) in * by (push_rep; Z.div_mod_to_quot_rem; lia). push_rep. subst a0b1. @@ -8113,7 +8114,7 @@ Module BarrettReduction. eapply represents_trans. { repeat (apply wideadd_represents; [ | apply represents_small; Z.div_mod_to_quot_rem; nia| ]). - erewrite represents_snd; [ | apply widemul_represents; solve [ auto with zarith ] ]. + erewrite represents_high; [ | apply widemul_represents; solve [ auto with zarith ] ]. { apply represents_add; try reflexivity; solve [auto with zarith]. } { match goal with H : 0 <= ?x + ?y < ?z |- 0 <= ?x < ?z => split; [ solve [Z.zero_bounds] | ]; @@ -8123,11 +8124,11 @@ Module BarrettReduction. { ring. } Qed. - Definition cond_sub1 (a : T) y : Z := - let maybe_y := Z.zselect (Z.cc_l (snd a)) 0 y in - fst (Z.sub_get_borrow_full (2^k) (fst a) maybe_y). + Definition cond_sub1 (a : list Z) y : Z := + dlet_nd maybe_y := Z.zselect (Z.cc_l (high a)) 0 y in + fst (Z.sub_get_borrow_full (2^k) (low a) maybe_y). - Lemma cc_l_only_bit x s: 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s. + Lemma cc_l_only_bit : forall x s, 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s. Proof. cbv [Z.cc_l]; intros. rewrite Z.div_between_0_if by omega. @@ -8140,7 +8141,7 @@ Module BarrettReduction. 0 <= y < 2 ^ k -> cond_sub1 a y = if (x <? 2 ^ k) then x else x - y. Proof. - intros; cbv [cond_sub1]. rewrite Z.zselect_correct. push_rep. + intros; cbv [cond_sub1 Let_In]. rewrite Z.zselect_correct. push_rep. break_match; Z.ltb_to_lt; rewrite cc_l_only_bit in *; try omega; autorewrite with zsimplify_fast to_div_mod pull_Zmod; auto with zarith. Qed. @@ -8155,7 +8156,7 @@ Module BarrettReduction. Section Defn. Context (xLow xHigh : Z) (xLow_bounds : 0 <= xLow < 2^k) (xHigh_bounds : 0 <= xHigh < M). - Let xt := (xLow, xHigh). + Let xt := [xLow; xHigh]. Let x := xLow + 2^k * xHigh. Lemma x_rep : represents xt x. @@ -8167,6 +8168,8 @@ Module BarrettReduction. Definition muSelect := Z.zselect (Z.cc_m (2 ^ k) xHigh) 0 muLow. Local Hint Resolve Z.div_nonneg Z.div_lt_upper_bound. + Local Hint Resolve shiftr_represents mul_high_represents widemul_represents widesub_represents + cond_sub1_correct cond_sub2_correct represents_low represents_add. Lemma muSelect_correct : muSelect = (2 ^ (2 * k) / M) mod 2 ^ k * ((x / 2 ^ (k - 1)) / 2 ^ k). @@ -8177,9 +8180,10 @@ Module BarrettReduction. assert (0 <= x / (2 ^ k * (2 ^ k / 2)) < 2) by (Z.div_mod_to_quot_rem; auto with nia). assert (0 < 2 ^ k / 2) by Z.zero_bounds. assert (2 ^ (k - 1) <> 0) by auto with zarith. + assert (2 < 2 ^ k) by (eapply Z.le_lt_trans with (m:=2 ^ 1); auto with zarith). cbv [muSelect]. rewrite <-muLow_eq. - rewrite Z.zselect_correct, Z.cc_m_eq by nia. + rewrite Z.zselect_correct, Z.cc_m_eq by auto with zarith. replace xHigh with (x / 2^k) by (subst x; autorewrite with zsimplify; lia). autorewrite with pull_Zdiv push_Zpow. rewrite (Z.mul_comm (2 ^ k / 2)). @@ -8188,22 +8192,180 @@ Module BarrettReduction. autorewrite with zsimplify; reflexivity. Qed. - Definition barrett_reduce : Z := - reduce k fst shiftr mul_high widemul widesub cond_sub1 cond_sub2 xt (muLow, 1) M muSelect. + Lemma mu_rep : represents [muLow; 1] (2 ^ (2 * k) / M). + Proof. rewrite <-muLow_eq. eapply represents_trans; auto with zarith. Qed. - Lemma barrett_reduce_correct : - barrett_reduce = x mod M. + Derive barrett_reduce + SuchThat (barrett_reduce = x mod M) + As barrett_reduce_correct. Proof. - intros; cbv [barrett_reduce]. - apply reduce_correct with (rep:=represents); try omega; - auto using shiftr_represents, mul_high_represents, widemul_represents, widesub_represents, - cond_sub1_correct, cond_sub2_correct, x_bounds, muSelect_correct, represents_fst, x_rep. - rewrite <-muLow_eq. cbv [represents]; repeat split; autorewrite with cancel_pair zsimplify; nia. + erewrite <-reduce_correct with (rep:=represents) (muSelect:=muSelect) (k0:=k) (mut:=[muLow;1]) (xt0:=xt) + by (auto using x_bounds, muSelect_correct, x_rep, mu_rep; omega). + subst barrett_reduce. reflexivity. Qed. End Defn. End BarrettReduction. + + (* all the list operations from for_reification.ident *) + Strategy 100 [length seq repeat combine map flat_map partition app rev fold_right update_nth nth_default ]. + + Derive barrett_red_gen + SuchThat (forall (k M muLow : Z) + (n nout: nat) + (xLow xHigh : Z), + Interp (t:=type.reify_type_of barrett_reduce) + barrett_red_gen k M muLow n nout xLow xHigh + = barrett_reduce k M muLow n nout xLow xHigh) + As barrett_red_gen_correct. + Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. + (* TODO : reification here is still quite slow (~40s on a beefy machine). Possibly just due to size of term, but warrants further investigation. *) + Module Export ReifyHints. + Global Hint Extern 1 (_ = barrett_reduce _ _ _ _ _ _ _) => simple apply barrett_red_gen_correct : reify_gen_cache. + End ReifyHints. + + Section rbarrett_red. + Context (M : Z) + (machine_wordsize : Z). + + Let bound := Some r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. + Let mu := (2 ^ (2 * machine_wordsize)) / M. + Let muLow := mu mod (2 ^ machine_wordsize). + + Check barrett_reduce_correct. + Print Pipeline.Values_not_provably_distinct. + + Definition relax_zrange_of_machine_wordsize + := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z. + Local Arguments relax_zrange_of_machine_wordsize / . + + Let relax_zrange := relax_zrange_of_machine_wordsize. + + Definition check_args {T} (res : Pipeline.ErrorT T) + : Pipeline.ErrorT T + := if (mu / (2 ^ machine_wordsize) =? 0) + then Pipeline.Error (Pipeline.Values_not_provably_distinct "mu / 2 ^ k ≠ 0" (mu / 2 ^ machine_wordsize) 0) + else if (machine_wordsize <? 2) + then Pipeline.Error (Pipeline.Value_not_le "~ (2 <=k)" 2 machine_wordsize) + else if (negb (Z.log2 M + 1 =? machine_wordsize)) + then Pipeline.Error + (Pipeline.Values_not_provably_equal "log2(M)+1 != k" (Z.log2 M + 1) machine_wordsize) + else if (2 ^ (machine_wordsize + 1) - mu <? 2 * (2 ^ (2 * machine_wordsize) mod M)) + then Pipeline.Error + (Pipeline.Value_not_le "~ (2 * (2 ^ (2*k) mod M) <= 2^(k + 1) - mu)" + (2 * (2 ^ (2*machine_wordsize) mod M)) + (2^(machine_wordsize + 1) - mu)) + else res. + + Notation BoundsPipeline_correct in_bounds out_bounds op + := (fun rv (rop : Expr (type.reify_type_of op)) Hrop + => @Pipeline.BoundsPipeline_correct_trans + false (* subst01 *) + relax_zrange + (relax_zrange_gen_good _) + _ + rop + in_bounds + out_bounds + op + Hrop rv) + (only parsing). + + Definition rbarrett_red_correct + := BoundsPipeline_correct + (bound, bound) + bound + (barrett_reduce machine_wordsize M muLow 2 2). + + Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). + Definition rbarrett_red_correctT rv : Prop + := type_of_strip_3arrow (@rbarrett_red_correct rv). + End rbarrett_red. End BarrettReduction. +Ltac solve_rbarrett_red := solve_rop BarrettReduction.rbarrett_red_correct. +Ltac solve_rbarrett_red_nocache := solve_rop_nocache BarrettReduction.rbarrett_red_correct. + +Module Barrett256. + + Definition M := (2^256-2^224+2^192+2^96-1). + Definition machine_wordsize := 256. + + (* TODO : why does this not bounds check? + Let F := Some r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange. + Eval vm_compute in ( + partial.bounds.expr.extract' (fun t : type => id) + (λ x : partial.data (type.type_primitive type.Z * type.type_primitive type.Z), + ident.Z.cast2 + (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange, + r[0 ~> 1]%zrange)%core @@ + (ident.Z.add_get_carry_concrete + 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ + (ident.Z.cast r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935] @@ + (ident.fst @@ (Var x)), ident.primitive (t:=type.Z) 26959946667150639793205513449348445388433292963828203772348655992835 @@ TT ))) (F, F)). + *) + Derive barrett_red256 + SuchThat (BarrettReduction.rbarrett_red_correctT M machine_wordsize barrett_red256) + As barrett_red256_correct. + Proof. Time solve_rbarrett_red machine_wordsize. Time Qed. + + Import PrintingNotations. + Open Scope expr_scope. + Set Printing Width 100000. + + Print barrett_red256. + (* TODO: the ADD/ADDC instructions containing Z.opp should be translated to SUB/SUBB in partial evaluation *) + (* +barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, + expr_let x0 := SELM (x₂, 0, 26959946667150639793205513449348445388433292963828203772348655992835) in + expr_let x1 := RSHI (0, x₂, 255) in + expr_let x2 := RSHI (x₂, x₁, 255) in + expr_let x3 := (uint128)(x2 >> 128) in + expr_let x4 := ((uint128)(x2) & 340282366920938463463374607431768211455) in + expr_let x5 := 79228162514264337589248983038 *₂₅₆ x4 in + expr_let x6 := (uint128)(x5 >> 128) in + expr_let x7 := ((uint128)(x5) & 340282366920938463463374607431768211455) in + expr_let x8 := 340282366841710300930663525764514709507 *₂₅₆ x3 in + expr_let x9 := (uint128)(x8 >> 128) in + expr_let x10 := ((uint128)(x8) & 340282366920938463463374607431768211455) in + expr_let x11 := 79228162514264337589248983038 *₂₅₆ x3 in + expr_let x12 := (uint256)(x7 << 128) in + expr_let x13 := (uint256)(x10 << 128) in + expr_let x14 := 340282366841710300930663525764514709507 *₂₅₆ x4 in + expr_let x15 := ADD_256 (x13, x14) in + expr_let x16 := ADDC_128 (x15₂, x6, x9) in + expr_let x17 := ADD_256 (x12, x15₁) in + expr_let x18 := ADDC_256 (x17₂, x11, x16₁) in + expr_let x19 := ADD_256 (x2, x18₁) in + expr_let x20 := ADDC_128 (x19₂, 0, x1) in + expr_let x21 := ADD_256 (x0, x19₁) in + expr_let x22 := ADDC_128 (x21₂, 0, x20₁) in + expr_let x23 := RSHI (x22₁, x21₁, 1) in + expr_let x24 := (uint128)(x23 >> 128) in + expr_let x25 := ((uint128)(x23) & 340282366920938463463374607431768211455) in + expr_let x26 := 79228162514264337593543950335 *₂₅₆ x24 in + expr_let x27 := (uint128)(x26 >> 128) in + expr_let x28 := ((uint128)(x26) & 340282366920938463463374607431768211455) in + expr_let x29 := 340282366841710300967557013911933812736 *₂₅₆ x25 in + expr_let x30 := (uint128)(x29 >> 128) in + expr_let x31 := ((uint128)(x29) & 340282366920938463463374607431768211455) in + expr_let x32 := 340282366841710300967557013911933812736 *₂₅₆ x24 in + expr_let x33 := (uint256)(x28 << 128) in + expr_let x34 := (uint256)(x31 << 128) in + expr_let x35 := 79228162514264337593543950335 *₂₅₆ x25 in + expr_let x36 := ADD_256 (x34, x35) in + expr_let x37 := ADDC_256 (x36₂, x27, x30) in + expr_let x38 := ADD_256 (x33, x36₁) in + expr_let x39 := ADDC_256 (x38₂, x32, x37₁) in + expr_let x40 := ADD_256 (Z.opp @@ (fst @@ x38), x₁) in + expr_let x41 := ADDC_256 (x40₂, Z.opp @@ (fst @@ x39), x₂) in + expr_let x42 := SELL (x41₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in + expr_let x43 := Z.cast uint256 @@ (fst @@ SUB_256 (x40₁, x42)) in + ADDM (x43, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) + : Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z)) + *) + +End Barrett256. + Module MontgomeryReduction. Section MontRed'. Context (N R N' R' : Z). @@ -8412,7 +8574,7 @@ montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * End Montgomery256. (* Extra-specialized ad-hoc pretty-printing *) -Module Montgomery256PrintingNotations. +Module FancyPrintingNotations. Export ident. Open Scope expr_scope. Open Scope ctype_scope. @@ -8420,6 +8582,10 @@ Module Montgomery256PrintingNotations. (AppIdent (primitive 115792089210356248762697446949407573530086143415290314195533631308867097853951) TT) (only printing, at level 9) : expr_scope. + Notation "'RegMuLow'" := + (AppIdent + (primitive 26959946667150639793205513449348445388433292963828203772348655992835) + TT) (only printing, at level 9) : expr_scope. Notation "'RegPinv'" := (AppIdent (primitive 115792089210356248768974548684794254293921932838497980611635986753331132366849) @@ -8437,6 +8603,14 @@ Module Montgomery256PrintingNotations. (AppIdent (primitive 340282366841710300967557013911933812736) TT) (only printing, at level 9, format "'RegMod' '>>' '128'") : expr_scope. + Notation "'Lower128{RegMuLow}'" := + (AppIdent + (primitive 340282366841710300930663525764514709507) + TT) (only printing, at level 9) : expr_scope. + Notation "'RegMuLow' '>>' '128'" := + (AppIdent + (primitive 79228162514264337589248983038) + TT) (only printing, at level 9, format "'RegMuLow' '>>' '128'") : expr_scope. Notation "'Lower128{RegPinv}'" := (AppIdent (primitive 79228162514264337593543950337) @@ -8455,6 +8629,8 @@ Module Montgomery256PrintingNotations. Notation "$ n '_hi'" := (snd @@ (Var n))%expr (at level 10, format "$ n _hi") : expr_scope. Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Var n)))%expr (at level 10, format "$ n _lo") : expr_scope. Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Var n)))%expr (at level 10, format "$ n _hi") : expr_scope. + Notation "$ n '_lo'" := (fst @@ (Z.cast2 _ @@ Var n))%expr (at level 10, format "$ n _lo") : expr_scope. + Notation "$ n '_hi'" := (snd @@ (Z.cast2 _ @@ Var n))%expr (at level 10, format "$ n _hi") : expr_scope. Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _lo") : expr_scope. Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _hi") : expr_scope. Notation "'c.Mul128x128(' '$' n ',' x ',' y ');' f" := @@ -8472,12 +8648,21 @@ Module Montgomery256PrintingNotations. Notation "'c.Add64(' '$' n ',' x ',' y ');' f" := (expr_let n := Z.cast uint128 @@ (Z.add @@ (x, y)) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add64(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. - Notation "'c.Addc(' '$' n ',' x ',' y ');' f" := - (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (_, x, y)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. - Notation "'c.Selc(' '$' n ',' y ',' z ');' f" := - (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (_, y, z)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$' n ',' y ',' z ');' ']' '//' f") : expr_scope. + Notation "'c.Addc256(' '$' n ',' x ',' y ',' z ');' f" := + (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (x, y, z)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc256(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope. + Notation "'c.Addc128(' '$' n ',' x ',' y ',' z ');' f" := + (expr_let n := Z.cast2 (uint128, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (x, y, z)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc128(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope. + Notation "'c.Selc(' '$' n ',' x ',' y ',' z ');' f" := + (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (x , y, z)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope. + Notation "'c.Selm(' '$' n ',' x ',' y ',' z ');' f" := + (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (Z.cast (r[0 ~>1]) @@ ((Z.cc_m_concrete _) @@ x), y, z)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selm(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope. + Notation "'c.Sell(' '$' n ',' x ',' y ',' z ');' f" := + (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (Z.cast (r[0~>1]) @@ (Z.land 1 @@ _ x), y, z)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Sell(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope. Notation "'c.Sub(' '$' n ',' x ',' y ');' f" := (expr_let n := Z.cast uint256 @@ (fst @@ (Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)))) in f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope. @@ -8485,6 +8670,8 @@ Module Montgomery256PrintingNotations. (Z.cast uint256 @@ (Z.add_modulo @@ (x, y, z)))%expr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : expr_scope. Notation "'c.ShiftR(' '$' n ',' x ',' y ');' f" := (expr_let n := Z.cast _ @@ (Z.shiftr y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. + Notation "'c.Rshi(' '$' n ',' x ',' y ',' m ');' f" := + (expr_let n := Z.cast _ @@ (Z.rshi_concrete $R m @@ (x, y)) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Rshi(' '$' n ',' x ',' y ',' m ');' ']' '//' f") : expr_scope. Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" := (expr_let n := Z.cast _ @@ (Z.shiftl y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. Notation "'c.Lower128(' '$' n ',' x ');' f" := @@ -8510,11 +8697,60 @@ Module Montgomery256PrintingNotations. := (Z.cast uint256 @@ (Z.mul @@ (x, y))) : expr_scope. *) -End Montgomery256PrintingNotations. +End FancyPrintingNotations. -Import Montgomery256PrintingNotations. +Import FancyPrintingNotations. Local Open Scope expr_scope. +Print Barrett256.barrett_red256. +(* +c.Selm($x0, $x_hi, RegZero, RegMuLow); +c.Rshi($x1, RegZero, $x_hi, 255); +c.Rshi($x2, $x_hi, $x_lo, 255); +c.ShiftR($x3, $x2, 128); +c.Lower128($x4, $x2); +c.Mul128x128($x5, RegMuLow >> 128, $x4); +c.ShiftR($x6, $x5, 128); +c.Lower128($x7, $x5); +c.Mul128x128($x8, Lower128{RegMuLow}, $x3); +c.ShiftR($x9, $x8, 128); +c.Lower128($x10, $x8); +c.Mul128x128($x11, RegMuLow >> 128, $x3); +c.ShiftL($x12, $x7, 128); +c.ShiftL($x13, $x10, 128); +c.Mul128x128($x14, Lower128{RegMuLow}, $x4); +c.Add256($x15, $x13, $x14); +c.Addc128($x16, $x15_hi, $x6, $x9); +c.Add256($x17, $x12, $x15_lo); +c.Addc256($x18, $x17_hi, $x11, $x16_lo); +c.Add256($x19, $x2, $x18_lo); +c.Addc128($x20, $x19_hi, RegZero, $x1); +c.Add256($x21, $x0, $x19_lo); +c.Addc128($x22, $x21_hi, RegZero, $x20_lo); +c.Rshi($x23, $x22_lo, $x21_lo, 1); +c.ShiftR($x24, $x23, 128); +c.Lower128($x25, $x23); +c.Mul128x128($x26, Lower128{RegMod}, $x24); +c.ShiftR($x27, $x26, 128); +c.Lower128($x28, $x26); +c.Mul128x128($x29, RegMod >> 128, $x25); +c.ShiftR($x30, $x29, 128); +c.Lower128($x31, $x29); +c.Mul128x128($x32, RegMod >> 128, $x24); +c.ShiftL($x33, $x28, 128); +c.ShiftL($x34, $x31, 128); +c.Mul128x128($x35, Lower128{RegMod}, $x25); +c.Add256($x36, $x34, $x35); +c.Addc256($x37, $x36_hi, $x27, $x30); +c.Add256($x38, $x33, $x36_lo); +c.Addc256($x39, $x38_hi, $x32, $x37_lo); +c.Add256($x40, Z.opp @@ $x38_lo, $x_lo); +c.Addc256($x41, $x40_hi, Z.opp @@ $x39_lo, $x_hi); +c.Sell($x42, $x41_lo, RegZero, RegMod); +c.Sub($x43, $x40_lo, $x42); +c.AddM($ret, $x43, RegZero, RegMod); +*) + Print Montgomery256.montred256. (* c.ShiftR($x0, $x_lo, 128); @@ -8541,12 +8777,12 @@ c.ShiftL($x20, $x15, 128); c.ShiftL($x21, $x18, 128); c.Mul128x128($x22, Lower128{RegMod}, $x12); c.Add256($x23, $x21, $x22); -c.Addc($x24, $x14, $x17); +c.Addc256($x24, $x23_hi, $x14, $x17); c.Add256($x25, $x20, $x23_lo); -c.Addc($x26, $x19, $x24_lo); +c.Addc256($x26, $x25_hi, $x19, $x24_lo); c.Add256($x27, $x25_lo, $x_lo); -c.Addc($x28, $x26_lo, $x_hi); -c.Selc($x29,RegZero, RegMod); +c.Addc256($x28, $x27_hi, $x26_lo, $x_hi); +c.Selc($x29, $x28_hi, RegZero, RegMod); c.Sub($x30, $x28_lo, $x29); c.AddM($ret, $x30, RegZero, RegMod); *) |