aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-17 14:33:00 +0200
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-30 04:20:04 -0400
commitc95a7ed3892bed4e86500a5a7339181ce2b6d00c (patch)
treedc8cc1fe8698b88cdd576fa386b74a743ad2e330 /src/Experiments/SimplyTypedArithmetic.v
parent6cbd9dba8e259d7ee3eda867fd1b0f5512da90f6 (diff)
first stab at reifying barrett
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v448
1 files changed, 342 insertions, 106 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 853b28545..98d445e6e 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -7797,13 +7797,14 @@ Module BarrettReduction.
(muSelect_correct: muSelect = mu mod 2 ^ k * (x / 2 ^ (k - 1) / 2 ^ k)).
Definition qt :=
- let q1 := shiftr xt (k - 1) in
- let twoq := mul_high mut q1 muSelect in
+ dlet_nd muSelect := muSelect in (* makes sure muSelect is not inlined in the output *)
+ dlet_nd q1 := shiftr xt (k - 1) in
+ dlet_nd twoq := mul_high mut q1 muSelect in
shiftr twoq 1.
Definition reduce :=
- let r2 := mul (low qt) M in
- let r := sub xt r2 in
- let q3 := cond_sub1 r M in
+ dlet_nd r2 := mul (low qt) M in
+ dlet_nd r := sub xt r2 in
+ dlet_nd q3 := cond_sub1 r M in
cond_sub2 q3 M.
Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k).
@@ -7881,7 +7882,7 @@ Module BarrettReduction.
pose proof looser_bound. pose proof r_bounds. pose proof q_bounds.
assert (2 * M < 2^k * 2^k) by nia.
rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega).
- cbv [reduce].
+ cbv [reduce Let_In].
erewrite low_correct by eauto. Z.rewrite_mod_small.
erewrite two_conditional_subtracts by solve_rep.
rewrite !cond_sub2_correct.
@@ -7890,14 +7891,13 @@ Module BarrettReduction.
End Generic.
Section BarrettReduction.
- Context (k : Z) (Hk_positive : 0 < k).
+ Context (k : Z) (k_bound : 2 <= k).
Context (M muLow : Z).
Context (M_pos : 0 < M)
(muLow_eq : muLow + 2^k = 2^(2*k) / M)
(muLow_bounds : 0 <= muLow < 2^k)
(M_bound1 : 2 ^ (k - 1) < M < 2^k)
(M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)).
- Context (pow2_k_bound : 4 <= 2^k).
Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k).
Context (nout : nat) (Hnout : nout = 2%nat).
@@ -7907,30 +7907,37 @@ Module BarrettReduction.
Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval.
- Let T : Type := Z * Z.
+ Definition low (t : list Z) : Z := nth_default 0 t 0.
+ Definition high (t : list Z) : Z := nth_default 0 t 1.
+ Definition represents (t : list Z) (x : Z) :=
+ t = [x mod 2^k; x / 2^k] /\ 0 <= x < 2^k * 2^k.
- Definition represents (t : T) (x : Z) :=
- fst t = x mod 2^k /\ snd t = x / 2^k /\ 0 <= x < 2^k * 2^k.
-
- Lemma represents_fst t x :
- represents t x -> fst t = x mod 2^k.
+ Lemma represents_eq t x :
+ represents t x -> t = [x mod 2^k; x / 2^k].
Proof. cbv [represents]; tauto. Qed.
- Lemma represents_snd t x :
- represents t x -> snd t = x / 2^k.
- Proof. cbv [represents]; tauto. Qed.
+ Lemma represents_length t x : represents t x -> length t = 2%nat.
+ Proof. cbv [represents]; intuition. subst t; reflexivity. Qed.
+
+ Lemma represents_low t x :
+ represents t x -> low t = x mod 2^k.
+ Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed.
- Lemma represents_fst_range t x :
+ Lemma represents_high t x :
+ represents t x -> high t = x / 2^k.
+ Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed.
+
+ Lemma represents_low_range t x :
represents t x -> 0 <= x mod 2^k < 2^k.
Proof. auto with zarith. Qed.
- Hint Resolve represents_fst_range.
- Lemma represents_snd_range t x :
+
+ Lemma represents_high_range t x :
represents t x -> 0 <= x / 2^k < 2^k.
Proof.
destruct 1 as [? [? ?] ]; intros.
auto using Z.div_lt_upper_bound with zarith.
Qed.
- Hint Resolve represents_snd_range.
+ Hint Resolve represents_length represents_low_range represents_high_range.
Lemma represents_range t x :
represents t x -> 0 <= x < 2^k*2^k.
@@ -7938,7 +7945,7 @@ Module BarrettReduction.
Lemma represents_id x :
0 <= x < 2^k * 2^k ->
- represents (x mod 2^k, x / 2^k) x.
+ represents [x mod 2^k; x / 2^k] x.
Proof.
intros; cbv [represents]; autorewrite with cancel_pair.
Z.rewrite_mod_small; tauto.
@@ -7946,14 +7953,14 @@ Module BarrettReduction.
Local Ltac push_rep :=
repeat match goal with
- | H : represents ?t ?x |- _ => unique pose proof (represents_fst_range _ _ H)
- | H : represents ?t ?x |- _ => unique pose proof (represents_snd_range _ _ H)
- | H : represents ?t ?x |- _ => rewrite (represents_fst t x) in * by assumption
- | H : represents ?t ?x |- _ => rewrite (represents_snd t x) in * by assumption
+ | H : represents ?t ?x |- _ => unique pose proof (represents_low_range _ _ H)
+ | H : represents ?t ?x |- _ => unique pose proof (represents_high_range _ _ H)
+ | H : represents ?t ?x |- _ => rewrite (represents_low t x) in * by assumption
+ | H : represents ?t ?x |- _ => rewrite (represents_high t x) in * by assumption
end.
- Definition shiftr (t : T) (n : Z) : T :=
- (Z.rshi (2^k) (snd t) (fst t) n, Z.rshi (2^k) 0 (snd t) n).
+ Definition shiftr (t : list Z) (n : Z) : list Z :=
+ [Z.rshi (2^k) (high t) (low t) n; Z.rshi (2^k) 0 (high t) n].
Lemma shiftr_represents a i x :
represents a x ->
@@ -7971,7 +7978,7 @@ Module BarrettReduction.
| _ => rewrite <-Z.div_mod''' by auto with zarith
| _ => progress autorewrite with zsimplify_fast
| _ => progress Z.rewrite_mod_small
- | |- context [represents ((?a / ?c) mod ?b, ?a / ?b / ?c)] =>
+ | |- context [represents [(?a / ?c) mod ?b; ?a / ?b / ?c] ] =>
rewrite (Z.div_div_comm a b c) by auto with zarith
| _ => solve [auto using represents_id, Z.div_lt_upper_bound with zarith lia]
end.
@@ -7980,42 +7987,34 @@ Module BarrettReduction.
Context (Hw : forall i, w i = (2 ^ k) ^ Z.of_nat i).
Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r.
- Definition wideadd t1 t2 :=
- let sum := fst (Rows.add w 2 [fst t1; snd t1] [fst t2; snd t2]) in
- (nth_default 0 sum 0, nth_default 0 sum 1).
+ Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2).
+ Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2).
+ Definition widemul := BaseConversion.widemul k n nout.
- Definition widesub t1 t2 :=
- let sum := fst (Rows.sub w 2 [fst t1; snd t1] [fst t2; snd t2]) in
- (nth_default 0 sum 0, nth_default 0 sum 1).
-
- Definition widemul x y :=
- let xy := BaseConversion.widemul k n nout x y in
- (nth_default 0 xy 0, nth_default 0 xy 1).
-
- Lemma partition_represents x y :
- 0 <= y < 2^k*2^k ->
- x = Rows.partition w 2 y ->
- represents (nth_default 0 x 0, nth_default 0 x 1) y.
+ Lemma partition_represents x :
+ 0 <= x < 2^k*2^k ->
+ represents (Rows.partition w 2 x) x.
Proof.
- intros; subst x; cbv [represents Rows.partition].
- cbn; change_weight. Z.rewrite_mod_small.
- auto with zarith.
+ intros; cbn. change_weight.
+ Z.rewrite_mod_small.
+ autorewrite with zsimplify_fast.
+ auto using represents_id.
Qed.
Lemma eval_represents t x :
- represents t x ->
- eval w 2 [fst t; snd t] = x.
+ represents t x -> eval w 2 t = x.
Proof.
- intros; cbn. change_weight; push_rep.
+ intros; rewrite (represents_eq t x) by assumption.
+ cbn. change_weight; push_rep.
autorewrite with zsimplify. reflexivity.
Qed.
Ltac wide_op partitions_pf :=
repeat match goal with
- | _ => apply partition_represents; auto with zarith; [ ]
- | _ => rewrite partitions_pf by auto
+ | _ => rewrite partitions_pf by eauto
+ | _ => rewrite partitions_pf by auto with zarith
| _ => erewrite eval_represents by eauto
- | _ => reflexivity
+ | _ => solve [auto using partition_represents, represents_id]
end.
Lemma wideadd_represents t1 t2 x y :
@@ -8038,30 +8037,29 @@ Module BarrettReduction.
represents (widemul x y) (x * y).
Proof.
intros; cbv [widemul].
- rewrite BaseConversion.widemul_correct by auto with zarith.
- autorewrite with push_nth_default.
- auto using represents_id with zarith.
+ assert (0 <= x * y < 2^k*2^k) by auto with zarith.
+ wide_op BaseConversion.widemul_correct.
Qed.
- Definition mul_high (a b : T) a0b1 : T :=
- let a0b0 := widemul (fst a) (fst b) in
- let ab := wideadd (snd a0b0, snd b) (fst b, 0) in
- wideadd ab (a0b1, 0).
+ Definition mul_high (a b : list Z) a0b1 : list Z :=
+ dlet_nd a0b0 := widemul (low a) (low b) in
+ dlet_nd ab := wideadd [high a0b0; high b] [low b; 0] in
+ wideadd ab [a0b1; 0].
- Lemma mul_high_idea s a b a0 a1 b0 b1 :
- s <> 0 ->
- a = s * a1 + a0 ->
- b = s * b1 + b0 ->
- (a * b) / s = a0 * b0 / s + s * a1 * b1 + a1 * b0 + a0 * b1.
+ Lemma mul_high_idea d a b a0 a1 b0 b1 :
+ d <> 0 ->
+ a = d * a1 + a0 ->
+ b = d * b1 + b0 ->
+ (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1.
Proof.
intros. subst a b. autorewrite with push_Zmul.
ring_simplify_subterms. rewrite Z.pow_2_r.
rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega).
repeat match goal with
- | |- context [s * ?a * ?b * ?c] =>
- replace (s * a * b * c) with (a * b * c * s) by ring
- | |- context [s * ?a * ?b] =>
- replace (s * a * b) with (a * b * s) by ring
+ | |- context [d * ?a * ?b * ?c] =>
+ replace (d * a * b * c) with (a * b * c * d) by ring
+ | |- context [d * ?a * ?b] =>
+ replace (d * a * b) with (a * b * d) by ring
end.
rewrite !Z.div_add by omega.
autorewrite with zsimplify.
@@ -8074,16 +8072,18 @@ Module BarrettReduction.
represents t x.
Proof. congruence. Qed.
- Lemma represents_add a b x y :
- a = x -> b = y ->
+ Lemma represents_add x y :
0 <= x < 2 ^ k ->
0 <= y < 2 ^ k ->
- represents (a,b) (x + 2^k*y).
- Proof. intros; subst a b; repeat split; autorewrite with cancel_pair zsimplify; nia. Qed.
+ represents [x;y] (x + 2^k*y).
+ Proof.
+ intros; cbv [represents]; autorewrite with zsimplify.
+ repeat split; (reflexivity || nia).
+ Qed.
Lemma represents_small x :
0 <= x < 2^k ->
- represents (x, 0) x.
+ represents [x; 0] x.
Proof.
intros.
eapply represents_trans.
@@ -8099,10 +8099,11 @@ Module BarrettReduction.
a0b1 = x mod 2^k * (y / 2^k) ->
represents (mul_high a b a0b1) ((x * y) / 2^k).
Proof.
- cbv [mul_high]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros.
+ cbv [mul_high Let_In]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros.
+ assert (4 <= 2 ^ k) by (transitivity (Z.pow 2 2); auto with zarith).
assert (0 <= x * y / 2^k < 2^k*2^k) by (Z.div_mod_to_quot_rem; nia).
- rewrite mul_high_idea with (a:=x) (b:=y) (a0 := fst a) (a1 := snd a) (b0 := fst b) (b1 := snd b) in *
+ rewrite mul_high_idea with (a:=x) (b:=y) (a0 := low a) (a1 := high a) (b0 := low b) (b1 := high b) in *
by (push_rep; Z.div_mod_to_quot_rem; lia).
push_rep. subst a0b1.
@@ -8113,7 +8114,7 @@ Module BarrettReduction.
eapply represents_trans.
{ repeat (apply wideadd_represents;
[ | apply represents_small; Z.div_mod_to_quot_rem; nia| ]).
- erewrite represents_snd; [ | apply widemul_represents; solve [ auto with zarith ] ].
+ erewrite represents_high; [ | apply widemul_represents; solve [ auto with zarith ] ].
{ apply represents_add; try reflexivity; solve [auto with zarith]. }
{ match goal with H : 0 <= ?x + ?y < ?z |- 0 <= ?x < ?z =>
split; [ solve [Z.zero_bounds] | ];
@@ -8123,11 +8124,11 @@ Module BarrettReduction.
{ ring. }
Qed.
- Definition cond_sub1 (a : T) y : Z :=
- let maybe_y := Z.zselect (Z.cc_l (snd a)) 0 y in
- fst (Z.sub_get_borrow_full (2^k) (fst a) maybe_y).
+ Definition cond_sub1 (a : list Z) y : Z :=
+ dlet_nd maybe_y := Z.zselect (Z.cc_l (high a)) 0 y in
+ fst (Z.sub_get_borrow_full (2^k) (low a) maybe_y).
- Lemma cc_l_only_bit x s: 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s.
+ Lemma cc_l_only_bit : forall x s, 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s.
Proof.
cbv [Z.cc_l]; intros.
rewrite Z.div_between_0_if by omega.
@@ -8140,7 +8141,7 @@ Module BarrettReduction.
0 <= y < 2 ^ k ->
cond_sub1 a y = if (x <? 2 ^ k) then x else x - y.
Proof.
- intros; cbv [cond_sub1]. rewrite Z.zselect_correct. push_rep.
+ intros; cbv [cond_sub1 Let_In]. rewrite Z.zselect_correct. push_rep.
break_match; Z.ltb_to_lt; rewrite cc_l_only_bit in *; try omega;
autorewrite with zsimplify_fast to_div_mod pull_Zmod; auto with zarith.
Qed.
@@ -8155,7 +8156,7 @@ Module BarrettReduction.
Section Defn.
Context (xLow xHigh : Z) (xLow_bounds : 0 <= xLow < 2^k) (xHigh_bounds : 0 <= xHigh < M).
- Let xt := (xLow, xHigh).
+ Let xt := [xLow; xHigh].
Let x := xLow + 2^k * xHigh.
Lemma x_rep : represents xt x.
@@ -8167,6 +8168,8 @@ Module BarrettReduction.
Definition muSelect := Z.zselect (Z.cc_m (2 ^ k) xHigh) 0 muLow.
Local Hint Resolve Z.div_nonneg Z.div_lt_upper_bound.
+ Local Hint Resolve shiftr_represents mul_high_represents widemul_represents widesub_represents
+ cond_sub1_correct cond_sub2_correct represents_low represents_add.
Lemma muSelect_correct :
muSelect = (2 ^ (2 * k) / M) mod 2 ^ k * ((x / 2 ^ (k - 1)) / 2 ^ k).
@@ -8177,9 +8180,10 @@ Module BarrettReduction.
assert (0 <= x / (2 ^ k * (2 ^ k / 2)) < 2) by (Z.div_mod_to_quot_rem; auto with nia).
assert (0 < 2 ^ k / 2) by Z.zero_bounds.
assert (2 ^ (k - 1) <> 0) by auto with zarith.
+ assert (2 < 2 ^ k) by (eapply Z.le_lt_trans with (m:=2 ^ 1); auto with zarith).
cbv [muSelect]. rewrite <-muLow_eq.
- rewrite Z.zselect_correct, Z.cc_m_eq by nia.
+ rewrite Z.zselect_correct, Z.cc_m_eq by auto with zarith.
replace xHigh with (x / 2^k) by (subst x; autorewrite with zsimplify; lia).
autorewrite with pull_Zdiv push_Zpow.
rewrite (Z.mul_comm (2 ^ k / 2)).
@@ -8188,22 +8192,180 @@ Module BarrettReduction.
autorewrite with zsimplify; reflexivity.
Qed.
- Definition barrett_reduce : Z :=
- reduce k fst shiftr mul_high widemul widesub cond_sub1 cond_sub2 xt (muLow, 1) M muSelect.
+ Lemma mu_rep : represents [muLow; 1] (2 ^ (2 * k) / M).
+ Proof. rewrite <-muLow_eq. eapply represents_trans; auto with zarith. Qed.
- Lemma barrett_reduce_correct :
- barrett_reduce = x mod M.
+ Derive barrett_reduce
+ SuchThat (barrett_reduce = x mod M)
+ As barrett_reduce_correct.
Proof.
- intros; cbv [barrett_reduce].
- apply reduce_correct with (rep:=represents); try omega;
- auto using shiftr_represents, mul_high_represents, widemul_represents, widesub_represents,
- cond_sub1_correct, cond_sub2_correct, x_bounds, muSelect_correct, represents_fst, x_rep.
- rewrite <-muLow_eq. cbv [represents]; repeat split; autorewrite with cancel_pair zsimplify; nia.
+ erewrite <-reduce_correct with (rep:=represents) (muSelect:=muSelect) (k0:=k) (mut:=[muLow;1]) (xt0:=xt)
+ by (auto using x_bounds, muSelect_correct, x_rep, mu_rep; omega).
+ subst barrett_reduce. reflexivity.
Qed.
End Defn.
End BarrettReduction.
+
+ (* all the list operations from for_reification.ident *)
+ Strategy 100 [length seq repeat combine map flat_map partition app rev fold_right update_nth nth_default ].
+
+ Derive barrett_red_gen
+ SuchThat (forall (k M muLow : Z)
+ (n nout: nat)
+ (xLow xHigh : Z),
+ Interp (t:=type.reify_type_of barrett_reduce)
+ barrett_red_gen k M muLow n nout xLow xHigh
+ = barrett_reduce k M muLow n nout xLow xHigh)
+ As barrett_red_gen_correct.
+ Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed.
+ (* TODO : reification here is still quite slow (~40s on a beefy machine). Possibly just due to size of term, but warrants further investigation. *)
+ Module Export ReifyHints.
+ Global Hint Extern 1 (_ = barrett_reduce _ _ _ _ _ _ _) => simple apply barrett_red_gen_correct : reify_gen_cache.
+ End ReifyHints.
+
+ Section rbarrett_red.
+ Context (M : Z)
+ (machine_wordsize : Z).
+
+ Let bound := Some r[0 ~> (2^machine_wordsize - 1)%Z]%zrange.
+ Let mu := (2 ^ (2 * machine_wordsize)) / M.
+ Let muLow := mu mod (2 ^ machine_wordsize).
+
+ Check barrett_reduce_correct.
+ Print Pipeline.Values_not_provably_distinct.
+
+ Definition relax_zrange_of_machine_wordsize
+ := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z.
+ Local Arguments relax_zrange_of_machine_wordsize / .
+
+ Let relax_zrange := relax_zrange_of_machine_wordsize.
+
+ Definition check_args {T} (res : Pipeline.ErrorT T)
+ : Pipeline.ErrorT T
+ := if (mu / (2 ^ machine_wordsize) =? 0)
+ then Pipeline.Error (Pipeline.Values_not_provably_distinct "mu / 2 ^ k ≠ 0" (mu / 2 ^ machine_wordsize) 0)
+ else if (machine_wordsize <? 2)
+ then Pipeline.Error (Pipeline.Value_not_le "~ (2 <=k)" 2 machine_wordsize)
+ else if (negb (Z.log2 M + 1 =? machine_wordsize))
+ then Pipeline.Error
+ (Pipeline.Values_not_provably_equal "log2(M)+1 != k" (Z.log2 M + 1) machine_wordsize)
+ else if (2 ^ (machine_wordsize + 1) - mu <? 2 * (2 ^ (2 * machine_wordsize) mod M))
+ then Pipeline.Error
+ (Pipeline.Value_not_le "~ (2 * (2 ^ (2*k) mod M) <= 2^(k + 1) - mu)"
+ (2 * (2 ^ (2*machine_wordsize) mod M))
+ (2^(machine_wordsize + 1) - mu))
+ else res.
+
+ Notation BoundsPipeline_correct in_bounds out_bounds op
+ := (fun rv (rop : Expr (type.reify_type_of op)) Hrop
+ => @Pipeline.BoundsPipeline_correct_trans
+ false (* subst01 *)
+ relax_zrange
+ (relax_zrange_gen_good _)
+ _
+ rop
+ in_bounds
+ out_bounds
+ op
+ Hrop rv)
+ (only parsing).
+
+ Definition rbarrett_red_correct
+ := BoundsPipeline_correct
+ (bound, bound)
+ bound
+ (barrett_reduce machine_wordsize M muLow 2 2).
+
+ Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _).
+ Definition rbarrett_red_correctT rv : Prop
+ := type_of_strip_3arrow (@rbarrett_red_correct rv).
+ End rbarrett_red.
End BarrettReduction.
+Ltac solve_rbarrett_red := solve_rop BarrettReduction.rbarrett_red_correct.
+Ltac solve_rbarrett_red_nocache := solve_rop_nocache BarrettReduction.rbarrett_red_correct.
+
+Module Barrett256.
+
+ Definition M := (2^256-2^224+2^192+2^96-1).
+ Definition machine_wordsize := 256.
+
+ (* TODO : why does this not bounds check?
+ Let F := Some r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange.
+ Eval vm_compute in (
+ partial.bounds.expr.extract' (fun t : type => id)
+ (λ x : partial.data (type.type_primitive type.Z * type.type_primitive type.Z),
+ ident.Z.cast2
+ (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange,
+ r[0 ~> 1]%zrange)%core @@
+ (ident.Z.add_get_carry_concrete
+ 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@
+ (ident.Z.cast r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935] @@
+ (ident.fst @@ (Var x)), ident.primitive (t:=type.Z) 26959946667150639793205513449348445388433292963828203772348655992835 @@ TT ))) (F, F)).
+ *)
+ Derive barrett_red256
+ SuchThat (BarrettReduction.rbarrett_red_correctT M machine_wordsize barrett_red256)
+ As barrett_red256_correct.
+ Proof. Time solve_rbarrett_red machine_wordsize. Time Qed.
+
+ Import PrintingNotations.
+ Open Scope expr_scope.
+ Set Printing Width 100000.
+
+ Print barrett_red256.
+ (* TODO: the ADD/ADDC instructions containing Z.opp should be translated to SUB/SUBB in partial evaluation *)
+ (*
+barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype,
+ expr_let x0 := SELM (x₂, 0, 26959946667150639793205513449348445388433292963828203772348655992835) in
+ expr_let x1 := RSHI (0, x₂, 255) in
+ expr_let x2 := RSHI (x₂, x₁, 255) in
+ expr_let x3 := (uint128)(x2 >> 128) in
+ expr_let x4 := ((uint128)(x2) & 340282366920938463463374607431768211455) in
+ expr_let x5 := 79228162514264337589248983038 *₂₅₆ x4 in
+ expr_let x6 := (uint128)(x5 >> 128) in
+ expr_let x7 := ((uint128)(x5) & 340282366920938463463374607431768211455) in
+ expr_let x8 := 340282366841710300930663525764514709507 *₂₅₆ x3 in
+ expr_let x9 := (uint128)(x8 >> 128) in
+ expr_let x10 := ((uint128)(x8) & 340282366920938463463374607431768211455) in
+ expr_let x11 := 79228162514264337589248983038 *₂₅₆ x3 in
+ expr_let x12 := (uint256)(x7 << 128) in
+ expr_let x13 := (uint256)(x10 << 128) in
+ expr_let x14 := 340282366841710300930663525764514709507 *₂₅₆ x4 in
+ expr_let x15 := ADD_256 (x13, x14) in
+ expr_let x16 := ADDC_128 (x15₂, x6, x9) in
+ expr_let x17 := ADD_256 (x12, x15₁) in
+ expr_let x18 := ADDC_256 (x17₂, x11, x16₁) in
+ expr_let x19 := ADD_256 (x2, x18₁) in
+ expr_let x20 := ADDC_128 (x19₂, 0, x1) in
+ expr_let x21 := ADD_256 (x0, x19₁) in
+ expr_let x22 := ADDC_128 (x21₂, 0, x20₁) in
+ expr_let x23 := RSHI (x22₁, x21₁, 1) in
+ expr_let x24 := (uint128)(x23 >> 128) in
+ expr_let x25 := ((uint128)(x23) & 340282366920938463463374607431768211455) in
+ expr_let x26 := 79228162514264337593543950335 *₂₅₆ x24 in
+ expr_let x27 := (uint128)(x26 >> 128) in
+ expr_let x28 := ((uint128)(x26) & 340282366920938463463374607431768211455) in
+ expr_let x29 := 340282366841710300967557013911933812736 *₂₅₆ x25 in
+ expr_let x30 := (uint128)(x29 >> 128) in
+ expr_let x31 := ((uint128)(x29) & 340282366920938463463374607431768211455) in
+ expr_let x32 := 340282366841710300967557013911933812736 *₂₅₆ x24 in
+ expr_let x33 := (uint256)(x28 << 128) in
+ expr_let x34 := (uint256)(x31 << 128) in
+ expr_let x35 := 79228162514264337593543950335 *₂₅₆ x25 in
+ expr_let x36 := ADD_256 (x34, x35) in
+ expr_let x37 := ADDC_256 (x36₂, x27, x30) in
+ expr_let x38 := ADD_256 (x33, x36₁) in
+ expr_let x39 := ADDC_256 (x38₂, x32, x37₁) in
+ expr_let x40 := ADD_256 (Z.opp @@ (fst @@ x38), x₁) in
+ expr_let x41 := ADDC_256 (x40₂, Z.opp @@ (fst @@ x39), x₂) in
+ expr_let x42 := SELL (x41₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
+ expr_let x43 := Z.cast uint256 @@ (fst @@ SUB_256 (x40₁, x42)) in
+ ADDM (x43, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
+ : Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z))
+ *)
+
+End Barrett256.
+
Module MontgomeryReduction.
Section MontRed'.
Context (N R N' R' : Z).
@@ -8412,7 +8574,7 @@ montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z *
End Montgomery256.
(* Extra-specialized ad-hoc pretty-printing *)
-Module Montgomery256PrintingNotations.
+Module FancyPrintingNotations.
Export ident.
Open Scope expr_scope.
Open Scope ctype_scope.
@@ -8420,6 +8582,10 @@ Module Montgomery256PrintingNotations.
(AppIdent
(primitive 115792089210356248762697446949407573530086143415290314195533631308867097853951)
TT) (only printing, at level 9) : expr_scope.
+ Notation "'RegMuLow'" :=
+ (AppIdent
+ (primitive 26959946667150639793205513449348445388433292963828203772348655992835)
+ TT) (only printing, at level 9) : expr_scope.
Notation "'RegPinv'" :=
(AppIdent
(primitive 115792089210356248768974548684794254293921932838497980611635986753331132366849)
@@ -8437,6 +8603,14 @@ Module Montgomery256PrintingNotations.
(AppIdent
(primitive 340282366841710300967557013911933812736)
TT) (only printing, at level 9, format "'RegMod' '>>' '128'") : expr_scope.
+ Notation "'Lower128{RegMuLow}'" :=
+ (AppIdent
+ (primitive 340282366841710300930663525764514709507)
+ TT) (only printing, at level 9) : expr_scope.
+ Notation "'RegMuLow' '>>' '128'" :=
+ (AppIdent
+ (primitive 79228162514264337589248983038)
+ TT) (only printing, at level 9, format "'RegMuLow' '>>' '128'") : expr_scope.
Notation "'Lower128{RegPinv}'" :=
(AppIdent
(primitive 79228162514264337593543950337)
@@ -8455,6 +8629,8 @@ Module Montgomery256PrintingNotations.
Notation "$ n '_hi'" := (snd @@ (Var n))%expr (at level 10, format "$ n _hi") : expr_scope.
Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Var n)))%expr (at level 10, format "$ n _lo") : expr_scope.
Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Var n)))%expr (at level 10, format "$ n _hi") : expr_scope.
+ Notation "$ n '_lo'" := (fst @@ (Z.cast2 _ @@ Var n))%expr (at level 10, format "$ n _lo") : expr_scope.
+ Notation "$ n '_hi'" := (snd @@ (Z.cast2 _ @@ Var n))%expr (at level 10, format "$ n _hi") : expr_scope.
Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _lo") : expr_scope.
Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _hi") : expr_scope.
Notation "'c.Mul128x128(' '$' n ',' x ',' y ');' f" :=
@@ -8472,12 +8648,21 @@ Module Montgomery256PrintingNotations.
Notation "'c.Add64(' '$' n ',' x ',' y ');' f" :=
(expr_let n := Z.cast uint128 @@ (Z.add @@ (x, y)) in
f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add64(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
- Notation "'c.Addc(' '$' n ',' x ',' y ');' f" :=
- (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (_, x, y)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
- Notation "'c.Selc(' '$' n ',' y ',' z ');' f" :=
- (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (_, y, z)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$' n ',' y ',' z ');' ']' '//' f") : expr_scope.
+ Notation "'c.Addc256(' '$' n ',' x ',' y ',' z ');' f" :=
+ (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (x, y, z)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc256(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope.
+ Notation "'c.Addc128(' '$' n ',' x ',' y ',' z ');' f" :=
+ (expr_let n := Z.cast2 (uint128, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (x, y, z)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc128(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope.
+ Notation "'c.Selc(' '$' n ',' x ',' y ',' z ');' f" :=
+ (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (x , y, z)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope.
+ Notation "'c.Selm(' '$' n ',' x ',' y ',' z ');' f" :=
+ (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (Z.cast (r[0 ~>1]) @@ ((Z.cc_m_concrete _) @@ x), y, z)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selm(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope.
+ Notation "'c.Sell(' '$' n ',' x ',' y ',' z ');' f" :=
+ (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (Z.cast (r[0~>1]) @@ (Z.land 1 @@ _ x), y, z)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Sell(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope.
Notation "'c.Sub(' '$' n ',' x ',' y ');' f" :=
(expr_let n := Z.cast uint256 @@ (fst @@ (Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)))) in
f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope.
@@ -8485,6 +8670,8 @@ Module Montgomery256PrintingNotations.
(Z.cast uint256 @@ (Z.add_modulo @@ (x, y, z)))%expr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : expr_scope.
Notation "'c.ShiftR(' '$' n ',' x ',' y ');' f" :=
(expr_let n := Z.cast _ @@ (Z.shiftr y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
+ Notation "'c.Rshi(' '$' n ',' x ',' y ',' m ');' f" :=
+ (expr_let n := Z.cast _ @@ (Z.rshi_concrete $R m @@ (x, y)) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Rshi(' '$' n ',' x ',' y ',' m ');' ']' '//' f") : expr_scope.
Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" :=
(expr_let n := Z.cast _ @@ (Z.shiftl y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
Notation "'c.Lower128(' '$' n ',' x ');' f" :=
@@ -8510,11 +8697,60 @@ Module Montgomery256PrintingNotations.
:= (Z.cast uint256 @@ (Z.mul @@ (x, y)))
: expr_scope.
*)
-End Montgomery256PrintingNotations.
+End FancyPrintingNotations.
-Import Montgomery256PrintingNotations.
+Import FancyPrintingNotations.
Local Open Scope expr_scope.
+Print Barrett256.barrett_red256.
+(*
+c.Selm($x0, $x_hi, RegZero, RegMuLow);
+c.Rshi($x1, RegZero, $x_hi, 255);
+c.Rshi($x2, $x_hi, $x_lo, 255);
+c.ShiftR($x3, $x2, 128);
+c.Lower128($x4, $x2);
+c.Mul128x128($x5, RegMuLow >> 128, $x4);
+c.ShiftR($x6, $x5, 128);
+c.Lower128($x7, $x5);
+c.Mul128x128($x8, Lower128{RegMuLow}, $x3);
+c.ShiftR($x9, $x8, 128);
+c.Lower128($x10, $x8);
+c.Mul128x128($x11, RegMuLow >> 128, $x3);
+c.ShiftL($x12, $x7, 128);
+c.ShiftL($x13, $x10, 128);
+c.Mul128x128($x14, Lower128{RegMuLow}, $x4);
+c.Add256($x15, $x13, $x14);
+c.Addc128($x16, $x15_hi, $x6, $x9);
+c.Add256($x17, $x12, $x15_lo);
+c.Addc256($x18, $x17_hi, $x11, $x16_lo);
+c.Add256($x19, $x2, $x18_lo);
+c.Addc128($x20, $x19_hi, RegZero, $x1);
+c.Add256($x21, $x0, $x19_lo);
+c.Addc128($x22, $x21_hi, RegZero, $x20_lo);
+c.Rshi($x23, $x22_lo, $x21_lo, 1);
+c.ShiftR($x24, $x23, 128);
+c.Lower128($x25, $x23);
+c.Mul128x128($x26, Lower128{RegMod}, $x24);
+c.ShiftR($x27, $x26, 128);
+c.Lower128($x28, $x26);
+c.Mul128x128($x29, RegMod >> 128, $x25);
+c.ShiftR($x30, $x29, 128);
+c.Lower128($x31, $x29);
+c.Mul128x128($x32, RegMod >> 128, $x24);
+c.ShiftL($x33, $x28, 128);
+c.ShiftL($x34, $x31, 128);
+c.Mul128x128($x35, Lower128{RegMod}, $x25);
+c.Add256($x36, $x34, $x35);
+c.Addc256($x37, $x36_hi, $x27, $x30);
+c.Add256($x38, $x33, $x36_lo);
+c.Addc256($x39, $x38_hi, $x32, $x37_lo);
+c.Add256($x40, Z.opp @@ $x38_lo, $x_lo);
+c.Addc256($x41, $x40_hi, Z.opp @@ $x39_lo, $x_hi);
+c.Sell($x42, $x41_lo, RegZero, RegMod);
+c.Sub($x43, $x40_lo, $x42);
+c.AddM($ret, $x43, RegZero, RegMod);
+*)
+
Print Montgomery256.montred256.
(*
c.ShiftR($x0, $x_lo, 128);
@@ -8541,12 +8777,12 @@ c.ShiftL($x20, $x15, 128);
c.ShiftL($x21, $x18, 128);
c.Mul128x128($x22, Lower128{RegMod}, $x12);
c.Add256($x23, $x21, $x22);
-c.Addc($x24, $x14, $x17);
+c.Addc256($x24, $x23_hi, $x14, $x17);
c.Add256($x25, $x20, $x23_lo);
-c.Addc($x26, $x19, $x24_lo);
+c.Addc256($x26, $x25_hi, $x19, $x24_lo);
c.Add256($x27, $x25_lo, $x_lo);
-c.Addc($x28, $x26_lo, $x_hi);
-c.Selc($x29,RegZero, RegMod);
+c.Addc256($x28, $x27_hi, $x26_lo, $x_hi);
+c.Selc($x29, $x28_hi, RegZero, RegMod);
c.Sub($x30, $x28_lo, $x29);
c.AddM($ret, $x30, RegZero, RegMod);
*)