From 9b894e59c4c5275477e7faf2138a104fa8e7c68f Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Fri, 18 May 2018 01:12:56 +0200 Subject: [WIP] shifting adds --- src/Experiments/SimplyTypedArithmetic.v | 291 ++++++++++++++++++++++---------- 1 file changed, 200 insertions(+), 91 deletions(-) (limited to 'src/Experiments/SimplyTypedArithmetic.v') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 40cc90157..814161335 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -8547,16 +8547,13 @@ Module PreFancy. Local Notation Z := (type.type_primitive type.Z). Inductive ident : type -> type -> Type := - | add : ident (Z * Z) (Z * Z) - | addc : ident (Z * Z * Z) (Z * Z) + | add (imm : BinInt.Z) : ident (Z * Z) (Z * Z) + | addc (imm : BinInt.Z) : ident (Z * Z * Z) (Z * Z) + | sub (imm : BinInt.Z) : ident (Z * Z) (Z * Z) | mulll : ident (Z * Z) Z | mullh : ident (Z * Z) Z | mulhl : ident (Z * Z) Z | mulhh : ident (Z * Z) Z - | sub : ident (Z * Z) (Z * Z) - (* | land : BinInt.Z -> ident Z Z *) - | shiftr : BinInt.Z -> ident Z Z - | shiftl : BinInt.Z -> ident Z Z | rshi : BinInt.Z -> ident (Z * Z) Z | selc : ident (Z * Z * Z) Z | selm : ident (Z * Z * Z) Z @@ -8649,27 +8646,70 @@ Module PreFancy. | _ => None end. + Definition invert_shift {t} (s : @scalar var t) + : option (@scalar var Z * BinInt.Z) := + match s return option (@scalar var Z * BinInt.Z) with + | Cast r (Shiftl n x) => + match invert_lower x return option (@scalar var Z * BinInt.Z) with + | Some x' => + if (lower r =? 0) && (upper r =? wordmax-1) && (n =? half_bits) + then Some (x', half_bits) + else None + | None => None + end + | _ => + match invert_upper s return _ with + | Some x => Some (x, -half_bits) + | None => None + end + end. + Definition of_straightline_ident {s d} (idc : ident.ident s d) : forall t, range_type d -> @scalar var s -> (var d -> @expr var ident t) -> @expr var ident t := match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with | ident.Z.add_get_carry_concrete w => fun t r x f => if w =? wordmax - then LetInAppIdentZZ r add x f + then + match x with + | Pair Z Z xl xr => + match invert_shift xl, invert_shift xr with + | _, Some (xr', imm) => LetInAppIdentZZ r (add imm) (Pair xl xr') f + | Some (xl', imm), None => LetInAppIdentZZ r (add imm) (Pair xr xl') f + + | None, None => LetInAppIdentZZ r (add 0) (Pair xl xr) f + end + | _ => dummy _ + end else dummy _ | ident.Z.add_with_get_carry_concrete w => fun t r x f => if w =? wordmax - then LetInAppIdentZZ r addc x f + then + match x with + | Pair (type.prod Z Z) Z (Pair Z Z xc xl) xr => + match invert_shift xl, invert_shift xr with + | _, Some (xr', imm) => LetInAppIdentZZ r (addc imm) (Pair (Pair xc xl) xr') f + | Some (xl', imm), None => LetInAppIdentZZ r (addc imm) (Pair (Pair xc xr) xl') f + + | None, None => LetInAppIdentZZ r (addc 0) (Pair (Pair xc xl) xr) f + end + | _ => dummy _ + end else dummy _ | ident.Z.sub_get_borrow_concrete w => fun t r x f => if w =? wordmax - then LetInAppIdentZZ r sub x f + then + match x with + | Pair Z Z xl xr => + match invert_shift xr with + | Some (xr', imm) => LetInAppIdentZZ r (sub imm) (Pair xl xr') f + | None => LetInAppIdentZZ r (sub 0) (Pair xl xr) f + end + | _ => dummy _ + end else dummy _ - (* | ident.Z.land n => fun _ r => LetInAppIdentZ r (land n) *) - | ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n) - | ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n) | ident.Z.rshi_concrete w n => fun _ r x f => if w =? wordmax @@ -8726,19 +8766,17 @@ Module PreFancy. Section interp. Local Notation low x := (Z.land x (wordmax_half_bits - 1)). Local Notation high x := (x >> half_bits). + Local Notation shift x imm := ((x << imm) mod wordmax). Definition interp_ident {s d} (idc : ident s d) : type.interp s -> type.interp d := match idc with - | add => fun x => Z.add_get_carry_full wordmax (fst x) (snd x) - | addc => fun x => Z.add_with_get_carry_full wordmax (fst (fst x)) (snd (fst x)) (snd x) + | add imm => fun x => Z.add_get_carry_full wordmax (fst x) (shift (snd x) imm) + | addc imm => fun x => Z.add_with_get_carry_full wordmax (fst (fst x)) (snd (fst x)) (shift (snd x) imm) + | sub imm => fun x => Z.sub_get_borrow_full wordmax (fst x) (shift (snd x) imm) | mulll => fun x => low (fst x) * low (snd x) | mullh => fun x => low (fst x) * high (snd x) | mulhl => fun x => high (fst x) * low (snd x) | mulhh => fun x => high (fst x) * high (snd x) - | sub => fun x => Z.sub_get_borrow_full wordmax (fst x) (snd x) - (* | land n => fun x => Z.land x n *) (* only allowed inside select/mul? *) - | shiftr n => fun x => Z.shiftr x n - | shiftl n => fun x => Z.shiftl x n | rshi n => fun x => Z.rshi wordmax (fst x) (snd x) n | selc => fun x => Z.zselect (fst (fst x)) (snd (fst x)) (snd x) | selm => fun x => Z.zselect (Z.cc_m wordmax (fst (fst x))) (snd (fst x)) (snd x) @@ -8855,50 +8893,33 @@ Module PreFancy. Inductive ok_ident : forall s d, scalar s -> range_type d -> ident.ident s d -> Prop := | ok_add : - forall x : scalar (type.prod type.Z type.Z), - in_word_range (fst (get_range x)) -> - in_word_range (snd (get_range x)) -> + forall x y : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> ok_ident _ (type.prod type.Z type.Z) - x + (Pair x y) (word_range, flag_range) (ident.Z.add_get_carry_concrete wordmax) | ok_addc : - forall x : scalar (type.prod (type.prod type.Z type.Z) type.Z), - in_flag_range (fst (fst (get_range x))) -> - in_word_range (snd (fst (get_range x))) -> - in_word_range (snd (get_range x)) -> + forall c x y : scalar type.Z, + in_flag_range (get_range c) -> + in_word_range (get_range x) -> + in_word_range (get_range y) -> ok_ident _ (type.prod type.Z type.Z) - x + (Pair (Pair c x) y) (word_range, flag_range) (ident.Z.add_with_get_carry_concrete wordmax) | ok_sub : - forall x : scalar (type.prod type.Z type.Z), - in_word_range (fst (get_range x)) -> - in_word_range (snd (get_range x)) -> + forall x y : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> ok_ident _ (type.prod type.Z type.Z) - x + (Pair x y) (word_range, flag_range) (ident.Z.sub_get_borrow_concrete wordmax) - (* | ok_land : - forall (x : scalar type.Z) n, - in_word_range (get_range x) -> - 0 <= n < wordmax -> - ok_ident type.Z type.Z x (ZRange.map (fun y => Z.land y n) (get_range x)) (ident.Z.land n)*) - | ok_shiftr : - forall (x : scalar type.Z) n r, - in_word_range (get_range x) -> - 0 <= n <= log2wordmax -> - r = ZRange.map (fun y => Z.shiftr y n) (get_range x) -> - ok_ident type.Z type.Z x r (ident.Z.shiftr n) - | ok_shiftl : - forall (x : scalar type.Z) n, - in_word_range (get_range x) -> - 0 <= n -> - upper (get_range x) * 2 ^ n <= wordmax - 1 -> - ok_ident type.Z type.Z x word_range (ident.Z.shiftl n) | ok_rshi : forall (x : scalar (type.prod type.Z type.Z)) n, in_word_range (fst (get_range x)) -> @@ -9150,6 +9171,16 @@ Module PreFancy. | true => invert H; extract_ok_scalar' 1 x | _ => fail end + | H : ok_scalar (Pair (Pair x _) _) |- _ => + match (eval compute in (2 <=? level)) with + | true => invert H; extract_ok_scalar' 1 x + | _ => fail + end + | H : ok_scalar (Pair (Pair _ x) _) |- _ => + match (eval compute in (2 <=? level)) with + | true => invert H; extract_ok_scalar' 1 x + | _ => fail + end | H : ok_scalar (?g x) |- _ => invert H | H : ok_scalar (Pair x _) |- _ => invert H | H : ok_scalar (Pair _ x) |- _ => invert H @@ -9304,9 +9335,6 @@ Module PreFancy. match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. rewrite Z.div_sub_small by omega. split; break_match; lia. } - { apply has_range_shiftr; cbn [has_range]; omega. } - { eapply has_word_range_shiftl; eauto with has_range omega. - apply in_word_range_spec ; omega. } { apply has_word_range_rshi; omega. } { rewrite Z.zselect_correct. break_match; omega. } { cbn [interp_scalar fst snd get_range] in *. @@ -9423,6 +9451,73 @@ Module PreFancy. cbn; rewrite Pxi, Pyi; rewrite interp_cast_noop by auto; reflexivity. Qed. + Lemma has_word_range_mod_small x: + @has_range type.Z word_range x -> + x mod wordmax = x. + Proof. + cbv [has_range upper lower]. + intros. apply Z.mod_small; omega. + Qed. + + Lemma invert_shift_correct (s : scalar type.Z) x imm : + ok_scalar s -> + invert_shift consts s = Some (x, imm) -> + interp_scalar s = (interp_scalar x << imm) mod wordmax. + Proof. + (* + intros Hok ?; invert Hok; + try match goal with H : ok_scalar ?x, H' : context[Cast _ ?x] |- _ => + invert H end; + try match goal with H : ok_scalar ?x, H' : context[Shiftl _ ?x] |- _ => + invert H end; + try match goal with H : ok_scalar ?x, H' : context[Shiftl _ (Cast _ ?x)] |- _ => + invert H end; + try (cbn [invert_shift invert_upper invert_upper'] in *; discriminate). + { + repeat match goal with + | _ => progress (cbn [invert_shift invert_upper invert_upper' + interp_scalar fst snd] in * ) + | _ => rewrite interp_cast_noop by eauto using has_range_loosen + | H : ok_scalar (Shiftr _ _) |- _ => apply has_range_interp_scalar in H + | H : context [if ?x then _ else _] |- _ => + let Heq := fresh in case_eq x; intro Heq; rewrite Heq in H + | H : _ |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt + | H : Some _ = Some _ |- _ => progress (invert H) + | _ => progress subst + | _ => reflexivity + | _ => discriminate + end. + rewrite has_word_range_mod_small. + 2:eauto using has_range_loosen. + } + { + repeat match goal with + | _ => progress (cbn [invert_shift invert_upper invert_upper' + invert_lower invert_lower' + interp_scalar fst snd] in * ) + | _ => rewrite interp_cast_noop by eauto using has_range_loosen + | H : ok_scalar (Shiftr _ _) |- _ => apply has_range_interp_scalar in H + | H : ok_scalar (Shiftl _ _) |- _ => apply has_range_interp_scalar in H + | H : ok_scalar (Land _ _) |- _ => apply has_range_interp_scalar in H + | H : context [if ?x then _ else _] |- _ => + let Heq := fresh in case_eq x; intro Heq; rewrite Heq in H + | H : _ |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt + | H : Some _ = Some _ |- _ => progress (invert H) + | _ => progress subst + | _ => reflexivity + | _ => discriminate + end. + Qed. + *) + Admitted. + + Local Ltac solve_commutative_replace := + match goal with + | |- @eq (_ * _) ?x ?y => + replace x with (fst x, snd x) by (destruct x; reflexivity); + replace y with (fst y, snd y) by (destruct y; reflexivity) + end; autorewrite with to_div_mod; solve [repeat (f_equal; try ring)]. + Lemma of_straightline_ident_correct s d t x r (idc : ident.ident s d) g : ok_ident s d x r idc -> ok_scalar x -> @@ -9433,16 +9528,29 @@ Module PreFancy. pose proof wordmax_half_bits_pos. pose proof (ident_interp_has_range _ _ x r idc ltac:(assumption) ltac:(assumption)). induction H; try solve [auto using of_straightline_ident_mul_correct]; - cbn [of_straightline_ident ident.interp ident.gen_interp invert_selm invert_sell] in *; intros; - rewrite ?Z.eqb_refl; cbn [interp interp_ident andb]; try destruct_scalar; + cbn [of_straightline_ident ident.interp ident.gen_interp + invert_selm invert_sell] in *; + intros; rewrite ?Z.eqb_refl; cbn [andb]; + try match goal with |- context [invert_shift] => break_match end; + cbn [interp interp_ident]; try destruct_scalar; repeat match goal with | _ => progress (cbn [fst snd interp_scalar] in * ) | _ => progress break_match; [ ] + | _ => progress autorewrite with zsimplify_fast | _ => rewrite interp_cast_noop with (r:=flag_range) in * by (apply has_flag_range_cc_m'; auto; extract_ok_scalar) | _ => rewrite interp_cast_noop with (r:=flag_range) in * by (apply has_flag_range_land'; auto; extract_ok_scalar) | H : _ = (_,_) |- _ => progress (inversion H; subst) + | H : invert_shift _ _ = Some _ |- _ => + apply invert_shift_correct in H; [|extract_ok_scalar]; + rewrite <-H + | H : has_range ?r (?f ?x ?y) |- context [?f ?y ?x] => + replace (f y x) with (f x y) by solve_commutative_replace + | _ => rewrite has_word_range_mod_small + by (eapply has_range_loosen; + [apply has_range_interp_scalar; extract_ok_scalar|]; + assumption) | _ => rewrite interp_cast_noop by assumption | _ => rewrite interp_cast2_noop by assumption | _ => reflexivity @@ -9513,25 +9621,32 @@ Module PreFancy. Notation "f @( y , x1 , x2 , x3 ); g " := (LetInAppIdentZZ (uint256, bool)%core f (Pair (Pair x1 x2) x3) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g "). - Notation "f @( y , x1 , x2 , x3 ); g " + Notation "f @( y , x1 , x2 , x3 ); '#128' g " := (LetInAppIdentZZ (uint128, bool)%core f (Pair (Pair x1 x2) x3) (fun y => g)) - (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g "). + (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '#128' '//' g "). Notation "f @( y , x1 , x2 ); g " := (LetInAppIdentZ uint256 f (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g "). Notation "f @( y , x1 , x2 , x3 ); g " := (LetInAppIdentZ uint256 f (Pair (Pair x1 x2) x3) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g "). - Notation "shiftL@( y , x , n ); g" - := (LetInAppIdentZ uint256 (shiftl n) (Lower{x}) (fun y => g)) - (at level 10, g at level 200, format "shiftL@( y , x , n ); '//' g"). - Notation "shiftR@( y , x , n ); g " - := (LetInAppIdentZ uint128 (shiftr n) x (fun y => g)) - (at level 10, g at level 200, format "shiftR@( y , x , n ); '//' g "). + (* special cases for when the ident constructor takes a constant argument *) + Notation "add@( y , x1 , x2 , n ); g" + := (LetInAppIdentZZ (uint256, bool) (add n) (Pair x1 x2) (fun y => g)) + (at level 10, g at level 200, format "add@( y , x1 , x2 , n ); '//' g"). + Notation "addc@( y , x1 , x2 , x3 , n ); g" + := (LetInAppIdentZZ (uint256, bool) (addc n) (Pair (Pair x1 x2) x3) (fun y => g)) + (at level 10, g at level 200, format "addc@( y , x1 , x2 , x3 , n ); '//' g"). + Notation "addc@( y , x1 , x2 , x3 , n ); '#128' g" + := (LetInAppIdentZZ (uint128, bool) (addc n) (Pair (Pair x1 x2) x3) (fun y => g)) + (at level 10, g at level 200, format "addc@( y , x1 , x2 , x3 , n ); '#128' '//' g"). + Notation "sub@( y , x1 , x2 , n ); g" + := (LetInAppIdentZZ (uint256, bool) (sub n) (Pair x1 x2) (fun y => g)) + (at level 10, g at level 200, format "sub@( y , x1 , x2 , n ); '//' g"). Notation "rshi@( y , x1 , x2 , n ); g" := (LetInAppIdentZ _ (rshi n) (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "rshi@( y , x1 , x2 , n ); '//' g "). - Notation "'Ret' $ x" := (Scalar (Var tZ x)) (at level 10, format "'Ret' $ x"). + Notation "'ret' $ x" := (Scalar (Var tZ x)) (at level 10, format "'ret' $ x"). Notation "( x , y )" := (Pair x y) (at level 10, left associativity). End Notations. End PreFancy. @@ -10159,30 +10274,24 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type mulhl@(y3, RegMuLow, $y1); mullh@(y4, RegMuLow, $y1); mulll@(y5, RegMuLow, $y1); - add@(y6, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y3})), $y5); - addc@(y7, carry{$y6}, $y2, Straightline.expr.Cast uint128 - (Straightline.expr.Shiftr 128 ($y4))); - add@(y8, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y4})), $y6); - addc@(y9, carry{$y8}, Straightline.expr.Cast uint128 - (Straightline.expr.Shiftr 128 ($y3)), $y7); - add@(y10, $y1, $y9); - addc@(y11, carry{$y10}, RegZero, $y0); - add@(y12, $y, $y10); - addc@(y13, carry{$y12}, RegZero, $y11); + add@(y6, $y5, $y3, 128); + addc@(y7, carry{$y6}, $y2, $y4, -128); + add@(y8, $y6, $y4, 128); + addc@(y9, carry{$y8}, $y7, $y3, -128); + add@(y10, $y1, $y9, 0); + addc@(y11, carry{$y10}, RegZero, $y0, 0); #128 + add@(y12, $y, $y10, 0); + addc@(y13, carry{$y12}, RegZero, $y11, 0); #128 rshi@(y14, $y13, $y12,1); mulhh@(y15, RegMod, $y14); mullh@(y16, RegMod, $y14); mulhl@(y17, RegMod, $y14); mulll@(y18, RegMod, $y14); - add@(y19, Straightline.expr.Cast uint256 - (Straightline.expr.Shiftl 128 (Lower{$y16})), $y18); - addc@(y20, carry{$y19}, $y15, Straightline.expr.Cast uint128 - (Straightline.expr.Shiftr 128 ($y17))); - add@(y21, Straightline.expr.Cast uint256 - (Straightline.expr.Shiftl 128 (Lower{$y17})), $y19); - addc@(_, carry{$y21}, Straightline.expr.Cast uint128 - (Straightline.expr.Shiftr 128 ($y16)), $y20); - Straightline.expr.Scalar (Straightline.expr.Primitive (-1)) + add@(y19, $y18, $y16, 128); + addc@(y20, carry{$y19}, $y15, $y17, -128); + add@(y21, $y19, $y17, 128); + addc@(_, carry{$y21}, $y20, $y16, -128); + Straightline.expr.Scalar (Straightline.expr.Primitive (-1)) *) End Barrett256. @@ -10880,22 +10989,22 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * (* mulhl@(y0, RegPInv, $x₁); mulll@(y1, RegPInv, $x₁); - add@(y2, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y})), $y1); - add@(y3, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y0})), $y2); + add@(y2, $y1, $y, 128); + add@(y3, $y2, $y0, 128); mulhh@(y4, RegMod, $y3); mullh@(y5, RegMod, $y3); mulhl@(y6, RegMod, $y3); mulll@(y7, RegMod, $y3); - add@(y8, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y5})), $y7); - addc@(y9, carry{$y8}, $y4, Straightline.expr.Cast uint128 (Straightline.expr.Shiftr 128 ($y6))); - add@(y10, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y6})), $y8); - addc@(y11, carry{$y10}, Straightline.expr.Cast uint128 (Straightline.expr.Shiftr 128 ($y5)), $y9); - add@(y12, $y10, $x₁); - addc@(y13, carry{$y12}, $y11, $x₂); + add@(y8, $y7, $y5, 128); + addc@(y9, carry{$y8}, $y4, $y6, -128); + add@(y10, $y8, $y6, 128); + addc@(y11, carry{$y10}, $y9, $y5, -128); + add@(y12, $y10, $x₁, 0); + addc@(y13, carry{$y12}, $y11, $x₂, 0); selc@(y14, carry{$y13}, RegZero, RegMod); - sub@(y15, $y13, $y14); + sub@(y15, $y13, $y14, 0); addm@(y16, $y15, RegZero, RegMod); - Ret $y16 + ret $y16 *) End Montgomery256. -- cgit v1.2.3