aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-30 15:46:28 +0200
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-05-07 04:29:09 -0400
commit838bdf01407af6025c8ac403458dd55ff27ef68f (patch)
treea73d5851f432335398ce89a311d21182c32e034c /src/Experiments/SimplyTypedArithmetic.v
parentde3ec0210ea1d40e2e796591c9a192711e79a03f (diff)
prefancy now works on barrett (modulo add-opp=>sub)
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v246
1 files changed, 196 insertions, 50 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index a1fec63ff..b00da7293 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -7942,6 +7942,7 @@ Module Straightline.
| Shiftr : Z -> scalar type.Z -> scalar type.Z
| Shiftl : Z -> scalar type.Z -> scalar type.Z
| Land : Z -> scalar type.Z -> scalar type.Z
+ | CC_m : Z -> scalar type.Z -> scalar type.Z
| Primitive {t} : type.interp (type.type_primitive t) -> scalar t
.
End with_ident.
@@ -7958,6 +7959,7 @@ Module Straightline.
| ident.Z.shiftr n => fun args => Some (Shiftr n args)
| ident.Z.shiftl n => fun args => Some (Shiftl n args)
| ident.Z.land n => fun args => Some (Land n args)
+ | ident.Z.cc_m_concrete s => fun args => Some (CC_m s args)
| @ident.primitive p x => fun _ => Some (Primitive x)
| _ => fun _ => None
end.
@@ -8028,7 +8030,7 @@ Module Straightline.
end
| None => None
end.
-
+
Definition mk_LetInAppIdent {s d t} (default : expr t)
: range_type d -> ident.ident s d -> scalar s -> (var d -> expr t) -> expr t :=
match d as d0 return range_type d0 -> ident.ident s d0 -> scalar s -> (var d0 -> expr t) -> expr t with
@@ -8136,8 +8138,22 @@ Module StraightlineTest.
AppIdent ident.Let_In
(Pair (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent ident.Z.mul (Pair (AppIdent (@ident.primitive type.Z 12) TT) (Var y))))
(Abs (fun z : var type.Z => (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (ident.Z.shiftr 3) (Var z)))))
- ))))).
- Eval vm_compute in (Straightline.of_Expr test_mul).
+ ))))).
+
+ Definition test_selm : Expr (type.Z -> type.Z) :=
+ fun var =>
+ Abs (fun x : var type.Z =>
+ AppIdent (var:=var) ident.Let_In
+ (Pair (AppIdent (var:=var) (ident.Z.cast r[0~>4294967295]%zrange)
+ (AppIdent (var:=var) ident.Z.zselect
+ (Pair
+ (Pair
+ (AppIdent (var:=var) (ident.Z.cast r[0~>1]%zrange)
+ (AppIdent (var:=var) (ident.Z.cc_m_concrete 4294967296)
+ (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (Var x))))
+ (AppIdent (@ident.primitive type.Z 0) TT))
+ (AppIdent (@ident.primitive type.Z 100) TT))))
+ (Abs (fun z : var type.Z => Var z)))).
End StraightlineTest.
(* Convert straightline code to code that uses only a certain set of identifiers *)
@@ -8159,9 +8175,13 @@ Module PreFancy.
| mulhl : ident (Z * Z) Z
| mulhh : ident (Z * Z) Z
| sub : ident (Z * Z) (Z * Z)
+ | land : BinInt.Z -> ident Z Z
| shiftr : BinInt.Z -> ident Z Z
| shiftl : BinInt.Z -> ident Z Z
- | sel : ident (Z * Z * Z) Z
+ | rshi : BinInt.Z -> ident (Z * Z) Z
+ | selc : ident (Z * Z * Z) Z
+ | selm : ident (Z * Z * Z) Z
+ | sell : ident (Z * Z * Z) Z
| addm : ident (Z * Z * Z) Z
.
Let dummy t : @expr var ident t := Scalar (Var _ (dummy_var t)).
@@ -8208,6 +8228,33 @@ Module PreFancy.
| _ => invert_upper' e
end.
+ Definition invert_sell {t} (e : @scalar var ident t) :
+ option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) :=
+ match e return _ with
+ | Pair _ Z (Pair Z Z x y) z =>
+ match x return option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) with
+ | Cast r (Land n x') =>
+ if (lower r =? 0) && (upper r =? 1) && (n =? 1)
+ then Some (x', y, z)
+ else None
+ | _ => (@None _)
+ end
+ | _ => None
+ end.
+
+ Definition invert_selm {t} (e : @scalar var ident t) :
+ option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) :=
+ match e return _ with
+ | Pair _ Z (Pair Z Z x y) z =>
+ match x return option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) with
+ | Cast r (CC_m n x') =>
+ if (lower r =? 0) && (upper r =? 1) && (n =? wordsize)
+ then Some (x', y, z)
+ else None
+ | _ => (@None _)
+ end
+ | _ => None
+ end.
Definition of_straightline_ident {s d} (idc : ident.ident s d)
: forall t, range_type d -> @scalar var ident s -> (var d -> @expr var ident t) -> @expr var ident t :=
@@ -8227,9 +8274,23 @@ Module PreFancy.
if w =? wordsize
then LetInAppIdentZZ r sub x f
else dummy _
+ | ident.Z.land n => fun _ r => LetInAppIdentZ r (land n)
| ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n)
| ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n)
- | ident.Z.zselect => fun _ r => LetInAppIdentZ r sel
+ | ident.Z.rshi_concrete w n =>
+ fun _ r x f =>
+ if w =? wordsize
+ then LetInAppIdentZ r (rshi n) x f
+ else dummy _
+ | ident.Z.zselect =>
+ fun t r x f =>
+ match invert_selm x with
+ | Some (x, y, z) => LetInAppIdentZ r selm (Pair (Pair x y) z) f
+ | None => match invert_sell x with
+ | Some (x, y, z) => LetInAppIdentZ r sell (Pair (Pair x y) z) f
+ | None => LetInAppIdentZ r selc x f
+ end
+ end
| ident.Z.add_modulo => fun _ r => LetInAppIdentZ r addm
| ident.Z.mul =>
fun t r x f =>
@@ -8271,6 +8332,7 @@ Module PreFancy.
| Shiftr n x => Shiftr n (of_straightline_scalar x)
| Shiftl n x => Shiftl n (of_straightline_scalar x)
| Land n x => Land n (of_straightline_scalar x)
+ | CC_m n x => CC_m n (of_straightline_scalar x)
| Primitive _ x => Primitive x
end.
@@ -8341,6 +8403,7 @@ Module BarrettReduction.
dlet_nd twoq := mul_high mut q1 muSelect in
shiftr twoq 1.
Definition reduce :=
+ dlet_nd qt := qt in
dlet_nd r2 := mul (low qt) M in
dlet_nd r := sub xt r2 in
dlet_nd q3 := cond_sub1 r M in
@@ -8528,7 +8591,7 @@ Module BarrettReduction.
Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2).
Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2).
- Definition widemul := BaseConversion.widemul k n nout.
+ Definition widemul := BaseConversion.widemul_inlined k n nout.
Lemma partition_represents x :
0 <= x < 2^k*2^k ->
@@ -8836,7 +8899,6 @@ Module Barrett256.
Import PrintingNotations.
Open Scope expr_scope.
- Set Printing Width 100000.
Print barrett_red256.
(* TODO: the ADD/ADDC instructions containing Z.opp should be translated to SUB/SUBB in partial evaluation *)
@@ -8847,50 +8909,134 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.
expr_let x0 := SELM (x₂, 0, 26959946667150639793205513449348445388433292963828203772348655992835) in
expr_let x1 := RSHI (0, x₂, 255) in
expr_let x2 := RSHI (x₂, x₁, 255) in
- expr_let x3 := (uint128)(x2 >> 128) in
- expr_let x4 := ((uint128)(x2) & 340282366920938463463374607431768211455) in
- expr_let x5 := 79228162514264337589248983038 *₂₅₆ x4 in
- expr_let x6 := (uint128)(x5 >> 128) in
- expr_let x7 := ((uint128)(x5) & 340282366920938463463374607431768211455) in
- expr_let x8 := 340282366841710300930663525764514709507 *₂₅₆ x3 in
- expr_let x9 := (uint128)(x8 >> 128) in
- expr_let x10 := ((uint128)(x8) & 340282366920938463463374607431768211455) in
- expr_let x11 := 79228162514264337589248983038 *₂₅₆ x3 in
- expr_let x12 := (uint256)(x7 << 128) in
- expr_let x13 := (uint256)(x10 << 128) in
- expr_let x14 := 340282366841710300930663525764514709507 *₂₅₆ x4 in
- expr_let x15 := ADD_256 (x13, x14) in
- expr_let x16 := ADDC_128 (x15₂, x6, x9) in
- expr_let x17 := ADD_256 (x12, x15₁) in
- expr_let x18 := ADDC_256 (x17₂, x11, x16₁) in
- expr_let x19 := ADD_256 (x2, x18₁) in
- expr_let x20 := ADDC_128 (x19₂, 0, x1) in
- expr_let x21 := ADD_256 (x0, x19₁) in
- expr_let x22 := ADDC_128 (x21₂, 0, x20₁) in
- expr_let x23 := RSHI (x22₁, x21₁, 1) in
- expr_let x24 := (uint128)(x23 >> 128) in
- expr_let x25 := ((uint128)(x23) & 340282366920938463463374607431768211455) in
- expr_let x26 := 79228162514264337593543950335 *₂₅₆ x24 in
- expr_let x27 := (uint128)(x26 >> 128) in
- expr_let x28 := ((uint128)(x26) & 340282366920938463463374607431768211455) in
- expr_let x29 := 340282366841710300967557013911933812736 *₂₅₆ x25 in
- expr_let x30 := (uint128)(x29 >> 128) in
- expr_let x31 := ((uint128)(x29) & 340282366920938463463374607431768211455) in
- expr_let x32 := 340282366841710300967557013911933812736 *₂₅₆ x24 in
- expr_let x33 := (uint256)(x28 << 128) in
- expr_let x34 := (uint256)(x31 << 128) in
- expr_let x35 := 79228162514264337593543950335 *₂₅₆ x25 in
- expr_let x36 := ADD_256 (x34, x35) in
- expr_let x37 := ADDC_256 (x36₂, x27, x30) in
- expr_let x38 := ADD_256 (x33, x36₁) in
- expr_let x39 := ADDC_256 (x38₂, x32, x37₁) in
- expr_let x40 := Z.add_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (Z.opp @@ (fst @@ x38), x₁) in
- expr_let x41 := Z.add_with_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (x40₂, Z.opp @@ (fst @@ x39), x₂) in
- expr_let x42 := SELL (x41₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
- expr_let x43 := Z.cast uint256 @@ (fst @@ SUB_256 (x40₁, x42)) in
- ADDM (x43, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
+ expr_let x3 := 79228162514264337589248983038 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in
+ expr_let x4 := 340282366841710300930663525764514709507 *₂₅₆ (uint128)(x2 >> 128) in
+ expr_let x5 := 79228162514264337589248983038 *₂₅₆ (uint128)(x2 >> 128) in
+ expr_let x6 := (uint256)(((uint128)(x3) & 340282366920938463463374607431768211455) << 128) in
+ expr_let x7 := (uint128)(x3 >> 128) in
+ expr_let x8 := (uint256)(((uint128)(x4) & 340282366920938463463374607431768211455) << 128) in
+ expr_let x9 := (uint128)(x4 >> 128) in
+ expr_let x10 := 340282366841710300930663525764514709507 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in
+ expr_let x11 := ADD_256 (x8, x10) in
+ expr_let x12 := ADDC_128 (x11₂, x7, x9) in
+ expr_let x13 := ADD_256 (x6, x11₁) in
+ expr_let x14 := ADDC_256 (x13₂, x5, x12₁) in
+ expr_let x15 := ADD_256 (x2, x14₁) in
+ expr_let x16 := ADDC_128 (x15₂, 0, x1) in
+ expr_let x17 := ADD_256 (x0, x15₁) in
+ expr_let x18 := ADDC_128 (x17₂, 0, x16₁) in
+ expr_let x19 := RSHI (x18₁, x17₁, 1) in
+ expr_let x20 := 79228162514264337593543950335 *₂₅₆ (uint128)(x19 >> 128) in
+ expr_let x21 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x19) & 340282366920938463463374607431768211455) in
+ expr_let x22 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x19 >> 128) in
+ expr_let x23 := (uint256)(((uint128)(x20) & 340282366920938463463374607431768211455) << 128) in
+ expr_let x24 := (uint128)(x20 >> 128) in
+ expr_let x25 := (uint256)(((uint128)(x21) & 340282366920938463463374607431768211455) << 128) in
+ expr_let x26 := (uint128)(x21 >> 128) in
+ expr_let x27 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x19) & 340282366920938463463374607431768211455) in
+ expr_let x28 := ADD_256 (x25, x27) in
+ expr_let x29 := ADDC_256 (x28₂, x24, x26) in
+ expr_let x30 := ADD_256 (x23, x28₁) in
+ expr_let x31 := ADDC_256 (x30₂, x22, x29₁) in
+ expr_let x32 := Z.add_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (Z.opp @@ (fst @@ x30), x₁) in
+ expr_let x33 := Z.add_with_get_carry_concrete 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ (x32₂, Z.opp @@ (fst @@ x31), x₂) in
+ expr_let x34 := SELL (x33₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
+ expr_let x35 := Z.cast uint256 @@ (fst @@ SUB_256 (x32₁, x34)) in
+ ADDM (x35, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
: Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z))
*)
+ Import Straightline.expr.
+ Import PreFancy.
+
+ Definition barrett_red256_straightline := Eval lazy in (fun var dummy_var x => Straightline.of_Expr barrett_red256 var x dummy_var).
+
+ Definition constant_to_scalar_gen var ident (const x : Z) : option (@scalar var ident type.Z) :=
+ if x =? (BinInt.Z.shiftr const 128)
+ then Some (Cast uint128 (Shiftr 128 (Primitive (t:=type.Z) const)))
+ else if x =? (BinInt.Z.land const (2^128 - 1))
+ then Some (Cast uint128 (Land (2^128-1) (Primitive (t:=type.Z) const)))
+ else None.
+
+ Definition muLow := (2 ^ (2 * machine_wordsize) / M) mod (2^machine_wordsize).
+ Definition constant_to_scalar var ident (x : Z) : option (@scalar var ident type.Z) :=
+ match (constant_to_scalar_gen var ident M x) with
+ | Some s => Some s
+ | None => constant_to_scalar_gen var ident muLow x
+ end.
+
+ Definition barrett_red256_prefancy :=
+ Eval vm_compute in (fun var dummy_var x =>
+ @of_straightline var dummy_var machine_wordsize (constant_to_scalar var) _
+ (barrett_red256_straightline var dummy_var x)).
+
+ Local Notation "'tZ'" := (type.type_primitive type.Z).
+ Local Notation "'RegMod'" := (Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951).
+ Local Notation "'RegMuLow'" := (Primitive (t:=type.Z) 26959946667150639793205513449348445388433292963828203772348655992835).
+ Local Notation "'RegZero'" := (Primitive (t:=type.Z) 0).
+ Local Notation "$ x" := (Cast uint256 (Fst (Cast2 (uint256,bool)%core (Var (tZ * tZ) x)))) (at level 10, format "$ x").
+ Local Notation "$ x" := (Cast uint128 (Fst (Cast2 (uint128,bool)%core (Var (tZ * tZ) x)))) (at level 10, format "$ x").
+ Local Notation "$ x ₁" := (Cast uint256 (Fst (Var (tZ * tZ) x))) (at level 10, format "$ x ₁").
+ Local Notation "$ x ₂" := (Cast uint256 (Snd (Var (tZ * tZ) x))) (at level 10, format "$ x ₂").
+ Local Notation "carry{ $ x }" := (Cast bool (Snd (Cast2 (uint256, bool)%core (Var (tZ * tZ) x)))) (at level 10, format "carry{ $ x }").
+ Local Notation "Lower{ x }" := (Cast uint128 (Land 340282366920938463463374607431768211455 x)) (at level 10, format "Lower{ x }").
+ Local Notation "$ x" := (Cast uint256 (Var tZ x)) (at level 10, format "$ x").
+ Local Notation "$ x" := (Cast uint128 (Var tZ x)) (at level 10, format "$ x").
+ Local Notation "f @( y , x1 , x2 ); g "
+ := (LetInAppIdentZZ (uint256, bool)%core f (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g ").
+ Local Notation "f @( y , x1 , x2 , x3 ); g "
+ := (LetInAppIdentZZ (uint256, bool)%core f (Pair (Pair x1 x2) x3) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g ").
+ Local Notation "f @( y , x1 , x2 , x3 ); g "
+ := (LetInAppIdentZZ (uint128, bool)%core f (Pair (Pair x1 x2) x3) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g ").
+ Local Notation "f @( y , x1 , x2 ); g "
+ := (LetInAppIdentZ uint256 f (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g ").
+ Local Notation "f @( y , x1 , x2 , x3 ); g "
+ := (LetInAppIdentZ uint256 f (Pair (Pair x1 x2) x3) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g ").
+ Local Notation "shiftL@( y , x , n ); g"
+ := (LetInAppIdentZ uint256 (shiftl n) (Lower{x}) (fun y => g)) (at level 10, g at level 200, format "shiftL@( y , x , n ); '//' g").
+ Local Notation "shiftR@( y , x , n ); g "
+ := (LetInAppIdentZ uint128 (shiftr n) x (fun y => g)) (at level 10, g at level 200, format "shiftR@( y , x , n ); '//' g ").
+ Local Notation "rshi@( y , x1 , x2 , n ); g"
+ := (LetInAppIdentZ _ (rshi n) (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "rshi@( y , x1 , x2 , n ); '//' g ").
+ Local Notation "'Ret' $ x" := (Scalar (Var tZ x)) (at level 10, format "'Ret' $ x").
+ Local Notation "( x , y )" := (Pair x y) (at level 10, left associativity).
+ Print barrett_red256_prefancy.
+ (* TODO : make prefancy cast uint128s to uint256s *)
+ (* Note : currently (correctly) fails to convert the adds that should be subs *)
+ (*
+ selm@(x0, $x₂, RegZero, RegMuLow);
+ rshi@(x1, RegZero, $x₂,255);
+ rshi@(x2, $x₂, $x₁,255);
+ mulhl@(x3, RegMuLow, $x2);
+ mullh@(x4, RegMuLow, $x2);
+ mulhh@(x5, RegMuLow, $x2);
+ shiftL@(x6, $x3, 128);
+ shiftR@(x7, $x3, 128);
+ shiftL@(x8, $x4, 128);
+ shiftR@(x9, $x4, 128);
+ mulll@(x10, RegMuLow, $x2);
+ add@(x11, $x8, $x10);
+ addc@(x12, carry{$x11}, $x7, $x9);
+ add@(x13, $x6, $x11);
+ addc@(x14, carry{$x13}, $x5, $x12);
+ add@(x15, $x2, $x14);
+ addc@(x16, carry{$x15}, RegZero, Cast bool (Var tZ x1));
+ add@(x17, $x0, $x15);
+ addc@(x18, carry{$x17}, RegZero, $x16);
+ rshi@(x19, $x18, $x17,1);
+ mullh@(x20, RegMod, $x19);
+ mulhl@(x21, RegMod, $x19);
+ mulhh@(x22, RegMod, $x19);
+ shiftL@(x23, $x20, 128);
+ shiftR@(x24, $x20, 128);
+ shiftL@(x25, $x21, 128);
+ shiftR@(x26, $x21, 128);
+ mulll@(x27, RegMod, $x19);
+ add@(x28, $x25, $x27);
+ addc@(x29, carry{$x28}, $x24, $x26);
+ add@(x30, $x23, $x28);
+ addc@(_, carry{$x30}, $x22, $x29);
+ Ret $(dummy_var tZ)
+ *)
End Barrett256.
Module SaturatedSolinas.
@@ -9495,7 +9641,7 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z *
else if x =? (BinInt.Z.land N' (2^128 - 1))
then Some (Cast uint128 (Land (2^128-1) (Primitive (t:=type.Z) N')))
else None.
-
+
Definition montred256_prefancy :=
Eval vm_compute in (fun var dummy_var x =>
@of_straightline var dummy_var machine_wordsize (constant_to_scalar var) _