aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-05-18 01:12:56 +0200
committerGravatar Jade Philipoom <jadep@google.com>2018-05-31 13:46:48 +0200
commit9b894e59c4c5275477e7faf2138a104fa8e7c68f (patch)
treea6b3b031dde5dc75ffeea774f721345dc3ca7366 /src
parent0dd2e5f84b91f3af22a37f87eeea476f53445535 (diff)
[WIP] shifting adds
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v291
1 files changed, 200 insertions, 91 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 40cc90157..814161335 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -8547,16 +8547,13 @@ Module PreFancy.
Local Notation Z := (type.type_primitive type.Z).
Inductive ident : type -> type -> Type :=
- | add : ident (Z * Z) (Z * Z)
- | addc : ident (Z * Z * Z) (Z * Z)
+ | add (imm : BinInt.Z) : ident (Z * Z) (Z * Z)
+ | addc (imm : BinInt.Z) : ident (Z * Z * Z) (Z * Z)
+ | sub (imm : BinInt.Z) : ident (Z * Z) (Z * Z)
| mulll : ident (Z * Z) Z
| mullh : ident (Z * Z) Z
| mulhl : ident (Z * Z) Z
| mulhh : ident (Z * Z) Z
- | sub : ident (Z * Z) (Z * Z)
- (* | land : BinInt.Z -> ident Z Z *)
- | shiftr : BinInt.Z -> ident Z Z
- | shiftl : BinInt.Z -> ident Z Z
| rshi : BinInt.Z -> ident (Z * Z) Z
| selc : ident (Z * Z * Z) Z
| selm : ident (Z * Z * Z) Z
@@ -8649,27 +8646,70 @@ Module PreFancy.
| _ => None
end.
+ Definition invert_shift {t} (s : @scalar var t)
+ : option (@scalar var Z * BinInt.Z) :=
+ match s return option (@scalar var Z * BinInt.Z) with
+ | Cast r (Shiftl n x) =>
+ match invert_lower x return option (@scalar var Z * BinInt.Z) with
+ | Some x' =>
+ if (lower r =? 0) && (upper r =? wordmax-1) && (n =? half_bits)
+ then Some (x', half_bits)
+ else None
+ | None => None
+ end
+ | _ =>
+ match invert_upper s return _ with
+ | Some x => Some (x, -half_bits)
+ | None => None
+ end
+ end.
+
Definition of_straightline_ident {s d} (idc : ident.ident s d)
: forall t, range_type d -> @scalar var s -> (var d -> @expr var ident t) -> @expr var ident t :=
match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with
| ident.Z.add_get_carry_concrete w =>
fun t r x f =>
if w =? wordmax
- then LetInAppIdentZZ r add x f
+ then
+ match x with
+ | Pair Z Z xl xr =>
+ match invert_shift xl, invert_shift xr with
+ | _, Some (xr', imm) => LetInAppIdentZZ r (add imm) (Pair xl xr') f
+ | Some (xl', imm), None => LetInAppIdentZZ r (add imm) (Pair xr xl') f
+
+ | None, None => LetInAppIdentZZ r (add 0) (Pair xl xr) f
+ end
+ | _ => dummy _
+ end
else dummy _
| ident.Z.add_with_get_carry_concrete w =>
fun t r x f =>
if w =? wordmax
- then LetInAppIdentZZ r addc x f
+ then
+ match x with
+ | Pair (type.prod Z Z) Z (Pair Z Z xc xl) xr =>
+ match invert_shift xl, invert_shift xr with
+ | _, Some (xr', imm) => LetInAppIdentZZ r (addc imm) (Pair (Pair xc xl) xr') f
+ | Some (xl', imm), None => LetInAppIdentZZ r (addc imm) (Pair (Pair xc xr) xl') f
+
+ | None, None => LetInAppIdentZZ r (addc 0) (Pair (Pair xc xl) xr) f
+ end
+ | _ => dummy _
+ end
else dummy _
| ident.Z.sub_get_borrow_concrete w =>
fun t r x f =>
if w =? wordmax
- then LetInAppIdentZZ r sub x f
+ then
+ match x with
+ | Pair Z Z xl xr =>
+ match invert_shift xr with
+ | Some (xr', imm) => LetInAppIdentZZ r (sub imm) (Pair xl xr') f
+ | None => LetInAppIdentZZ r (sub 0) (Pair xl xr) f
+ end
+ | _ => dummy _
+ end
else dummy _
- (* | ident.Z.land n => fun _ r => LetInAppIdentZ r (land n) *)
- | ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n)
- | ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n)
| ident.Z.rshi_concrete w n =>
fun _ r x f =>
if w =? wordmax
@@ -8726,19 +8766,17 @@ Module PreFancy.
Section interp.
Local Notation low x := (Z.land x (wordmax_half_bits - 1)).
Local Notation high x := (x >> half_bits).
+ Local Notation shift x imm := ((x << imm) mod wordmax).
Definition interp_ident {s d} (idc : ident s d) : type.interp s -> type.interp d :=
match idc with
- | add => fun x => Z.add_get_carry_full wordmax (fst x) (snd x)
- | addc => fun x => Z.add_with_get_carry_full wordmax (fst (fst x)) (snd (fst x)) (snd x)
+ | add imm => fun x => Z.add_get_carry_full wordmax (fst x) (shift (snd x) imm)
+ | addc imm => fun x => Z.add_with_get_carry_full wordmax (fst (fst x)) (snd (fst x)) (shift (snd x) imm)
+ | sub imm => fun x => Z.sub_get_borrow_full wordmax (fst x) (shift (snd x) imm)
| mulll => fun x => low (fst x) * low (snd x)
| mullh => fun x => low (fst x) * high (snd x)
| mulhl => fun x => high (fst x) * low (snd x)
| mulhh => fun x => high (fst x) * high (snd x)
- | sub => fun x => Z.sub_get_borrow_full wordmax (fst x) (snd x)
- (* | land n => fun x => Z.land x n *) (* only allowed inside select/mul? *)
- | shiftr n => fun x => Z.shiftr x n
- | shiftl n => fun x => Z.shiftl x n
| rshi n => fun x => Z.rshi wordmax (fst x) (snd x) n
| selc => fun x => Z.zselect (fst (fst x)) (snd (fst x)) (snd x)
| selm => fun x => Z.zselect (Z.cc_m wordmax (fst (fst x))) (snd (fst x)) (snd x)
@@ -8855,50 +8893,33 @@ Module PreFancy.
Inductive ok_ident : forall s d, scalar s -> range_type d -> ident.ident s d -> Prop :=
| ok_add :
- forall x : scalar (type.prod type.Z type.Z),
- in_word_range (fst (get_range x)) ->
- in_word_range (snd (get_range x)) ->
+ forall x y : scalar type.Z,
+ in_word_range (get_range x) ->
+ in_word_range (get_range y) ->
ok_ident _
(type.prod type.Z type.Z)
- x
+ (Pair x y)
(word_range, flag_range)
(ident.Z.add_get_carry_concrete wordmax)
| ok_addc :
- forall x : scalar (type.prod (type.prod type.Z type.Z) type.Z),
- in_flag_range (fst (fst (get_range x))) ->
- in_word_range (snd (fst (get_range x))) ->
- in_word_range (snd (get_range x)) ->
+ forall c x y : scalar type.Z,
+ in_flag_range (get_range c) ->
+ in_word_range (get_range x) ->
+ in_word_range (get_range y) ->
ok_ident _
(type.prod type.Z type.Z)
- x
+ (Pair (Pair c x) y)
(word_range, flag_range)
(ident.Z.add_with_get_carry_concrete wordmax)
| ok_sub :
- forall x : scalar (type.prod type.Z type.Z),
- in_word_range (fst (get_range x)) ->
- in_word_range (snd (get_range x)) ->
+ forall x y : scalar type.Z,
+ in_word_range (get_range x) ->
+ in_word_range (get_range y) ->
ok_ident _
(type.prod type.Z type.Z)
- x
+ (Pair x y)
(word_range, flag_range)
(ident.Z.sub_get_borrow_concrete wordmax)
- (* | ok_land :
- forall (x : scalar type.Z) n,
- in_word_range (get_range x) ->
- 0 <= n < wordmax ->
- ok_ident type.Z type.Z x (ZRange.map (fun y => Z.land y n) (get_range x)) (ident.Z.land n)*)
- | ok_shiftr :
- forall (x : scalar type.Z) n r,
- in_word_range (get_range x) ->
- 0 <= n <= log2wordmax ->
- r = ZRange.map (fun y => Z.shiftr y n) (get_range x) ->
- ok_ident type.Z type.Z x r (ident.Z.shiftr n)
- | ok_shiftl :
- forall (x : scalar type.Z) n,
- in_word_range (get_range x) ->
- 0 <= n ->
- upper (get_range x) * 2 ^ n <= wordmax - 1 ->
- ok_ident type.Z type.Z x word_range (ident.Z.shiftl n)
| ok_rshi :
forall (x : scalar (type.prod type.Z type.Z)) n,
in_word_range (fst (get_range x)) ->
@@ -9150,6 +9171,16 @@ Module PreFancy.
| true => invert H; extract_ok_scalar' 1 x
| _ => fail
end
+ | H : ok_scalar (Pair (Pair x _) _) |- _ =>
+ match (eval compute in (2 <=? level)) with
+ | true => invert H; extract_ok_scalar' 1 x
+ | _ => fail
+ end
+ | H : ok_scalar (Pair (Pair _ x) _) |- _ =>
+ match (eval compute in (2 <=? level)) with
+ | true => invert H; extract_ok_scalar' 1 x
+ | _ => fail
+ end
| H : ok_scalar (?g x) |- _ => invert H
| H : ok_scalar (Pair x _) |- _ => invert H
| H : ok_scalar (Pair _ x) |- _ => invert H
@@ -9304,9 +9335,6 @@ Module PreFancy.
match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
rewrite Z.div_sub_small by omega.
split; break_match; lia. }
- { apply has_range_shiftr; cbn [has_range]; omega. }
- { eapply has_word_range_shiftl; eauto with has_range omega.
- apply in_word_range_spec ; omega. }
{ apply has_word_range_rshi; omega. }
{ rewrite Z.zselect_correct. break_match; omega. }
{ cbn [interp_scalar fst snd get_range] in *.
@@ -9423,6 +9451,73 @@ Module PreFancy.
cbn; rewrite Pxi, Pyi; rewrite interp_cast_noop by auto; reflexivity.
Qed.
+ Lemma has_word_range_mod_small x:
+ @has_range type.Z word_range x ->
+ x mod wordmax = x.
+ Proof.
+ cbv [has_range upper lower].
+ intros. apply Z.mod_small; omega.
+ Qed.
+
+ Lemma invert_shift_correct (s : scalar type.Z) x imm :
+ ok_scalar s ->
+ invert_shift consts s = Some (x, imm) ->
+ interp_scalar s = (interp_scalar x << imm) mod wordmax.
+ Proof.
+ (*
+ intros Hok ?; invert Hok;
+ try match goal with H : ok_scalar ?x, H' : context[Cast _ ?x] |- _ =>
+ invert H end;
+ try match goal with H : ok_scalar ?x, H' : context[Shiftl _ ?x] |- _ =>
+ invert H end;
+ try match goal with H : ok_scalar ?x, H' : context[Shiftl _ (Cast _ ?x)] |- _ =>
+ invert H end;
+ try (cbn [invert_shift invert_upper invert_upper'] in *; discriminate).
+ {
+ repeat match goal with
+ | _ => progress (cbn [invert_shift invert_upper invert_upper'
+ interp_scalar fst snd] in * )
+ | _ => rewrite interp_cast_noop by eauto using has_range_loosen
+ | H : ok_scalar (Shiftr _ _) |- _ => apply has_range_interp_scalar in H
+ | H : context [if ?x then _ else _] |- _ =>
+ let Heq := fresh in case_eq x; intro Heq; rewrite Heq in H
+ | H : _ |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt
+ | H : Some _ = Some _ |- _ => progress (invert H)
+ | _ => progress subst
+ | _ => reflexivity
+ | _ => discriminate
+ end.
+ rewrite has_word_range_mod_small.
+ 2:eauto using has_range_loosen.
+ }
+ {
+ repeat match goal with
+ | _ => progress (cbn [invert_shift invert_upper invert_upper'
+ invert_lower invert_lower'
+ interp_scalar fst snd] in * )
+ | _ => rewrite interp_cast_noop by eauto using has_range_loosen
+ | H : ok_scalar (Shiftr _ _) |- _ => apply has_range_interp_scalar in H
+ | H : ok_scalar (Shiftl _ _) |- _ => apply has_range_interp_scalar in H
+ | H : ok_scalar (Land _ _) |- _ => apply has_range_interp_scalar in H
+ | H : context [if ?x then _ else _] |- _ =>
+ let Heq := fresh in case_eq x; intro Heq; rewrite Heq in H
+ | H : _ |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt
+ | H : Some _ = Some _ |- _ => progress (invert H)
+ | _ => progress subst
+ | _ => reflexivity
+ | _ => discriminate
+ end.
+ Qed.
+ *)
+ Admitted.
+
+ Local Ltac solve_commutative_replace :=
+ match goal with
+ | |- @eq (_ * _) ?x ?y =>
+ replace x with (fst x, snd x) by (destruct x; reflexivity);
+ replace y with (fst y, snd y) by (destruct y; reflexivity)
+ end; autorewrite with to_div_mod; solve [repeat (f_equal; try ring)].
+
Lemma of_straightline_ident_correct s d t x r (idc : ident.ident s d) g :
ok_ident s d x r idc ->
ok_scalar x ->
@@ -9433,16 +9528,29 @@ Module PreFancy.
pose proof wordmax_half_bits_pos.
pose proof (ident_interp_has_range _ _ x r idc ltac:(assumption) ltac:(assumption)).
induction H; try solve [auto using of_straightline_ident_mul_correct];
- cbn [of_straightline_ident ident.interp ident.gen_interp invert_selm invert_sell] in *; intros;
- rewrite ?Z.eqb_refl; cbn [interp interp_ident andb]; try destruct_scalar;
+ cbn [of_straightline_ident ident.interp ident.gen_interp
+ invert_selm invert_sell] in *;
+ intros; rewrite ?Z.eqb_refl; cbn [andb];
+ try match goal with |- context [invert_shift] => break_match end;
+ cbn [interp interp_ident]; try destruct_scalar;
repeat match goal with
| _ => progress (cbn [fst snd interp_scalar] in * )
| _ => progress break_match; [ ]
+ | _ => progress autorewrite with zsimplify_fast
| _ => rewrite interp_cast_noop with (r:=flag_range) in *
by (apply has_flag_range_cc_m'; auto; extract_ok_scalar)
| _ => rewrite interp_cast_noop with (r:=flag_range) in *
by (apply has_flag_range_land'; auto; extract_ok_scalar)
| H : _ = (_,_) |- _ => progress (inversion H; subst)
+ | H : invert_shift _ _ = Some _ |- _ =>
+ apply invert_shift_correct in H; [|extract_ok_scalar];
+ rewrite <-H
+ | H : has_range ?r (?f ?x ?y) |- context [?f ?y ?x] =>
+ replace (f y x) with (f x y) by solve_commutative_replace
+ | _ => rewrite has_word_range_mod_small
+ by (eapply has_range_loosen;
+ [apply has_range_interp_scalar; extract_ok_scalar|];
+ assumption)
| _ => rewrite interp_cast_noop by assumption
| _ => rewrite interp_cast2_noop by assumption
| _ => reflexivity
@@ -9513,25 +9621,32 @@ Module PreFancy.
Notation "f @( y , x1 , x2 , x3 ); g "
:= (LetInAppIdentZZ (uint256, bool)%core f (Pair (Pair x1 x2) x3) (fun y => g))
(at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g ").
- Notation "f @( y , x1 , x2 , x3 ); g "
+ Notation "f @( y , x1 , x2 , x3 ); '#128' g "
:= (LetInAppIdentZZ (uint128, bool)%core f (Pair (Pair x1 x2) x3) (fun y => g))
- (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g ").
+ (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '#128' '//' g ").
Notation "f @( y , x1 , x2 ); g "
:= (LetInAppIdentZ uint256 f (Pair x1 x2) (fun y => g))
(at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g ").
Notation "f @( y , x1 , x2 , x3 ); g "
:= (LetInAppIdentZ uint256 f (Pair (Pair x1 x2) x3) (fun y => g))
(at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g ").
- Notation "shiftL@( y , x , n ); g"
- := (LetInAppIdentZ uint256 (shiftl n) (Lower{x}) (fun y => g))
- (at level 10, g at level 200, format "shiftL@( y , x , n ); '//' g").
- Notation "shiftR@( y , x , n ); g "
- := (LetInAppIdentZ uint128 (shiftr n) x (fun y => g))
- (at level 10, g at level 200, format "shiftR@( y , x , n ); '//' g ").
+ (* special cases for when the ident constructor takes a constant argument *)
+ Notation "add@( y , x1 , x2 , n ); g"
+ := (LetInAppIdentZZ (uint256, bool) (add n) (Pair x1 x2) (fun y => g))
+ (at level 10, g at level 200, format "add@( y , x1 , x2 , n ); '//' g").
+ Notation "addc@( y , x1 , x2 , x3 , n ); g"
+ := (LetInAppIdentZZ (uint256, bool) (addc n) (Pair (Pair x1 x2) x3) (fun y => g))
+ (at level 10, g at level 200, format "addc@( y , x1 , x2 , x3 , n ); '//' g").
+ Notation "addc@( y , x1 , x2 , x3 , n ); '#128' g"
+ := (LetInAppIdentZZ (uint128, bool) (addc n) (Pair (Pair x1 x2) x3) (fun y => g))
+ (at level 10, g at level 200, format "addc@( y , x1 , x2 , x3 , n ); '#128' '//' g").
+ Notation "sub@( y , x1 , x2 , n ); g"
+ := (LetInAppIdentZZ (uint256, bool) (sub n) (Pair x1 x2) (fun y => g))
+ (at level 10, g at level 200, format "sub@( y , x1 , x2 , n ); '//' g").
Notation "rshi@( y , x1 , x2 , n ); g"
:= (LetInAppIdentZ _ (rshi n) (Pair x1 x2) (fun y => g))
(at level 10, g at level 200, format "rshi@( y , x1 , x2 , n ); '//' g ").
- Notation "'Ret' $ x" := (Scalar (Var tZ x)) (at level 10, format "'Ret' $ x").
+ Notation "'ret' $ x" := (Scalar (Var tZ x)) (at level 10, format "'ret' $ x").
Notation "( x , y )" := (Pair x y) (at level 10, left associativity).
End Notations.
End PreFancy.
@@ -10159,30 +10274,24 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type
mulhl@(y3, RegMuLow, $y1);
mullh@(y4, RegMuLow, $y1);
mulll@(y5, RegMuLow, $y1);
- add@(y6, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y3})), $y5);
- addc@(y7, carry{$y6}, $y2, Straightline.expr.Cast uint128
- (Straightline.expr.Shiftr 128 ($y4)));
- add@(y8, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y4})), $y6);
- addc@(y9, carry{$y8}, Straightline.expr.Cast uint128
- (Straightline.expr.Shiftr 128 ($y3)), $y7);
- add@(y10, $y1, $y9);
- addc@(y11, carry{$y10}, RegZero, $y0);
- add@(y12, $y, $y10);
- addc@(y13, carry{$y12}, RegZero, $y11);
+ add@(y6, $y5, $y3, 128);
+ addc@(y7, carry{$y6}, $y2, $y4, -128);
+ add@(y8, $y6, $y4, 128);
+ addc@(y9, carry{$y8}, $y7, $y3, -128);
+ add@(y10, $y1, $y9, 0);
+ addc@(y11, carry{$y10}, RegZero, $y0, 0); #128
+ add@(y12, $y, $y10, 0);
+ addc@(y13, carry{$y12}, RegZero, $y11, 0); #128
rshi@(y14, $y13, $y12,1);
mulhh@(y15, RegMod, $y14);
mullh@(y16, RegMod, $y14);
mulhl@(y17, RegMod, $y14);
mulll@(y18, RegMod, $y14);
- add@(y19, Straightline.expr.Cast uint256
- (Straightline.expr.Shiftl 128 (Lower{$y16})), $y18);
- addc@(y20, carry{$y19}, $y15, Straightline.expr.Cast uint128
- (Straightline.expr.Shiftr 128 ($y17)));
- add@(y21, Straightline.expr.Cast uint256
- (Straightline.expr.Shiftl 128 (Lower{$y17})), $y19);
- addc@(_, carry{$y21}, Straightline.expr.Cast uint128
- (Straightline.expr.Shiftr 128 ($y16)), $y20);
- Straightline.expr.Scalar (Straightline.expr.Primitive (-1))
+ add@(y19, $y18, $y16, 128);
+ addc@(y20, carry{$y19}, $y15, $y17, -128);
+ add@(y21, $y19, $y17, 128);
+ addc@(_, carry{$y21}, $y20, $y16, -128);
+ Straightline.expr.Scalar (Straightline.expr.Primitive (-1))
*)
End Barrett256.
@@ -10880,22 +10989,22 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z *
(*
mulhl@(y0, RegPInv, $x₁);
mulll@(y1, RegPInv, $x₁);
- add@(y2, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y})), $y1);
- add@(y3, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y0})), $y2);
+ add@(y2, $y1, $y, 128);
+ add@(y3, $y2, $y0, 128);
mulhh@(y4, RegMod, $y3);
mullh@(y5, RegMod, $y3);
mulhl@(y6, RegMod, $y3);
mulll@(y7, RegMod, $y3);
- add@(y8, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y5})), $y7);
- addc@(y9, carry{$y8}, $y4, Straightline.expr.Cast uint128 (Straightline.expr.Shiftr 128 ($y6)));
- add@(y10, Straightline.expr.Cast uint256 (Straightline.expr.Shiftl 128 (Lower{$y6})), $y8);
- addc@(y11, carry{$y10}, Straightline.expr.Cast uint128 (Straightline.expr.Shiftr 128 ($y5)), $y9);
- add@(y12, $y10, $x₁);
- addc@(y13, carry{$y12}, $y11, $x₂);
+ add@(y8, $y7, $y5, 128);
+ addc@(y9, carry{$y8}, $y4, $y6, -128);
+ add@(y10, $y8, $y6, 128);
+ addc@(y11, carry{$y10}, $y9, $y5, -128);
+ add@(y12, $y10, $x₁, 0);
+ addc@(y13, carry{$y12}, $y11, $x₂, 0);
selc@(y14, carry{$y13}, RegZero, RegMod);
- sub@(y15, $y13, $y14);
+ sub@(y15, $y13, $y14, 0);
addm@(y16, $y15, RegZero, RegMod);
- Ret $y16
+ ret $y16
*)
End Montgomery256.