aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-30 11:12:45 +0200
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-05-07 04:29:09 -0400
commit6f8733f05344bda560fabb384ac25971089e7783 (patch)
tree4260b420af5aa2370188fc2ed4860da29733040a /src/Experiments
parent306b6f1900c49747b6e9c911ab395e43223ed77e (diff)
Translating to 'pre-fancy' form now works on Montgomery
Diffstat (limited to 'src/Experiments')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v553
1 files changed, 353 insertions, 200 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 880128643..9618588d4 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -1797,6 +1797,18 @@ Module BaseConversion.
push_eval.
Qed.
+ (* bind the input, but *not* the carrying operations *)
+ Derive from_associational_inlined
+ SuchThat (forall n idxs p,
+ from_associational_inlined n idxs p = from_associational n idxs p)
+ As from_associational_inlined_correct.
+ Proof.
+ intros.
+ cbv beta iota delta [from_associational Associational.carry Associational.carryterm].
+ unfold Let_In at 2 3.
+ subst from_associational_inlined; reflexivity.
+ Qed.
+
(* carry chain that aligns terms in the intermediate weight with the final weight *)
Definition aligned_carries (log_dw_sw nout : nat)
:= (map (fun i => ((log_dw_sw * (i + 1)) - 1))%nat (seq 0 nout)).
@@ -1858,6 +1870,23 @@ Module BaseConversion.
rewrite Z.pow_mul_r, Z.pow_2_r by omega.
Z.rewrite_mod_small. reflexivity.
Qed.
+
+ Derive widemul_inlined
+ SuchThat (forall a b,
+ 0 <= a * b < 2^log2base * 2^log2base ->
+ widemul_inlined a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base])
+ As widemul_inlined_correct.
+ Proof.
+ intros.
+ rewrite <-widemul_correct by auto.
+ cbv beta iota delta [widemul mul_converted to_associational convert_bases
+ chained_carries_no_reduce carry
+ Associational.carry Associational.carryterm
+ Positional.from_associational].
+ cbv beta iota delta [Let_In].
+ rewrite <-from_associational_inlined_correct.
+ subst widemul_inlined; reflexivity.
+ Qed.
End widemul.
End BaseConversion.
@@ -8493,7 +8522,6 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.
ADDM (x43, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
: Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z))
*)
-
End Barrett256.
Module SaturatedSolinas.
@@ -8901,30 +8929,39 @@ Module Straightline.
Let uexpr t := @Uncurried.expr.expr ident.ident var t.
- Inductive expr : type.type -> Type :=
- | Scalar {t} : scalar t -> expr t
- | LetInAppIdentZ {s d} : zrange -> ident.ident s (type.Z) -> scalar s -> (var (type.Z) -> expr d) -> expr d
- | LetInAppIdentZZ {s d} : zrange * zrange -> ident.ident s (type.Z*type.Z) -> scalar s -> (var (type.Z*type.Z) -> expr d) -> expr d
- with scalar : type.type -> Type :=
- | Var t : var t -> scalar t
- | TT : scalar (type.type_primitive type.unit)
- | Pair {a b} : scalar a -> scalar b -> scalar (a * b)
- | Cast : zrange -> scalar type.Z -> scalar type.Z
- | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z)
- | Fst {a b} : scalar (a * b) -> scalar a
- | Snd {a b} : scalar (a * b) -> scalar b
- | Primitive {t} : type.interp (type.type_primitive t) -> scalar t
- .
-
- Fixpoint dummy t : expr t := Scalar (Var t (dummy_var t)).
+ Section with_ident.
+ Context {ident : type.type -> type.type -> Type}.
+ Inductive expr : type.type -> Type :=
+ | Scalar {t} : scalar t -> expr t
+ | LetInAppIdentZ {s d} : zrange -> ident s (type.Z) -> scalar s -> (var (type.Z) -> expr d) -> expr d
+ | LetInAppIdentZZ {s d} : zrange * zrange -> ident s (type.Z*type.Z) -> scalar s -> (var (type.Z*type.Z) -> expr d) -> expr d
+ with scalar : type.type -> Type :=
+ | Var t : var t -> scalar t
+ | TT : scalar (type.type_primitive type.unit)
+ | Pair {a b} : scalar a -> scalar b -> scalar (a * b)
+ | Cast : zrange -> scalar type.Z -> scalar type.Z
+ | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z)
+ | Fst {a b} : scalar (a * b) -> scalar a
+ | Snd {a b} : scalar (a * b) -> scalar b
+ | Shiftr : Z -> scalar type.Z -> scalar type.Z
+ | Shiftl : Z -> scalar type.Z -> scalar type.Z
+ | Land : Z -> scalar type.Z -> scalar type.Z
+ | Primitive {t} : type.interp (type.type_primitive t) -> scalar t
+ .
+ End with_ident.
+
+ Fixpoint dummy t : @expr ident.ident t := Scalar (Var t (dummy_var t)).
Definition scalar_of_uncurried_ident {s d} (idc : ident.ident s d)
: scalar s -> option (scalar d) :=
- match idc in ident.ident s d return scalar s -> option (scalar d) with
+ match idc in ident.ident s d return scalar (ident:=ident.ident) s -> option (scalar d) with
| ident.Z.cast r => fun args => Some (Cast r args)
| ident.Z.cast2 r => fun args => Some (Cast2 r args)
- | @ident.fst A B => fun args => Some (@Fst A B args)
- | @ident.snd A B => fun args => Some (@Snd A B args)
+ | @ident.fst A B => fun args => Some (Fst args)
+ | @ident.snd A B => fun args => Some (Snd args)
+ | ident.Z.shiftr n => fun args => Some (Shiftr n args)
+ | ident.Z.shiftl n => fun args => Some (Shiftl n args)
+ | ident.Z.land n => fun args => Some (Land n args)
| @ident.primitive p x => fun _ => Some (Primitive x)
| _ => fun _ => None
end.
@@ -9000,9 +9037,9 @@ Module Straightline.
: range_type d -> ident.ident s d -> scalar s -> (var d -> expr t) -> expr t :=
match d as d0 return range_type d0 -> ident.ident s d0 -> scalar s -> (var d0 -> expr t) -> expr t with
| type.type_primitive type.Z =>
- fun r idc x k => @LetInAppIdentZ s t r idc x k
+ fun r idc x k => @LetInAppIdentZ ident.ident s t r idc x k
| type.prod type.Z type.Z =>
- fun r idc x k => @LetInAppIdentZZ s t r idc x k
+ fun r idc x k => @LetInAppIdentZZ ident.ident s t r idc x k
| _ => fun _ _ _ _ => default
end.
@@ -9107,89 +9144,151 @@ Module StraightlineTest.
Eval vm_compute in (Straightline.of_Expr test_mul).
End StraightlineTest.
-(*
-Module InlineOperations.
+(* Convert straightline code to code that uses only a certain set of identifiers *)
+Module PreFancy.
Section with_var.
- Context {var : type.type -> Type} (op_match : forall t, @expr.expr ident.ident var t -> bool).
-
- Fixpoint inline_op {t} (e : @expr.expr ident.ident var t) : @expr.expr ident.ident var t
- :=
- match e in expr.expr t return @expr.expr ident.ident var t with
- | Var t _ as e => e
- | TT as e
- => e
- | Pair A B a b
- => Pair (@inline_op A a) (@inline_op B b)
- | App s d f x => App (@inline_op _ f) (@inline_op _ x)
- | Abs s d f => Abs (fun v => @inline_op _ (f v))
- | AppIdent s d idc args
- => inline_op_ident idc (@inline_op s args)
- end.
+ Import Straightline.expr.
+ Context {var : type -> Type} (dummy_var : forall t, var t) (log2wordsize : Z)
+ (constant_to_scalar : forall ident, Z -> option (@scalar var ident type.Z)).
+ Local Notation Z := (type.type_primitive type.Z).
+ Let wordsize := 2 ^ log2wordsize.
+ Let half_bits := log2wordsize / 2.
+ Let half_wordsize := 2 ^ half_bits.
+
+ Inductive ident : type -> type -> Type :=
+ | add : ident (Z * Z) (Z * Z)
+ | addc : ident (Z * Z * Z) (Z * Z)
+ | mulll : ident (Z * Z) Z
+ | mullh : ident (Z * Z) Z
+ | mulhl : ident (Z * Z) Z
+ | mulhh : ident (Z * Z) Z
+ | sub : ident (Z * Z) (Z * Z)
+ | shiftr : BinInt.Z -> ident Z Z
+ | shiftl : BinInt.Z -> ident Z Z
+ | sel : ident (Z * Z * Z) Z
+ | addm : ident (Z * Z * Z) Z
+ .
+ Let dummy t : @expr var ident t := Scalar (Var _ (dummy_var t)).
+
+ Definition invert_lower' {t} (e : @scalar var ident t) :
+ option (@scalar var ident Z) :=
+ match e in scalar t return option (@scalar var ident Z) with
+ | Cast r (Land n x) =>
+ if (lower r =? 0) && (upper r =? (half_wordsize - 1)) && (n =? 2^half_bits-1)
+ then Some x
+ else None
+ | _ => None
+ end.
- (*
- AppIdent ident.Let_In (Pair (AppIdent (ident.shiftr n) x) (Abs (fun y : var _ => F (Var _ y))))
+ Definition invert_upper' {t} (e : @scalar var ident t) :
+ option (@scalar var ident Z) :=
+ match e in scalar t return option (@scalar var ident Z) with
+ | Cast r (Shiftr n x) =>
+ if (lower r =? 0) && (upper r =? (half_wordsize - 1)) && (n =? half_bits)
+ then Some x
+ else None
+ | _ => None
+ end.
- =>
+ Definition invert_lower {t} (e : @scalar var ident t) :
+ option (@scalar var ident Z) :=
+ match e in scalar t return option (@scalar var ident Z) with
+ | Primitive type.Z x =>
+ match constant_to_scalar ident x with
+ | Some y => invert_lower' y
+ | None => None
+ end
+ | _ => invert_lower' e
+ end.
- F (AppIdent (ident.shiftr n) x)
- *)
+ Definition invert_upper {t} (e : @scalar var ident t) :
+ option (@scalar var ident Z) :=
+ match e in scalar t return option (@scalar var ident Z) with
+ | Primitive type.Z x =>
+ match constant_to_scalar ident x with
+ | Some y => invert_upper' y
+ | None => None
+ end
+ | _ => invert_upper' e
+ end.
- Definition replace_num {t} (old : var type.Z) (new : @expr.expr ident.ident var type.Z) (e : @expr.expr ident.ident var t)
- : @expr.expr ident.ident var t
- := match e with
- | Var _ v =>
- | TT => TT
- |
-
-
- Definition inline_if_match {tx tC} (args : @expr.expr ident.ident var (tx * (tx -> tC)))
- : @expr.expr ident.ident var tC :=
- let default _ := AppIdent (@ident.Let_In tx tC) args in
- match invert_Pair args with
- | Some (e, Abs s d k) =>
- if op_match tx e then App k e else default tt
- | None => default tt
+ Definition of_straightline_ident {s d} (idc : ident.ident s d)
+ : forall t, range_type d -> @scalar var ident s -> (var d -> @expr var ident t) -> @expr var ident t :=
+ match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with
+ | ident.Z.add_get_carry_concrete w =>
+ fun t r x f =>
+ if w =? wordsize
+ then LetInAppIdentZZ r add x f
+ else dummy _
+ | ident.Z.add_with_get_carry_concrete w =>
+ fun t r x f =>
+ if w =? wordsize
+ then LetInAppIdentZZ r addc x f
+ else dummy _
+ | ident.Z.sub_get_borrow_concrete w =>
+ fun t r x f =>
+ if w =? wordsize
+ then LetInAppIdentZZ r sub x f
+ else dummy _
+ | ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n)
+ | ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n)
+ | ident.Z.zselect => fun _ r => LetInAppIdentZ r sel
+ | ident.Z.add_modulo => fun _ r => LetInAppIdentZ r addm
+ | ident.Z.mul =>
+ fun t r x f =>
+ match x return expr t with
+ | Pair _ _ x0 x1 =>
+ match invert_lower x0, invert_lower x1 with
+ | Some y0, Some y1 => LetInAppIdentZ r mulll (Pair y0 y1) f
+ | Some y0, None =>
+ match invert_upper x1 with
+ | Some y1 => LetInAppIdentZ r mullh (Pair y0 y1) f
+ | None => dummy _
+ end
+ | None, Some y1 =>
+ match invert_upper x0 with
+ | Some y0 => LetInAppIdentZ r mulhl (Pair y0 y1) f
+ | None => dummy _
+ end
+ | None, None =>
+ match invert_upper x0, invert_upper x1 with
+ | Some y0, Some y1 => LetInAppIdentZ r mulhh (Pair y0 y1) f
+ | _,_ => dummy _
+ end
+ end
+ | _ => dummy _
+ end
+ | _ => fun t _ _ _ => dummy t
end.
- Definition inline_op_ident {s d} (idc : ident.ident s d)
- : expr.expr s -> expr.expr d
- :=
- match idc in ident.ident s d return @expr.expr ident.ident var s -> expr.expr d with
- | ident.Let_In tx tC => fun args : @expr.expr ident.ident var (tx * (tx -> tC)) =>
- @inline_if_match tx tC args
- | _ as i => fun args => AppIdent i args
- end.
+ Fixpoint of_straightline_scalar {t} (s : @scalar var ident.ident t)
+ : @scalar var ident t :=
+ match s with
+ | Var _ v => Var _ v
+ | TT => TT
+ | Pair _ _ x y => Pair (of_straightline_scalar x) (of_straightline_scalar y)
+ | Cast r x => Cast r (of_straightline_scalar x)
+ | Cast2 r x => Cast2 r (of_straightline_scalar x)
+ | Fst _ _ x => Fst (of_straightline_scalar x)
+ | Snd _ _ x => Snd (of_straightline_scalar x)
+ | Shiftr n x => Shiftr n (of_straightline_scalar x)
+ | Shiftl n x => Shiftl n (of_straightline_scalar x)
+ | Land n x => Land n (of_straightline_scalar x)
+ | Primitive _ x => Primitive x
+ end.
- Fixpoint inline_op {t} (e : @expr.expr ident.ident var t) : @expr.expr ident.ident var t
- :=
- match e in expr.expr t return @expr.expr ident.ident var t with
- | Var t _ as e => e
- | TT as e
- => e
- | Pair A B a b
- => Pair (@inline_op A a) (@inline_op B b)
- | App s d f x => App (@inline_op _ f) (@inline_op _ x)
- | Abs s d f => Abs (fun v => @inline_op _ (f v))
- | AppIdent s d idc args
- => inline_op_ident idc (@inline_op s args)
- end.
+ Fixpoint of_straightline {t} (e : @expr var ident.ident t)
+ : @expr var ident t :=
+ match e with
+ | Scalar _ s => Scalar (of_straightline_scalar s)
+ | LetInAppIdentZ _ t r idc x f =>
+ of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y))
+ | LetInAppIdentZZ _ t r idc x f =>
+ of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y))
+ end.
End with_var.
-
- Fixpoint shift_land_matcher {var} t (e : @expr.expr ident.ident var t) : bool :=
- match e with
- | AppIdent _ _ (ident.Z_cast _) args => shift_land_matcher _ args
- | AppIdent _ _ (ident.Z_cast2 _) args => shift_land_matcher _ args
- | AppIdent _ _ (ident.Z_land _) args => true
- | AppIdent _ _ (ident.Z_shiftr _) args => true
- | _ => false
- end.
-
- Definition inline_shiftr_and_land {t} (e : Expr t) : Expr t :=
- fun var => inline_op shift_land_matcher (e _).
-
-End InlineOperations.
-*)
+End PreFancy.
Module MontgomeryReduction.
Section MontRed'.
@@ -9206,8 +9305,8 @@ Module MontgomeryReduction.
Context (nout : nat) (Hnout : nout = 2%nat).
Definition montred' (lo_hi : (Z * Z)) :=
- dlet_nd y := nth_default 0 (BaseConversion.widemul Zlog2R n nout (fst lo_hi) N') 0 in
- dlet_nd t1_t2 := (BaseConversion.widemul Zlog2R n nout y N) in
+ dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R n nout (fst lo_hi) N') 0 in
+ dlet_nd t1_t2 := (BaseConversion.widemul_inlined Zlog2R n nout y N) in
dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [fst lo_hi; snd lo_hi] t1_t2 in
dlet_nd y' := Z.zselect (snd sum_carry) 0 N in
dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in
@@ -9244,7 +9343,7 @@ Module MontgomeryReduction.
rewrite Hlo, Hhi.
assert (0 <= T mod R * N' < w 2) by (solve_range).
- rewrite !BaseConversion.widemul_correct
+ rewrite !BaseConversion.widemul_inlined_correct
by (rewrite ?BaseConversion.widemul_correct; autorewrite with push_nth_default; solve_range).
rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega).
rewrite R_two_pow.
@@ -9297,7 +9396,6 @@ Module MontgomeryReduction.
Context (N R N' : Z)
(machine_wordsize : Z).
- Let n : nat := Z.to_nat (Qceiling ((Z.log2_up N) / machine_wordsize)).
Let bound := Some r[0 ~> (2^machine_wordsize - 1)%Z]%zrange.
Definition relax_zrange_of_machine_wordsize
@@ -9308,11 +9406,7 @@ Module MontgomeryReduction.
Definition check_args {T} (res : Pipeline.ErrorT T)
: Pipeline.ErrorT T
- := if (N =? 0)%Z
- then Pipeline.Error (Pipeline.Values_not_provably_distinct "N ≠ 0" N 0)
- else if (n =? 0)%Z
- then Pipeline.Error (Pipeline.Values_not_provably_distinct "n ≠ 0" N 0)
- else res.
+ := res. (* TODO: this should actually check stuff that corresponds with preconditions of montred'_correct *)
Notation BoundsPipeline_correct in_bounds out_bounds op
:= (fun rv (rop : Expr (type.reify_type_of op)) Hrop
@@ -9355,51 +9449,113 @@ Module Montgomery256.
As montred256_correct.
Proof. Time solve_rmontred machine_wordsize. Time Qed.
- Import Straightline.expr.
Import PrintingNotations.
- Set Printing Depth 1000000.
- Local Notation "'tZ'" := (type.type_primitive type.Z).
- Eval lazy in (Straightline.of_Expr montred256).
+ Set Printing Width 10000.
+
Print montred256.
+(*
+montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype,
+ expr_let x0 := 79228162514264337593543950337 *₂₅₆ (uint128)(x₁ >> 128) in
+ expr_let x1 := 340282366841710300986003757985643364352 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in
+ expr_let x2 := (uint256)(((uint128)(x0) & 340282366920938463463374607431768211455) << 128) in
+ expr_let x3 := (uint256)(((uint128)(x1) & 340282366920938463463374607431768211455) << 128) in
+ expr_let x4 := 79228162514264337593543950337 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in
+ expr_let x5 := ADD_256 (x3, x4) in
+ expr_let x6 := ADD_256 (x2, x5₁) in
+ expr_let x7 := 79228162514264337593543950335 *₂₅₆ (uint128)(x6₁ >> 128) in
+ expr_let x8 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x6₁) & 340282366920938463463374607431768211455) in
+ expr_let x9 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x6₁ >> 128) in
+ expr_let x10 := (uint256)(((uint128)(x7) & 340282366920938463463374607431768211455) << 128) in
+ expr_let x11 := (uint128)(x7 >> 128) in
+ expr_let x12 := (uint256)(((uint128)(x8) & 340282366920938463463374607431768211455) << 128) in
+ expr_let x13 := (uint128)(x8 >> 128) in
+ expr_let x14 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x6₁) & 340282366920938463463374607431768211455) in
+ expr_let x15 := ADD_256 (x12, x14) in
+ expr_let x16 := ADDC_256 (x15₂, x11, x13) in
+ expr_let x17 := ADD_256 (x10, x15₁) in
+ expr_let x18 := ADDC_256 (x17₂, x9, x16₁) in
+ expr_let x19 := ADD_256 (x17₁, x₁) in
+ expr_let x20 := ADDC_256 (x19₂, x18₁, x₂) in
+ expr_let x21 := SELC (x20₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
+ expr_let x22 := SUB_256 (x20₁, x21) in
+ ADDM (x22₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951))%expr
+ : Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z))
+*)
+
+ Import Straightline.expr.
+ Import PreFancy.
+
+ Definition montred256_straightline := Eval lazy in (fun var dummy_var x => Straightline.of_Expr montred256 var x dummy_var).
+
+ Definition constant_to_scalar var ident (x : Z) : option (@scalar var ident type.Z) :=
+ if x =? (BinInt.Z.shiftr N 128)
+ then Some (Cast uint128 (Shiftr 128 (Primitive (t:=type.Z) N)))
+ else if x =? (BinInt.Z.shiftr N' 128)
+ then Some (Cast uint128 (Shiftr 128 (Primitive (t:=type.Z) N')))
+ else if x =? (BinInt.Z.land N (2^128 - 1))
+ then Some (Cast uint128 (Land (2^128-1) (Primitive (t:=type.Z) N)))
+ else if x =? (BinInt.Z.land N' (2^128 - 1))
+ then Some (Cast uint128 (Land (2^128-1) (Primitive (t:=type.Z) N')))
+ else None.
+
+ Definition montred256_prefancy :=
+ Eval vm_compute in (fun var dummy_var x =>
+ @of_straightline var dummy_var machine_wordsize (constant_to_scalar var) _
+ (montred256_straightline var dummy_var x)).
-(* TODO: to get the right kind of output operation, I probably want
- to inline all these shifts/ands. *)
+ Local Notation "'tZ'" := (type.type_primitive type.Z).
+ Local Notation "'RegMod'" := (Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951).
+ Local Notation "'RegPInv'" := (Primitive (t:=type.Z) 115792089210356248768974548684794254293921932838497980611635986753331132366849).
+ Local Notation "'RegZero'" := (Primitive (t:=type.Z) 0).
+ Local Notation "$ x" := (Cast uint256 (Fst (Cast2 (uint256,bool) (Var (tZ * tZ) x)))) (at level 10, format "$ x").
+ Local Notation "$ x ₁" := (Cast uint256 (Fst (Var (tZ * tZ) x))) (at level 10, format "$ x ₁").
+ Local Notation "$ x ₂" := (Cast uint256 (Snd (Var (tZ * tZ) x))) (at level 10, format "$ x ₂").
+ Local Notation "carry{ $ x }" := (Cast bool (Snd (Cast2 (uint256, bool) (Var (tZ * tZ) x)))) (at level 10, format "carry{ $ x }").
+ Local Notation "Lower{ x }" := (Cast uint128 (Land 340282366920938463463374607431768211455 x)) (at level 10, format "Lower{ x }").
+ Local Notation "$ x" := (Cast uint256 (Var tZ x)) (at level 10, format "$ x").
+ Local Notation "$ x" := (Cast uint128 (Var tZ x)) (at level 10, format "$ x").
+ Local Notation "f @( y , x1 , x2 ); g "
+ := (LetInAppIdentZZ (uint256, bool) f (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g ").
+ Local Notation "f @( y , x1 , x2 , x3 ); g "
+ := (LetInAppIdentZZ (uint256, bool) f (Pair (Pair x1 x2) x3) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g ").
+ Local Notation "f @( y , x1 , x2 ); g "
+ := (LetInAppIdentZ uint256 f (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g ").
+ Local Notation "f @( y , x1 , x2 , x3 ); g "
+ := (LetInAppIdentZ uint256 f (Pair (Pair x1 x2) x3) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g ").
+ Local Notation "shiftL@( y , x , n ); g"
+ := (LetInAppIdentZ uint256 (shiftl n) (Lower{x}) (fun y => g)) (at level 10, g at level 200, format "shiftL@( y , x , n ); '//' g ").
+ Local Notation "shiftR@( y , x , n ); g"
+ := (LetInAppIdentZ uint128 (shiftr n) x (fun y => g)) (at level 10, g at level 200, format "shiftR@( y , x , n ); '//' g ").
+ Local Notation "'Ret' $ x" := (Scalar (Var tZ x)) (at level 10, format "'Ret' $ x").
+ Local Notation "( x , y )" := (Pair x y) (at level 10, left associativity).
+ Print montred256_prefancy.
(*
-montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype,
- expr_let x0 := (uint128)(x₁ >> 128) in
- expr_let x1 := ((uint128)(x₁) & 340282366920938463463374607431768211455) in
- expr_let x2 := 79228162514264337593543950337 *₂₅₆ x0 in
- expr_let x3 := ((uint128)(x2) & 340282366920938463463374607431768211455) in
- expr_let x4 := 340282366841710300986003757985643364352 *₂₅₆ x1 in
- expr_let x5 := ((uint128)(x4) & 340282366920938463463374607431768211455) in
- expr_let x6 := (uint256)(x3 << 128) in
- expr_let x7 := (uint256)(x5 << 128) in
- expr_let x8 := 79228162514264337593543950337 *₂₅₆ x1 in
- expr_let x9 := ADD_256 (x7, x8) in
- expr_let x10 := ADD_256 (x6, x9₁) in
- expr_let x11 := (uint128)(x10₁ >> 128) in
- expr_let x12 := ((uint128)(x10₁) & 340282366920938463463374607431768211455) in
- expr_let x13 := 79228162514264337593543950335 *₂₅₆ x11 in
- expr_let x14 := (uint128)(x13 >> 128) in
- expr_let x15 := ((uint128)(x13) & 340282366920938463463374607431768211455) in
- expr_let x16 := 340282366841710300967557013911933812736 *₂₅₆ x12 in
- expr_let x17 := (uint128)(x16 >> 128) in
- expr_let x18 := ((uint128)(x16) & 340282366920938463463374607431768211455) in
- expr_let x19 := 340282366841710300967557013911933812736 *₂₅₆ x11 in
- expr_let x20 := (uint256)(x15 << 128) in
- expr_let x21 := (uint256)(x18 << 128) in
- expr_let x22 := 79228162514264337593543950335 *₂₅₆ x12 in
- expr_let x23 := ADD_256 (x21, x22) in
- expr_let x24 := ADDC_256 (x23₂, x14, x17) in
- expr_let x25 := ADD_256 (x20, x23₁) in
- expr_let x26 := ADDC_256 (x25₂, x19, x24₁) in
- expr_let x27 := ADD_256 (x25₁, x₁) in
- expr_let x28 := ADDC_256 (x27₂, x26₁, x₂) in
- expr_let x29 := SELC (x28₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
- expr_let x30 := Z.cast uint256 @@ (fst @@ SUB_256 (x28₁, x29)) in
- ADDM (x30, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
- : Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z))
- *)
+ mullh@(x0, RegPInv, $x₁);
+ mulhl@(x1, RegPInv, $x₁);
+ shiftL@(x2, $x0, 128);
+ shiftL@(x3, $x1, 128);
+ mulll@(x4, RegPInv, $x₁);
+ add@(x5, $x3, $x4);
+ add@(x6, $x2, $x5);
+ mullh@(x7, RegMod, $x6);
+ mulhl@(x8, RegMod, $x6);
+ mulhh@(x9, RegMod, $x6);
+ shiftL@(x10, $x7, 128);
+ shiftR@(x11, $x7, 128);
+ shiftL@(x12, $x8, 128);
+ shiftR@(x13, $x8, 128);
+ mulll@(x14, RegMod, $x6);
+ add@(x15, $x12, $x14);
+ addc@(x16, carry{$x15}, $x11, $x13);
+ add@(x17, $x10, $x15);
+ addc@(x18, carry{$x17}, $x9, $x16);
+ add@(x19, $x17, $x₁);
+ addc@(x20, carry{$x19}, $x18, $x₂);
+ sel@(x21, carry{$x20}, RegZero, RegMod);
+ sub@(x22, $x20, $x21);
+ addm@(x23, $x22, RegZero, RegMod);
+ Ret $x23
+ *)
End Montgomery256.
(* Extra-specialized ad-hoc pretty-printing *)
@@ -9424,30 +9580,30 @@ Module FancyPrintingNotations.
(primitive 0)
TT) (only printing, at level 9) : expr_scope.
Notation "'$R'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 : expr_scope.
- Notation "'Lower128{RegMod}'" :=
+ Notation "'c.Lower(RegMod)'" :=
(AppIdent
(primitive 79228162514264337593543950335)
TT) (only printing, at level 9) : expr_scope.
- Notation "'RegMod' '>>' '128'" :=
+ Notation "'c.Upper(RegMod)'" :=
(AppIdent
(primitive 340282366841710300967557013911933812736)
- TT) (only printing, at level 9, format "'RegMod' '>>' '128'") : expr_scope.
- Notation "'Lower128{RegMuLow}'" :=
+ TT) (only printing, at level 9) : expr_scope.
+ Notation "'c.Lower(RegMuLow)'" :=
(AppIdent
(primitive 340282366841710300930663525764514709507)
TT) (only printing, at level 9) : expr_scope.
- Notation "'RegMuLow' '>>' '128'" :=
+ Notation "'c.Upper(RegMuLow)'" :=
(AppIdent
(primitive 79228162514264337589248983038)
- TT) (only printing, at level 9, format "'RegMuLow' '>>' '128'") : expr_scope.
- Notation "'Lower128{RegPinv}'" :=
+ TT) (only printing, at level 9) : expr_scope.
+ Notation "'c.Lower(RegPinv)'" :=
(AppIdent
(primitive 79228162514264337593543950337)
TT) (only printing, at level 9) : expr_scope.
- Notation "'RegPinv' '>>' '128'" :=
+ Notation "'c.Upper(RegPinv)'" :=
(AppIdent
(primitive 340282366841710300986003757985643364352)
- TT) (only printing, at level 9, format "'RegPinv' '>>' '128'") : expr_scope.
+ TT) (only printing, at level 9) : expr_scope.
Notation "'uint256'"
:= (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : ctype_scope.
Notation "'uint128'"
@@ -9495,6 +9651,9 @@ Module FancyPrintingNotations.
Notation "'c.Sub(' '$' n ',' x ',' y ');' f" :=
(expr_let n := Z.cast uint256 @@ (fst @@ (Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)))) in
f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope.
+ Notation "'c.Sub(' '$' n ',' x ',' y ');' f" :=
+ (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope.
Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" :=
(Z.cast uint256 @@ (Z.add_modulo @@ (x, y, z)))%expr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : expr_scope.
Notation "'c.ShiftR(' '$' n ',' x ',' y ');' f" :=
@@ -9503,15 +9662,17 @@ Module FancyPrintingNotations.
(expr_let n := Z.cast _ @@ (Z.rshi_concrete $R m @@ (x, y)) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Rshi(' '$' n ',' x ',' y ',' m ');' ']' '//' f") : expr_scope.
Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" :=
(expr_let n := Z.cast _ @@ (Z.shiftl y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
+ Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" :=
+ (expr_let n := Z.cast _ @@ (Z.shiftl y @@ (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455 @@ x))) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
Notation "'c.Lower128(' '$' n ',' x ');' f" :=
(expr_let n := Z.cast _ @@ (Z.land 340282366920938463463374607431768211455 @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$' n ',' x ');' ']' '//' f") : expr_scope.
- Notation "'c.LowerHalf(' x ')'"
- := (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455))
- (at level 10, only printing, format "c.LowerHalf( x )")
+ Notation "'c.Lower(' x ')'"
+ := (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455 @@ x))
+ (at level 10, only printing, format "c.Lower( x )")
: expr_scope.
- Notation "'c.UpperHalf(' x ')'"
- := (Z.cast uint128 @@ (Z.shiftr 340282366920938463463374607431768211455))
- (at level 10, only printing, format "c.UpperHalf( x )")
+ Notation "'c.Upper(' x ')'"
+ := (Z.cast uint128 @@ (Z.shiftr 128 @@ x))
+ (at level 10, only printing, format "c.Upper( x )")
: expr_scope.
Notation "( v << count )"
:= (Z.cast _ @@ (Z.shiftl count @@ v)%expr)
@@ -9538,16 +9699,16 @@ c.Rshi($x1, RegZero, $x_hi, 255);
c.Rshi($x2, $x_hi, $x_lo, 255);
c.ShiftR($x3, $x2, 128);
c.Lower128($x4, $x2);
-c.Mul128x128($x5, RegMuLow >> 128, $x4);
+c.Mul128x128($x5, c.Upper(RegMuLow), $x4);
c.ShiftR($x6, $x5, 128);
c.Lower128($x7, $x5);
-c.Mul128x128($x8, Lower128{RegMuLow}, $x3);
+c.Mul128x128($x8, c.Lower(RegMuLow), $x3);
c.ShiftR($x9, $x8, 128);
c.Lower128($x10, $x8);
-c.Mul128x128($x11, RegMuLow >> 128, $x3);
+c.Mul128x128($x11, c.Upper(RegMuLow), $x3);
c.ShiftL($x12, $x7, 128);
c.ShiftL($x13, $x10, 128);
-c.Mul128x128($x14, Lower128{RegMuLow}, $x4);
+c.Mul128x128($x14, c.Lower(RegMuLow), $x4);
c.Add256($x15, $x13, $x14);
c.Addc128($x16, $x15_hi, $x6, $x9);
c.Add256($x17, $x12, $x15_lo);
@@ -9559,16 +9720,16 @@ c.Addc128($x22, $x21_hi, RegZero, $x20_lo);
c.Rshi($x23, $x22_lo, $x21_lo, 1);
c.ShiftR($x24, $x23, 128);
c.Lower128($x25, $x23);
-c.Mul128x128($x26, Lower128{RegMod}, $x24);
+c.Mul128x128($x26, c.Lower(RegMod), $x24);
c.ShiftR($x27, $x26, 128);
c.Lower128($x28, $x26);
-c.Mul128x128($x29, RegMod >> 128, $x25);
+c.Mul128x128($x29, c.Upper(RegMod), $x25);
c.ShiftR($x30, $x29, 128);
c.Lower128($x31, $x29);
-c.Mul128x128($x32, RegMod >> 128, $x24);
+c.Mul128x128($x32, c.Upper(RegMod), $x24);
c.ShiftL($x33, $x28, 128);
c.ShiftL($x34, $x31, 128);
-c.Mul128x128($x35, Lower128{RegMod}, $x25);
+c.Mul128x128($x35, c.Lower(RegMod), $x25);
c.Add256($x36, $x34, $x35);
c.Addc256($x37, $x36_hi, $x27, $x30);
c.Add256($x38, $x33, $x36_lo);
@@ -9584,36 +9745,28 @@ c.AddM($ret, $x43, RegZero, RegMod);
Print Montgomery256.montred256.
(*
-c.ShiftR($x0, $x_lo, 128);
-c.Lower128($x1, $x_lo);
-c.Mul128x128($x2, Lower128{RegPinv}, $x0);
-c.Lower128($x3, $x2);
-c.Mul128x128($x4, RegPinv >> 128, $x1);
-c.Lower128($x5, $x4);
-c.ShiftL($x6, $x3, 128);
-c.ShiftL($x7, $x5, 128);
-c.Mul128x128($x8, Lower128{RegPinv}, $x1);
-c.Add256($x9, $x7, $x8);
-c.Add256($x10, $x6, $x9_lo);
-c.ShiftR($x11, $x10_lo, 128);
-c.Lower128($x12, $x10_lo);
-c.Mul128x128($x13, Lower128{RegMod}, $x11);
-c.ShiftR($x14, $x13, 128);
-c.Lower128($x15, $x13);
-c.Mul128x128($x16, RegMod >> 128, $x12);
-c.ShiftR($x17, $x16, 128);
-c.Lower128($x18, $x16);
-c.Mul128x128($x19, RegMod >> 128, $x11);
-c.ShiftL($x20, $x15, 128);
-c.ShiftL($x21, $x18, 128);
-c.Mul128x128($x22, Lower128{RegMod}, $x12);
-c.Add256($x23, $x21, $x22);
-c.Addc256($x24, $x23_hi, $x14, $x17);
-c.Add256($x25, $x20, $x23_lo);
-c.Addc256($x26, $x25_hi, $x19, $x24_lo);
-c.Add256($x27, $x25_lo, $x_lo);
-c.Addc256($x28, $x27_hi, $x26_lo, $x_hi);
-c.Selc($x29, $x28_hi, RegZero, RegMod);
-c.Sub($x30, $x28_lo, $x29);
-c.AddM($ret, $x30, RegZero, RegMod);
+c.Mul128x128($x0, c.Lower(RegPinv), c.Upper($x_lo));
+c.Mul128x128($x1, c.Upper(RegPinv), c.Lower($x_lo));
+c.ShiftL($x2, $x0, 128);
+c.ShiftL($x3, $x1, 128);
+c.Mul128x128($x4, c.Lower(RegPinv), c.Lower($x_lo));
+c.Add256($x5, $x3, $x4);
+c.Add256($x6, $x2, $x5_lo);
+c.Mul128x128($x7, c.Lower(RegMod), c.Upper($x6_lo));
+c.Mul128x128($x8, c.Upper(RegMod), c.Lower($x6_lo));
+c.Mul128x128($x9, c.Upper(RegMod), c.Upper($x6_lo));
+c.ShiftL($x10, $x7, 128);
+c.ShiftR($x11, $x7, 128);
+c.ShiftL($x12, $x8, 128);
+c.ShiftR($x13, $x8, 128);
+c.Mul128x128($x14, c.Lower(RegMod), c.Lower($x6_lo));
+c.Add256($x15, $x12, $x14);
+c.Addc256($x16, $x15_hi, $x11, $x13);
+c.Add256($x17, $x10, $x15_lo);
+c.Addc256($x18, $x17_hi, $x9, $x16_lo);
+c.Add256($x19, $x17_lo, $x_lo);
+c.Addc256($x20, $x19_hi, $x18_lo, $x_hi);
+c.Selc($x21, $x20_hi, RegZero, RegMod);
+c.Sub($x22, $x20_lo, $x21);
+c.AddM($ret, $x22_lo, RegZero, RegMod);
*)