diff options
author | Jade Philipoom <jadep@google.com> | 2018-05-30 17:02:46 +0200 |
---|---|---|
committer | Jade Philipoom <jadep@google.com> | 2018-05-31 15:13:41 +0200 |
commit | e6119c9595326a910d177488bf44aab3cc275e49 (patch) | |
tree | bcc9c2f55725f90dfde2c3acca33adbf31762bdf /src | |
parent | 6f2493f77f61b3922f3bc01ce3ea613f2a70230c (diff) |
temporary workaround for #352
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 107 |
1 files changed, 92 insertions, 15 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 76a885d70..bb2255547 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -5009,7 +5009,7 @@ Module Compilers. | Some x, Some y, Some z => type.option.Some (t:=(type.Z*type.Z)%ctype) - (let b := ZRange.split_bounds (ZRange.eight_corners (fun x y z => (x - y - z)%Z) x y z) split_at in + (let b := ZRange.split_bounds (ZRange.eight_corners (fun x y z => (y - z - x)%Z) x y z) split_at in (* N.B. sub_get_borrow returns - ((x - y) / split_at) as the borrow, so we need to negate *) (fst b, ZRange.opp (snd b))) | _, _, _ => type.option.None @@ -7631,6 +7631,7 @@ Module PrintingNotations. Notation "'ADDC_256' ( x , y , z )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.add_with_get_carry_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope. Notation "'ADDC_128' ( x , y , z )" := (ident.Z.cast2 (uint128, bool)%core @@ (ident.Z.add_with_get_carry_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope. Notation "'SUB_256' ( x , y )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.sub_get_borrow_concrete TwoPow256 @@ (x, y)))%expr : expr_scope. + Notation "'SUBB_256' ( x , y , z )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.sub_with_get_borrow_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope. Notation "'ADDM' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.add_modulo @@ (x, y, z)))%expr : expr_scope. Notation "'RSHI' ( x , y , z )" := (ident.Z.cast _ @@ (ident.Z.rshi_concrete _ z @@ (x, y)))%expr : expr_scope. Notation "'SELC' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.zselect @@ (x, y, z)))%expr : expr_scope. @@ -8614,6 +8615,7 @@ Module PreFancy. | add (imm : BinInt.Z) : ident (Z * Z) (Z * Z) | addc (imm : BinInt.Z) : ident (Z * Z * Z) (Z * Z) | sub (imm : BinInt.Z) : ident (Z * Z) (Z * Z) + | subb (imm : BinInt.Z) : ident (Z * Z * Z) (Z * Z) | mulll : ident (Z * Z) Z | mullh : ident (Z * Z) Z | mulhl : ident (Z * Z) Z @@ -8774,6 +8776,19 @@ Module PreFancy. | _ => dummy _ end else dummy _ + | ident.Z.sub_with_get_borrow_concrete w => + fun t r x f => + if w =? wordmax + then + match x with + | Pair (type.prod Z Z) Z (Pair Z Z xb xl) xr => + match invert_shift xr with + | Some (xr', imm) => LetInAppIdentZZ r (subb imm) (Pair (Pair xb xl) xr') f + | None => LetInAppIdentZZ r (subb 0) (Pair (Pair xb xl) xr) f + end + | _ => dummy _ + end + else dummy _ | ident.Z.rshi_concrete w n => fun _ r x f => if w =? wordmax @@ -8840,6 +8855,7 @@ Module PreFancy. | add imm => fun x => Z.add_get_carry_full wordmax (fst x) (shift (snd x) imm) | addc imm => fun x => Z.add_with_get_carry_full wordmax (fst (fst x)) (snd (fst x)) (shift (snd x) imm) | sub imm => fun x => Z.sub_get_borrow_full wordmax (fst x) (shift (snd x) imm) + | subb imm => fun x => Z.sub_with_get_borrow_full wordmax (fst (fst x)) (snd (fst x)) (shift (snd x) imm) | mulll => fun x => low (fst x) * low (snd x) | mullh => fun x => low (fst x) * high (snd x) | mulhl => fun x => high (fst x) * low (snd x) @@ -8990,6 +9006,16 @@ Module PreFancy. (Pair x y) (word_range, flag_range) (ident.Z.sub_get_borrow_concrete wordmax) + | ok_subb : + forall b x y : scalar type.Z, + in_flag_range (get_range b) -> + in_word_range (get_range x) -> + in_word_range (get_range y) -> + ok_ident _ + (type.prod type.Z type.Z) + (Pair (Pair b x) y) + (word_range, flag_range) + (ident.Z.sub_with_get_borrow_concrete wordmax) | ok_rshi : forall (x : scalar (type.prod type.Z type.Z)) n, in_word_range (fst (get_range x)) -> @@ -9397,6 +9423,12 @@ Module PreFancy. match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. rewrite Z.div_sub_small by omega. split; break_match; lia. } + { + autorewrite with to_div_mod. + match goal with |- context [?a - ?b - ?c] => replace (a - b - c) with (a - (b + c)) by ring end. + match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite Z.div_sub_small by omega. + split; break_match; lia. } { apply has_word_range_rshi; omega. } { rewrite Z.zselect_correct. break_match; omega. } { cbn [interp_scalar fst snd get_range] in *. @@ -9793,6 +9825,9 @@ Module PreFancy. Notation "sub@( y , x1 , x2 , n ); g" := (LetInAppIdentZZ (uint256, bool) (sub n) (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "sub@( y , x1 , x2 , n ); '//' g"). + Notation "subb@( y , x1 , x2 , x3 , n ); g" + := (LetInAppIdentZZ (uint256, bool) (subb n) (Pair (Pair x1 x2) x3) (fun y => g)) + (at level 10, g at level 200, format "subb@( y , x1 , x2 , x3 , n ); '//' g"). Notation "rshi@( y , x1 , x2 , n ); g" := (LetInAppIdentZ _ (rshi n) (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "rshi@( y , x1 , x2 , n ); '//' g "). @@ -9927,6 +9962,15 @@ Module Fancy. r0 - (r1 << imm)) |}. + Definition SUBC (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 - (r1 << imm) - cc[C]) + |}. + + Definition MUL128LL : instruction := {| num_source_regs := 2; @@ -10093,6 +10137,8 @@ Module Fancy. existT _ (ADDC imm) (of_prefancy_scalar (Pair (Snd (Fst args)) (Snd args))) | PreFancy.sub imm => fun args : @scalar var (type.Z * type.Z) => existT _ (SUB imm) (of_prefancy_scalar args) + | PreFancy.subb imm => fun args : @scalar var (type.Z * type.Z * type.Z) => + existT _ (SUBC imm) (of_prefancy_scalar (Pair (Snd (Fst args)) (Snd args))) | PreFancy.mulll => fun args : @scalar var (type.Z * type.Z) => existT _ MUL128LL (of_prefancy_scalar args) | PreFancy.mullh => fun args : @scalar var (type.Z * type.Z) => @@ -10514,7 +10560,7 @@ Module BarrettReduction. dlet_nd qt := qt in dlet_nd r2 := mul (low qt) M in dlet_nd r := sub xt r2 in - dlet_nd q3 := cond_sub1 r M in + let q3 := cond_sub1 r M in cond_sub2 q3 M. Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k). @@ -10698,7 +10744,16 @@ Module BarrettReduction. Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r. Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2). - Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2). + (* TODO: use this definition once issue #352 is resolved *) + (* Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2). *) + Definition widesub (t1 t2 : list Z) := + let t1_0 := hd 0 t1 in + let t1_1 := hd 0 (tl t1) in + let t2_0 := hd 0 t2 in + let t2_1 := hd 0 (tl t2) in + dlet_nd x0 := Z.sub_get_borrow_full (2^k) t1_0 t2_0 in + dlet_nd x1 := Z.sub_with_get_borrow_full (2^k) (snd x0) t1_1 t2_1 in + [fst x0; fst x1]. Definition widemul := BaseConversion.widemul_inlined k n nout. Lemma partition_represents x : @@ -10739,7 +10794,27 @@ Module BarrettReduction. represents t2 y -> 0 <= x - y < 2^k*2^k -> represents (widesub t1 t2) (x - y). + Proof. + intros; cbv [widesub Let_In]. + rewrite (represents_eq t1 x) by assumption. + rewrite (represents_eq t2 y) by assumption. + cbn [hd tl]. + autorewrite with to_div_mod. + pull_Zmod. + match goal with |- represents [?m; ?d] ?x => + replace d with (x / 2 ^ k); [solve [auto using represents_id] |] end. + rewrite <-(Z.mod_small ((x - y) / 2^k) (2^k)) by (split; try apply Z.div_lt_upper_bound; Z.zero_bounds). + f_equal. + transitivity ((x mod 2^k - y mod 2^k + 2^k * (x / 2 ^ k) - 2^k * (y / 2^k)) / 2^k). { + rewrite (Z.div_mod x (2^k)) at 1 by auto using Z.pow_nonzero with omega. + rewrite (Z.div_mod y (2^k)) at 1 by auto using Z.pow_nonzero with omega. + f_equal. ring. } + autorewrite with zsimplify. + ring. + Qed. + (* Works with Rows.sub-based widesub definition Proof. intros; cbv [widesub]. wide_op Rows.sub_partitions. Qed. + *) Lemma widemul_represents x y : 0 <= x < 2^k -> @@ -10836,7 +10911,8 @@ Module BarrettReduction. Definition cond_sub1 (a : list Z) y : Z := dlet_nd maybe_y := Z.zselect (Z.cc_l (high a)) 0 y in - fst (Z.sub_get_borrow_full (2^k) (low a) maybe_y). + dlet_nd diff := Z.sub_get_borrow_full (2^k) (low a) maybe_y in + fst diff. Lemma cc_l_only_bit : forall x s, 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s. Proof. @@ -10928,7 +11004,7 @@ Module BarrettReduction. = barrett_reduce k M muLow n nout xLow xHigh) As barrett_red_gen_correct. Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. - (* TODO : reification here is still quite slow (~40s on a beefy machine). Possibly just due to size of term, but warrants further investigation. *) + (* TODO : reification here is still quite slow (~90s on a beefy machine). Possibly just due to size of term, but warrants further investigation. *) Module Export ReifyHints. Global Hint Extern 1 (_ = barrett_reduce _ _ _ _ _ _ _) => simple apply barrett_red_gen_correct : reify_gen_cache. End ReifyHints. @@ -11010,9 +11086,6 @@ Module Barrett256. Open Scope expr_scope. Print barrett_red256. - (* TODO: the ADD/ADDC instructions containing Z.opp should be translated to SUB/SUBB in partial evaluation *) - (* Note: the aforementioned ADD/ADDC instructions currently *do* fail to bounds-check, although the pipeline doesn't give an error. - This is why their results are not cast (because the carry has range [-1~>0]). *) (* barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, expr_let x0 := SELM (x₂, 0, 26959946667150639793205513449348445388433292963828203772348655992835) in @@ -11039,11 +11112,11 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type expr_let x21 := ADDC_256 (x20₂, (uint128)(x18 >> 128), x16) in expr_let x22 := ADD_256 ((uint256)(((uint128)(x17) & 340282366920938463463374607431768211455) << 128), x20₁) in expr_let x23 := ADDC_256 (x22₂, (uint128)(x17 >> 128), x21₁) in - expr_let x24 := Z.add_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (Z.opp @@ (fst @@ x22), x₁) in - expr_let x25 := Z.add_with_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (x24₂, Z.opp @@ (fst @@ x23), x₂) in + expr_let x24 := SUB_256 (x₁, x22₁) in + expr_let x25 := SUBB_256 (x24₂, x₂, x23₁) in expr_let x26 := SELL (x25₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in - expr_let x27 := Z.cast uint256 @@ (fst @@ SUB_256 (x24₁, x26)) in - ADDM (x27, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) + expr_let x27 := SUB_256 (x24₁, x26) in + ADDM (x27₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) : Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z)) *) @@ -11056,7 +11129,6 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951). Local Notation "'RegMuLow'" := (Straightline.expr.Primitive (t:=type.Z) 26959946667150639793205513449348445388433292963828203772348655992835). Print barrett_red256_prefancy. - (* Note : currently (correctly) fails to convert the adds that should be subs *) (* selm@(y, $x₂, RegZero, RegMuLow); rshi@(y0, RegZero, $x₂,255); @@ -11081,8 +11153,13 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type add@(y19, $y18, $y17, 128); addc@(y20, carry{$y19}, $y15, $y17, -128); add@(y21, $y19, $y16, 128); - addc@(_, carry{$y21}, $y20, $y16, -128); - Straightline.expr.Scalar (Straightline.expr.Primitive (-1)) + addc@(y22, carry{$y21}, $y20, $y16, -128); + sub@(y23, $x₁, $y21, 0); + subb@(y24, carry{$y23}, $x₂, $y22, 0); + sell@(y25, $y24, RegZero, RegMod); + sub@(y26, $y23, $y25, 0); + addm@(y27, $y26, RegZero, RegMod); + ret $y27 *) End Barrett256. |