aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-05-30 17:02:46 +0200
committerGravatar Jade Philipoom <jadep@google.com>2018-05-31 15:13:41 +0200
commite6119c9595326a910d177488bf44aab3cc275e49 (patch)
treebcc9c2f55725f90dfde2c3acca33adbf31762bdf /src
parent6f2493f77f61b3922f3bc01ce3ea613f2a70230c (diff)
temporary workaround for #352
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v107
1 files changed, 92 insertions, 15 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 76a885d70..bb2255547 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -5009,7 +5009,7 @@ Module Compilers.
| Some x, Some y, Some z
=> type.option.Some
(t:=(type.Z*type.Z)%ctype)
- (let b := ZRange.split_bounds (ZRange.eight_corners (fun x y z => (x - y - z)%Z) x y z) split_at in
+ (let b := ZRange.split_bounds (ZRange.eight_corners (fun x y z => (y - z - x)%Z) x y z) split_at in
(* N.B. sub_get_borrow returns - ((x - y) / split_at) as the borrow, so we need to negate *)
(fst b, ZRange.opp (snd b)))
| _, _, _ => type.option.None
@@ -7631,6 +7631,7 @@ Module PrintingNotations.
Notation "'ADDC_256' ( x , y , z )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.add_with_get_carry_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope.
Notation "'ADDC_128' ( x , y , z )" := (ident.Z.cast2 (uint128, bool)%core @@ (ident.Z.add_with_get_carry_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope.
Notation "'SUB_256' ( x , y )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.sub_get_borrow_concrete TwoPow256 @@ (x, y)))%expr : expr_scope.
+ Notation "'SUBB_256' ( x , y , z )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.sub_with_get_borrow_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope.
Notation "'ADDM' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.add_modulo @@ (x, y, z)))%expr : expr_scope.
Notation "'RSHI' ( x , y , z )" := (ident.Z.cast _ @@ (ident.Z.rshi_concrete _ z @@ (x, y)))%expr : expr_scope.
Notation "'SELC' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.zselect @@ (x, y, z)))%expr : expr_scope.
@@ -8614,6 +8615,7 @@ Module PreFancy.
| add (imm : BinInt.Z) : ident (Z * Z) (Z * Z)
| addc (imm : BinInt.Z) : ident (Z * Z * Z) (Z * Z)
| sub (imm : BinInt.Z) : ident (Z * Z) (Z * Z)
+ | subb (imm : BinInt.Z) : ident (Z * Z * Z) (Z * Z)
| mulll : ident (Z * Z) Z
| mullh : ident (Z * Z) Z
| mulhl : ident (Z * Z) Z
@@ -8774,6 +8776,19 @@ Module PreFancy.
| _ => dummy _
end
else dummy _
+ | ident.Z.sub_with_get_borrow_concrete w =>
+ fun t r x f =>
+ if w =? wordmax
+ then
+ match x with
+ | Pair (type.prod Z Z) Z (Pair Z Z xb xl) xr =>
+ match invert_shift xr with
+ | Some (xr', imm) => LetInAppIdentZZ r (subb imm) (Pair (Pair xb xl) xr') f
+ | None => LetInAppIdentZZ r (subb 0) (Pair (Pair xb xl) xr) f
+ end
+ | _ => dummy _
+ end
+ else dummy _
| ident.Z.rshi_concrete w n =>
fun _ r x f =>
if w =? wordmax
@@ -8840,6 +8855,7 @@ Module PreFancy.
| add imm => fun x => Z.add_get_carry_full wordmax (fst x) (shift (snd x) imm)
| addc imm => fun x => Z.add_with_get_carry_full wordmax (fst (fst x)) (snd (fst x)) (shift (snd x) imm)
| sub imm => fun x => Z.sub_get_borrow_full wordmax (fst x) (shift (snd x) imm)
+ | subb imm => fun x => Z.sub_with_get_borrow_full wordmax (fst (fst x)) (snd (fst x)) (shift (snd x) imm)
| mulll => fun x => low (fst x) * low (snd x)
| mullh => fun x => low (fst x) * high (snd x)
| mulhl => fun x => high (fst x) * low (snd x)
@@ -8990,6 +9006,16 @@ Module PreFancy.
(Pair x y)
(word_range, flag_range)
(ident.Z.sub_get_borrow_concrete wordmax)
+ | ok_subb :
+ forall b x y : scalar type.Z,
+ in_flag_range (get_range b) ->
+ in_word_range (get_range x) ->
+ in_word_range (get_range y) ->
+ ok_ident _
+ (type.prod type.Z type.Z)
+ (Pair (Pair b x) y)
+ (word_range, flag_range)
+ (ident.Z.sub_with_get_borrow_concrete wordmax)
| ok_rshi :
forall (x : scalar (type.prod type.Z type.Z)) n,
in_word_range (fst (get_range x)) ->
@@ -9397,6 +9423,12 @@ Module PreFancy.
match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
rewrite Z.div_sub_small by omega.
split; break_match; lia. }
+ {
+ autorewrite with to_div_mod.
+ match goal with |- context [?a - ?b - ?c] => replace (a - b - c) with (a - (b + c)) by ring end.
+ match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
+ rewrite Z.div_sub_small by omega.
+ split; break_match; lia. }
{ apply has_word_range_rshi; omega. }
{ rewrite Z.zselect_correct. break_match; omega. }
{ cbn [interp_scalar fst snd get_range] in *.
@@ -9793,6 +9825,9 @@ Module PreFancy.
Notation "sub@( y , x1 , x2 , n ); g"
:= (LetInAppIdentZZ (uint256, bool) (sub n) (Pair x1 x2) (fun y => g))
(at level 10, g at level 200, format "sub@( y , x1 , x2 , n ); '//' g").
+ Notation "subb@( y , x1 , x2 , x3 , n ); g"
+ := (LetInAppIdentZZ (uint256, bool) (subb n) (Pair (Pair x1 x2) x3) (fun y => g))
+ (at level 10, g at level 200, format "subb@( y , x1 , x2 , x3 , n ); '//' g").
Notation "rshi@( y , x1 , x2 , n ); g"
:= (LetInAppIdentZ _ (rshi n) (Pair x1 x2) (fun y => g))
(at level 10, g at level 200, format "rshi@( y , x1 , x2 , n ); '//' g ").
@@ -9927,6 +9962,15 @@ Module Fancy.
r0 - (r1 << imm))
|}.
+ Definition SUBC (imm : int) : instruction :=
+ {|
+ num_source_regs := 2;
+ writes_conditions := [C; M; L; Z];
+ spec := (fun '(r0, r1) cc =>
+ r0 - (r1 << imm) - cc[C])
+ |}.
+
+
Definition MUL128LL : instruction :=
{|
num_source_regs := 2;
@@ -10093,6 +10137,8 @@ Module Fancy.
existT _ (ADDC imm) (of_prefancy_scalar (Pair (Snd (Fst args)) (Snd args)))
| PreFancy.sub imm => fun args : @scalar var (type.Z * type.Z) =>
existT _ (SUB imm) (of_prefancy_scalar args)
+ | PreFancy.subb imm => fun args : @scalar var (type.Z * type.Z * type.Z) =>
+ existT _ (SUBC imm) (of_prefancy_scalar (Pair (Snd (Fst args)) (Snd args)))
| PreFancy.mulll => fun args : @scalar var (type.Z * type.Z) =>
existT _ MUL128LL (of_prefancy_scalar args)
| PreFancy.mullh => fun args : @scalar var (type.Z * type.Z) =>
@@ -10514,7 +10560,7 @@ Module BarrettReduction.
dlet_nd qt := qt in
dlet_nd r2 := mul (low qt) M in
dlet_nd r := sub xt r2 in
- dlet_nd q3 := cond_sub1 r M in
+ let q3 := cond_sub1 r M in
cond_sub2 q3 M.
Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k).
@@ -10698,7 +10744,16 @@ Module BarrettReduction.
Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r.
Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2).
- Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2).
+ (* TODO: use this definition once issue #352 is resolved *)
+ (* Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2). *)
+ Definition widesub (t1 t2 : list Z) :=
+ let t1_0 := hd 0 t1 in
+ let t1_1 := hd 0 (tl t1) in
+ let t2_0 := hd 0 t2 in
+ let t2_1 := hd 0 (tl t2) in
+ dlet_nd x0 := Z.sub_get_borrow_full (2^k) t1_0 t2_0 in
+ dlet_nd x1 := Z.sub_with_get_borrow_full (2^k) (snd x0) t1_1 t2_1 in
+ [fst x0; fst x1].
Definition widemul := BaseConversion.widemul_inlined k n nout.
Lemma partition_represents x :
@@ -10739,7 +10794,27 @@ Module BarrettReduction.
represents t2 y ->
0 <= x - y < 2^k*2^k ->
represents (widesub t1 t2) (x - y).
+ Proof.
+ intros; cbv [widesub Let_In].
+ rewrite (represents_eq t1 x) by assumption.
+ rewrite (represents_eq t2 y) by assumption.
+ cbn [hd tl].
+ autorewrite with to_div_mod.
+ pull_Zmod.
+ match goal with |- represents [?m; ?d] ?x =>
+ replace d with (x / 2 ^ k); [solve [auto using represents_id] |] end.
+ rewrite <-(Z.mod_small ((x - y) / 2^k) (2^k)) by (split; try apply Z.div_lt_upper_bound; Z.zero_bounds).
+ f_equal.
+ transitivity ((x mod 2^k - y mod 2^k + 2^k * (x / 2 ^ k) - 2^k * (y / 2^k)) / 2^k). {
+ rewrite (Z.div_mod x (2^k)) at 1 by auto using Z.pow_nonzero with omega.
+ rewrite (Z.div_mod y (2^k)) at 1 by auto using Z.pow_nonzero with omega.
+ f_equal. ring. }
+ autorewrite with zsimplify.
+ ring.
+ Qed.
+ (* Works with Rows.sub-based widesub definition
Proof. intros; cbv [widesub]. wide_op Rows.sub_partitions. Qed.
+ *)
Lemma widemul_represents x y :
0 <= x < 2^k ->
@@ -10836,7 +10911,8 @@ Module BarrettReduction.
Definition cond_sub1 (a : list Z) y : Z :=
dlet_nd maybe_y := Z.zselect (Z.cc_l (high a)) 0 y in
- fst (Z.sub_get_borrow_full (2^k) (low a) maybe_y).
+ dlet_nd diff := Z.sub_get_borrow_full (2^k) (low a) maybe_y in
+ fst diff.
Lemma cc_l_only_bit : forall x s, 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s.
Proof.
@@ -10928,7 +11004,7 @@ Module BarrettReduction.
= barrett_reduce k M muLow n nout xLow xHigh)
As barrett_red_gen_correct.
Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed.
- (* TODO : reification here is still quite slow (~40s on a beefy machine). Possibly just due to size of term, but warrants further investigation. *)
+ (* TODO : reification here is still quite slow (~90s on a beefy machine). Possibly just due to size of term, but warrants further investigation. *)
Module Export ReifyHints.
Global Hint Extern 1 (_ = barrett_reduce _ _ _ _ _ _ _) => simple apply barrett_red_gen_correct : reify_gen_cache.
End ReifyHints.
@@ -11010,9 +11086,6 @@ Module Barrett256.
Open Scope expr_scope.
Print barrett_red256.
- (* TODO: the ADD/ADDC instructions containing Z.opp should be translated to SUB/SUBB in partial evaluation *)
- (* Note: the aforementioned ADD/ADDC instructions currently *do* fail to bounds-check, although the pipeline doesn't give an error.
- This is why their results are not cast (because the carry has range [-1~>0]). *)
(*
barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype,
expr_let x0 := SELM (x₂, 0, 26959946667150639793205513449348445388433292963828203772348655992835) in
@@ -11039,11 +11112,11 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type
expr_let x21 := ADDC_256 (x20₂, (uint128)(x18 >> 128), x16) in
expr_let x22 := ADD_256 ((uint256)(((uint128)(x17) & 340282366920938463463374607431768211455) << 128), x20₁) in
expr_let x23 := ADDC_256 (x22₂, (uint128)(x17 >> 128), x21₁) in
- expr_let x24 := Z.add_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (Z.opp @@ (fst @@ x22), x₁) in
- expr_let x25 := Z.add_with_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (x24₂, Z.opp @@ (fst @@ x23), x₂) in
+ expr_let x24 := SUB_256 (x₁, x22₁) in
+ expr_let x25 := SUBB_256 (x24₂, x₂, x23₁) in
expr_let x26 := SELL (x25₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
- expr_let x27 := Z.cast uint256 @@ (fst @@ SUB_256 (x24₁, x26)) in
- ADDM (x27, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
+ expr_let x27 := SUB_256 (x24₁, x26) in
+ ADDM (x27₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
: Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z))
*)
@@ -11056,7 +11129,6 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type
Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951).
Local Notation "'RegMuLow'" := (Straightline.expr.Primitive (t:=type.Z) 26959946667150639793205513449348445388433292963828203772348655992835).
Print barrett_red256_prefancy.
- (* Note : currently (correctly) fails to convert the adds that should be subs *)
(*
selm@(y, $x₂, RegZero, RegMuLow);
rshi@(y0, RegZero, $x₂,255);
@@ -11081,8 +11153,13 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type
add@(y19, $y18, $y17, 128);
addc@(y20, carry{$y19}, $y15, $y17, -128);
add@(y21, $y19, $y16, 128);
- addc@(_, carry{$y21}, $y20, $y16, -128);
- Straightline.expr.Scalar (Straightline.expr.Primitive (-1))
+ addc@(y22, carry{$y21}, $y20, $y16, -128);
+ sub@(y23, $x₁, $y21, 0);
+ subb@(y24, carry{$y23}, $x₂, $y22, 0);
+ sell@(y25, $y24, RegZero, RegMod);
+ sub@(y26, $y23, $y25, 0);
+ addm@(y27, $y26, RegZero, RegMod);
+ ret $y27
*)
End Barrett256.