From 6f8733f05344bda560fabb384ac25971089e7783 Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Mon, 30 Apr 2018 11:12:45 +0200 Subject: Translating to 'pre-fancy' form now works on Montgomery --- src/Experiments/SimplyTypedArithmetic.v | 553 ++++++++++++++++++++------------ 1 file changed, 353 insertions(+), 200 deletions(-) (limited to 'src') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 880128643..9618588d4 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -1797,6 +1797,18 @@ Module BaseConversion. push_eval. Qed. + (* bind the input, but *not* the carrying operations *) + Derive from_associational_inlined + SuchThat (forall n idxs p, + from_associational_inlined n idxs p = from_associational n idxs p) + As from_associational_inlined_correct. + Proof. + intros. + cbv beta iota delta [from_associational Associational.carry Associational.carryterm]. + unfold Let_In at 2 3. + subst from_associational_inlined; reflexivity. + Qed. + (* carry chain that aligns terms in the intermediate weight with the final weight *) Definition aligned_carries (log_dw_sw nout : nat) := (map (fun i => ((log_dw_sw * (i + 1)) - 1))%nat (seq 0 nout)). @@ -1858,6 +1870,23 @@ Module BaseConversion. rewrite Z.pow_mul_r, Z.pow_2_r by omega. Z.rewrite_mod_small. reflexivity. Qed. + + Derive widemul_inlined + SuchThat (forall a b, + 0 <= a * b < 2^log2base * 2^log2base -> + widemul_inlined a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base]) + As widemul_inlined_correct. + Proof. + intros. + rewrite <-widemul_correct by auto. + cbv beta iota delta [widemul mul_converted to_associational convert_bases + chained_carries_no_reduce carry + Associational.carry Associational.carryterm + Positional.from_associational]. + cbv beta iota delta [Let_In]. + rewrite <-from_associational_inlined_correct. + subst widemul_inlined; reflexivity. + Qed. End widemul. End BaseConversion. @@ -8493,7 +8522,6 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type. ADDM (x43, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) : Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z)) *) - End Barrett256. Module SaturatedSolinas. @@ -8901,30 +8929,39 @@ Module Straightline. Let uexpr t := @Uncurried.expr.expr ident.ident var t. - Inductive expr : type.type -> Type := - | Scalar {t} : scalar t -> expr t - | LetInAppIdentZ {s d} : zrange -> ident.ident s (type.Z) -> scalar s -> (var (type.Z) -> expr d) -> expr d - | LetInAppIdentZZ {s d} : zrange * zrange -> ident.ident s (type.Z*type.Z) -> scalar s -> (var (type.Z*type.Z) -> expr d) -> expr d - with scalar : type.type -> Type := - | Var t : var t -> scalar t - | TT : scalar (type.type_primitive type.unit) - | Pair {a b} : scalar a -> scalar b -> scalar (a * b) - | Cast : zrange -> scalar type.Z -> scalar type.Z - | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z) - | Fst {a b} : scalar (a * b) -> scalar a - | Snd {a b} : scalar (a * b) -> scalar b - | Primitive {t} : type.interp (type.type_primitive t) -> scalar t - . - - Fixpoint dummy t : expr t := Scalar (Var t (dummy_var t)). + Section with_ident. + Context {ident : type.type -> type.type -> Type}. + Inductive expr : type.type -> Type := + | Scalar {t} : scalar t -> expr t + | LetInAppIdentZ {s d} : zrange -> ident s (type.Z) -> scalar s -> (var (type.Z) -> expr d) -> expr d + | LetInAppIdentZZ {s d} : zrange * zrange -> ident s (type.Z*type.Z) -> scalar s -> (var (type.Z*type.Z) -> expr d) -> expr d + with scalar : type.type -> Type := + | Var t : var t -> scalar t + | TT : scalar (type.type_primitive type.unit) + | Pair {a b} : scalar a -> scalar b -> scalar (a * b) + | Cast : zrange -> scalar type.Z -> scalar type.Z + | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z) + | Fst {a b} : scalar (a * b) -> scalar a + | Snd {a b} : scalar (a * b) -> scalar b + | Shiftr : Z -> scalar type.Z -> scalar type.Z + | Shiftl : Z -> scalar type.Z -> scalar type.Z + | Land : Z -> scalar type.Z -> scalar type.Z + | Primitive {t} : type.interp (type.type_primitive t) -> scalar t + . + End with_ident. + + Fixpoint dummy t : @expr ident.ident t := Scalar (Var t (dummy_var t)). Definition scalar_of_uncurried_ident {s d} (idc : ident.ident s d) : scalar s -> option (scalar d) := - match idc in ident.ident s d return scalar s -> option (scalar d) with + match idc in ident.ident s d return scalar (ident:=ident.ident) s -> option (scalar d) with | ident.Z.cast r => fun args => Some (Cast r args) | ident.Z.cast2 r => fun args => Some (Cast2 r args) - | @ident.fst A B => fun args => Some (@Fst A B args) - | @ident.snd A B => fun args => Some (@Snd A B args) + | @ident.fst A B => fun args => Some (Fst args) + | @ident.snd A B => fun args => Some (Snd args) + | ident.Z.shiftr n => fun args => Some (Shiftr n args) + | ident.Z.shiftl n => fun args => Some (Shiftl n args) + | ident.Z.land n => fun args => Some (Land n args) | @ident.primitive p x => fun _ => Some (Primitive x) | _ => fun _ => None end. @@ -9000,9 +9037,9 @@ Module Straightline. : range_type d -> ident.ident s d -> scalar s -> (var d -> expr t) -> expr t := match d as d0 return range_type d0 -> ident.ident s d0 -> scalar s -> (var d0 -> expr t) -> expr t with | type.type_primitive type.Z => - fun r idc x k => @LetInAppIdentZ s t r idc x k + fun r idc x k => @LetInAppIdentZ ident.ident s t r idc x k | type.prod type.Z type.Z => - fun r idc x k => @LetInAppIdentZZ s t r idc x k + fun r idc x k => @LetInAppIdentZZ ident.ident s t r idc x k | _ => fun _ _ _ _ => default end. @@ -9107,89 +9144,151 @@ Module StraightlineTest. Eval vm_compute in (Straightline.of_Expr test_mul). End StraightlineTest. -(* -Module InlineOperations. +(* Convert straightline code to code that uses only a certain set of identifiers *) +Module PreFancy. Section with_var. - Context {var : type.type -> Type} (op_match : forall t, @expr.expr ident.ident var t -> bool). - - Fixpoint inline_op {t} (e : @expr.expr ident.ident var t) : @expr.expr ident.ident var t - := - match e in expr.expr t return @expr.expr ident.ident var t with - | Var t _ as e => e - | TT as e - => e - | Pair A B a b - => Pair (@inline_op A a) (@inline_op B b) - | App s d f x => App (@inline_op _ f) (@inline_op _ x) - | Abs s d f => Abs (fun v => @inline_op _ (f v)) - | AppIdent s d idc args - => inline_op_ident idc (@inline_op s args) - end. + Import Straightline.expr. + Context {var : type -> Type} (dummy_var : forall t, var t) (log2wordsize : Z) + (constant_to_scalar : forall ident, Z -> option (@scalar var ident type.Z)). + Local Notation Z := (type.type_primitive type.Z). + Let wordsize := 2 ^ log2wordsize. + Let half_bits := log2wordsize / 2. + Let half_wordsize := 2 ^ half_bits. + + Inductive ident : type -> type -> Type := + | add : ident (Z * Z) (Z * Z) + | addc : ident (Z * Z * Z) (Z * Z) + | mulll : ident (Z * Z) Z + | mullh : ident (Z * Z) Z + | mulhl : ident (Z * Z) Z + | mulhh : ident (Z * Z) Z + | sub : ident (Z * Z) (Z * Z) + | shiftr : BinInt.Z -> ident Z Z + | shiftl : BinInt.Z -> ident Z Z + | sel : ident (Z * Z * Z) Z + | addm : ident (Z * Z * Z) Z + . + Let dummy t : @expr var ident t := Scalar (Var _ (dummy_var t)). + + Definition invert_lower' {t} (e : @scalar var ident t) : + option (@scalar var ident Z) := + match e in scalar t return option (@scalar var ident Z) with + | Cast r (Land n x) => + if (lower r =? 0) && (upper r =? (half_wordsize - 1)) && (n =? 2^half_bits-1) + then Some x + else None + | _ => None + end. - (* - AppIdent ident.Let_In (Pair (AppIdent (ident.shiftr n) x) (Abs (fun y : var _ => F (Var _ y)))) + Definition invert_upper' {t} (e : @scalar var ident t) : + option (@scalar var ident Z) := + match e in scalar t return option (@scalar var ident Z) with + | Cast r (Shiftr n x) => + if (lower r =? 0) && (upper r =? (half_wordsize - 1)) && (n =? half_bits) + then Some x + else None + | _ => None + end. - => + Definition invert_lower {t} (e : @scalar var ident t) : + option (@scalar var ident Z) := + match e in scalar t return option (@scalar var ident Z) with + | Primitive type.Z x => + match constant_to_scalar ident x with + | Some y => invert_lower' y + | None => None + end + | _ => invert_lower' e + end. - F (AppIdent (ident.shiftr n) x) - *) + Definition invert_upper {t} (e : @scalar var ident t) : + option (@scalar var ident Z) := + match e in scalar t return option (@scalar var ident Z) with + | Primitive type.Z x => + match constant_to_scalar ident x with + | Some y => invert_upper' y + | None => None + end + | _ => invert_upper' e + end. - Definition replace_num {t} (old : var type.Z) (new : @expr.expr ident.ident var type.Z) (e : @expr.expr ident.ident var t) - : @expr.expr ident.ident var t - := match e with - | Var _ v => - | TT => TT - | - - - Definition inline_if_match {tx tC} (args : @expr.expr ident.ident var (tx * (tx -> tC))) - : @expr.expr ident.ident var tC := - let default _ := AppIdent (@ident.Let_In tx tC) args in - match invert_Pair args with - | Some (e, Abs s d k) => - if op_match tx e then App k e else default tt - | None => default tt + Definition of_straightline_ident {s d} (idc : ident.ident s d) + : forall t, range_type d -> @scalar var ident s -> (var d -> @expr var ident t) -> @expr var ident t := + match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with + | ident.Z.add_get_carry_concrete w => + fun t r x f => + if w =? wordsize + then LetInAppIdentZZ r add x f + else dummy _ + | ident.Z.add_with_get_carry_concrete w => + fun t r x f => + if w =? wordsize + then LetInAppIdentZZ r addc x f + else dummy _ + | ident.Z.sub_get_borrow_concrete w => + fun t r x f => + if w =? wordsize + then LetInAppIdentZZ r sub x f + else dummy _ + | ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n) + | ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n) + | ident.Z.zselect => fun _ r => LetInAppIdentZ r sel + | ident.Z.add_modulo => fun _ r => LetInAppIdentZ r addm + | ident.Z.mul => + fun t r x f => + match x return expr t with + | Pair _ _ x0 x1 => + match invert_lower x0, invert_lower x1 with + | Some y0, Some y1 => LetInAppIdentZ r mulll (Pair y0 y1) f + | Some y0, None => + match invert_upper x1 with + | Some y1 => LetInAppIdentZ r mullh (Pair y0 y1) f + | None => dummy _ + end + | None, Some y1 => + match invert_upper x0 with + | Some y0 => LetInAppIdentZ r mulhl (Pair y0 y1) f + | None => dummy _ + end + | None, None => + match invert_upper x0, invert_upper x1 with + | Some y0, Some y1 => LetInAppIdentZ r mulhh (Pair y0 y1) f + | _,_ => dummy _ + end + end + | _ => dummy _ + end + | _ => fun t _ _ _ => dummy t end. - Definition inline_op_ident {s d} (idc : ident.ident s d) - : expr.expr s -> expr.expr d - := - match idc in ident.ident s d return @expr.expr ident.ident var s -> expr.expr d with - | ident.Let_In tx tC => fun args : @expr.expr ident.ident var (tx * (tx -> tC)) => - @inline_if_match tx tC args - | _ as i => fun args => AppIdent i args - end. + Fixpoint of_straightline_scalar {t} (s : @scalar var ident.ident t) + : @scalar var ident t := + match s with + | Var _ v => Var _ v + | TT => TT + | Pair _ _ x y => Pair (of_straightline_scalar x) (of_straightline_scalar y) + | Cast r x => Cast r (of_straightline_scalar x) + | Cast2 r x => Cast2 r (of_straightline_scalar x) + | Fst _ _ x => Fst (of_straightline_scalar x) + | Snd _ _ x => Snd (of_straightline_scalar x) + | Shiftr n x => Shiftr n (of_straightline_scalar x) + | Shiftl n x => Shiftl n (of_straightline_scalar x) + | Land n x => Land n (of_straightline_scalar x) + | Primitive _ x => Primitive x + end. - Fixpoint inline_op {t} (e : @expr.expr ident.ident var t) : @expr.expr ident.ident var t - := - match e in expr.expr t return @expr.expr ident.ident var t with - | Var t _ as e => e - | TT as e - => e - | Pair A B a b - => Pair (@inline_op A a) (@inline_op B b) - | App s d f x => App (@inline_op _ f) (@inline_op _ x) - | Abs s d f => Abs (fun v => @inline_op _ (f v)) - | AppIdent s d idc args - => inline_op_ident idc (@inline_op s args) - end. + Fixpoint of_straightline {t} (e : @expr var ident.ident t) + : @expr var ident t := + match e with + | Scalar _ s => Scalar (of_straightline_scalar s) + | LetInAppIdentZ _ t r idc x f => + of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y)) + | LetInAppIdentZZ _ t r idc x f => + of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y)) + end. End with_var. - - Fixpoint shift_land_matcher {var} t (e : @expr.expr ident.ident var t) : bool := - match e with - | AppIdent _ _ (ident.Z_cast _) args => shift_land_matcher _ args - | AppIdent _ _ (ident.Z_cast2 _) args => shift_land_matcher _ args - | AppIdent _ _ (ident.Z_land _) args => true - | AppIdent _ _ (ident.Z_shiftr _) args => true - | _ => false - end. - - Definition inline_shiftr_and_land {t} (e : Expr t) : Expr t := - fun var => inline_op shift_land_matcher (e _). - -End InlineOperations. -*) +End PreFancy. Module MontgomeryReduction. Section MontRed'. @@ -9206,8 +9305,8 @@ Module MontgomeryReduction. Context (nout : nat) (Hnout : nout = 2%nat). Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (BaseConversion.widemul Zlog2R n nout (fst lo_hi) N') 0 in - dlet_nd t1_t2 := (BaseConversion.widemul Zlog2R n nout y N) in + dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R n nout (fst lo_hi) N') 0 in + dlet_nd t1_t2 := (BaseConversion.widemul_inlined Zlog2R n nout y N) in dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [fst lo_hi; snd lo_hi] t1_t2 in dlet_nd y' := Z.zselect (snd sum_carry) 0 N in dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in @@ -9244,7 +9343,7 @@ Module MontgomeryReduction. rewrite Hlo, Hhi. assert (0 <= T mod R * N' < w 2) by (solve_range). - rewrite !BaseConversion.widemul_correct + rewrite !BaseConversion.widemul_inlined_correct by (rewrite ?BaseConversion.widemul_correct; autorewrite with push_nth_default; solve_range). rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega). rewrite R_two_pow. @@ -9297,7 +9396,6 @@ Module MontgomeryReduction. Context (N R N' : Z) (machine_wordsize : Z). - Let n : nat := Z.to_nat (Qceiling ((Z.log2_up N) / machine_wordsize)). Let bound := Some r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. Definition relax_zrange_of_machine_wordsize @@ -9308,11 +9406,7 @@ Module MontgomeryReduction. Definition check_args {T} (res : Pipeline.ErrorT T) : Pipeline.ErrorT T - := if (N =? 0)%Z - then Pipeline.Error (Pipeline.Values_not_provably_distinct "N ≠ 0" N 0) - else if (n =? 0)%Z - then Pipeline.Error (Pipeline.Values_not_provably_distinct "n ≠ 0" N 0) - else res. + := res. (* TODO: this should actually check stuff that corresponds with preconditions of montred'_correct *) Notation BoundsPipeline_correct in_bounds out_bounds op := (fun rv (rop : Expr (type.reify_type_of op)) Hrop @@ -9355,51 +9449,113 @@ Module Montgomery256. As montred256_correct. Proof. Time solve_rmontred machine_wordsize. Time Qed. - Import Straightline.expr. Import PrintingNotations. - Set Printing Depth 1000000. - Local Notation "'tZ'" := (type.type_primitive type.Z). - Eval lazy in (Straightline.of_Expr montred256). + Set Printing Width 10000. + Print montred256. +(* +montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, + expr_let x0 := 79228162514264337593543950337 *₂₅₆ (uint128)(x₁ >> 128) in + expr_let x1 := 340282366841710300986003757985643364352 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in + expr_let x2 := (uint256)(((uint128)(x0) & 340282366920938463463374607431768211455) << 128) in + expr_let x3 := (uint256)(((uint128)(x1) & 340282366920938463463374607431768211455) << 128) in + expr_let x4 := 79228162514264337593543950337 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in + expr_let x5 := ADD_256 (x3, x4) in + expr_let x6 := ADD_256 (x2, x5₁) in + expr_let x7 := 79228162514264337593543950335 *₂₅₆ (uint128)(x6₁ >> 128) in + expr_let x8 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x6₁) & 340282366920938463463374607431768211455) in + expr_let x9 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x6₁ >> 128) in + expr_let x10 := (uint256)(((uint128)(x7) & 340282366920938463463374607431768211455) << 128) in + expr_let x11 := (uint128)(x7 >> 128) in + expr_let x12 := (uint256)(((uint128)(x8) & 340282366920938463463374607431768211455) << 128) in + expr_let x13 := (uint128)(x8 >> 128) in + expr_let x14 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x6₁) & 340282366920938463463374607431768211455) in + expr_let x15 := ADD_256 (x12, x14) in + expr_let x16 := ADDC_256 (x15₂, x11, x13) in + expr_let x17 := ADD_256 (x10, x15₁) in + expr_let x18 := ADDC_256 (x17₂, x9, x16₁) in + expr_let x19 := ADD_256 (x17₁, x₁) in + expr_let x20 := ADDC_256 (x19₂, x18₁, x₂) in + expr_let x21 := SELC (x20₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in + expr_let x22 := SUB_256 (x20₁, x21) in + ADDM (x22₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951))%expr + : Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)) +*) + + Import Straightline.expr. + Import PreFancy. + + Definition montred256_straightline := Eval lazy in (fun var dummy_var x => Straightline.of_Expr montred256 var x dummy_var). + + Definition constant_to_scalar var ident (x : Z) : option (@scalar var ident type.Z) := + if x =? (BinInt.Z.shiftr N 128) + then Some (Cast uint128 (Shiftr 128 (Primitive (t:=type.Z) N))) + else if x =? (BinInt.Z.shiftr N' 128) + then Some (Cast uint128 (Shiftr 128 (Primitive (t:=type.Z) N'))) + else if x =? (BinInt.Z.land N (2^128 - 1)) + then Some (Cast uint128 (Land (2^128-1) (Primitive (t:=type.Z) N))) + else if x =? (BinInt.Z.land N' (2^128 - 1)) + then Some (Cast uint128 (Land (2^128-1) (Primitive (t:=type.Z) N'))) + else None. + + Definition montred256_prefancy := + Eval vm_compute in (fun var dummy_var x => + @of_straightline var dummy_var machine_wordsize (constant_to_scalar var) _ + (montred256_straightline var dummy_var x)). -(* TODO: to get the right kind of output operation, I probably want - to inline all these shifts/ands. *) + Local Notation "'tZ'" := (type.type_primitive type.Z). + Local Notation "'RegMod'" := (Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951). + Local Notation "'RegPInv'" := (Primitive (t:=type.Z) 115792089210356248768974548684794254293921932838497980611635986753331132366849). + Local Notation "'RegZero'" := (Primitive (t:=type.Z) 0). + Local Notation "$ x" := (Cast uint256 (Fst (Cast2 (uint256,bool) (Var (tZ * tZ) x)))) (at level 10, format "$ x"). + Local Notation "$ x ₁" := (Cast uint256 (Fst (Var (tZ * tZ) x))) (at level 10, format "$ x ₁"). + Local Notation "$ x ₂" := (Cast uint256 (Snd (Var (tZ * tZ) x))) (at level 10, format "$ x ₂"). + Local Notation "carry{ $ x }" := (Cast bool (Snd (Cast2 (uint256, bool) (Var (tZ * tZ) x)))) (at level 10, format "carry{ $ x }"). + Local Notation "Lower{ x }" := (Cast uint128 (Land 340282366920938463463374607431768211455 x)) (at level 10, format "Lower{ x }"). + Local Notation "$ x" := (Cast uint256 (Var tZ x)) (at level 10, format "$ x"). + Local Notation "$ x" := (Cast uint128 (Var tZ x)) (at level 10, format "$ x"). + Local Notation "f @( y , x1 , x2 ); g " + := (LetInAppIdentZZ (uint256, bool) f (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g "). + Local Notation "f @( y , x1 , x2 , x3 ); g " + := (LetInAppIdentZZ (uint256, bool) f (Pair (Pair x1 x2) x3) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g "). + Local Notation "f @( y , x1 , x2 ); g " + := (LetInAppIdentZ uint256 f (Pair x1 x2) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g "). + Local Notation "f @( y , x1 , x2 , x3 ); g " + := (LetInAppIdentZ uint256 f (Pair (Pair x1 x2) x3) (fun y => g)) (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g "). + Local Notation "shiftL@( y , x , n ); g" + := (LetInAppIdentZ uint256 (shiftl n) (Lower{x}) (fun y => g)) (at level 10, g at level 200, format "shiftL@( y , x , n ); '//' g "). + Local Notation "shiftR@( y , x , n ); g" + := (LetInAppIdentZ uint128 (shiftr n) x (fun y => g)) (at level 10, g at level 200, format "shiftR@( y , x , n ); '//' g "). + Local Notation "'Ret' $ x" := (Scalar (Var tZ x)) (at level 10, format "'Ret' $ x"). + Local Notation "( x , y )" := (Pair x y) (at level 10, left associativity). + Print montred256_prefancy. (* -montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, - expr_let x0 := (uint128)(x₁ >> 128) in - expr_let x1 := ((uint128)(x₁) & 340282366920938463463374607431768211455) in - expr_let x2 := 79228162514264337593543950337 *₂₅₆ x0 in - expr_let x3 := ((uint128)(x2) & 340282366920938463463374607431768211455) in - expr_let x4 := 340282366841710300986003757985643364352 *₂₅₆ x1 in - expr_let x5 := ((uint128)(x4) & 340282366920938463463374607431768211455) in - expr_let x6 := (uint256)(x3 << 128) in - expr_let x7 := (uint256)(x5 << 128) in - expr_let x8 := 79228162514264337593543950337 *₂₅₆ x1 in - expr_let x9 := ADD_256 (x7, x8) in - expr_let x10 := ADD_256 (x6, x9₁) in - expr_let x11 := (uint128)(x10₁ >> 128) in - expr_let x12 := ((uint128)(x10₁) & 340282366920938463463374607431768211455) in - expr_let x13 := 79228162514264337593543950335 *₂₅₆ x11 in - expr_let x14 := (uint128)(x13 >> 128) in - expr_let x15 := ((uint128)(x13) & 340282366920938463463374607431768211455) in - expr_let x16 := 340282366841710300967557013911933812736 *₂₅₆ x12 in - expr_let x17 := (uint128)(x16 >> 128) in - expr_let x18 := ((uint128)(x16) & 340282366920938463463374607431768211455) in - expr_let x19 := 340282366841710300967557013911933812736 *₂₅₆ x11 in - expr_let x20 := (uint256)(x15 << 128) in - expr_let x21 := (uint256)(x18 << 128) in - expr_let x22 := 79228162514264337593543950335 *₂₅₆ x12 in - expr_let x23 := ADD_256 (x21, x22) in - expr_let x24 := ADDC_256 (x23₂, x14, x17) in - expr_let x25 := ADD_256 (x20, x23₁) in - expr_let x26 := ADDC_256 (x25₂, x19, x24₁) in - expr_let x27 := ADD_256 (x25₁, x₁) in - expr_let x28 := ADDC_256 (x27₂, x26₁, x₂) in - expr_let x29 := SELC (x28₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in - expr_let x30 := Z.cast uint256 @@ (fst @@ SUB_256 (x28₁, x29)) in - ADDM (x30, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) - : Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)) - *) + mullh@(x0, RegPInv, $x₁); + mulhl@(x1, RegPInv, $x₁); + shiftL@(x2, $x0, 128); + shiftL@(x3, $x1, 128); + mulll@(x4, RegPInv, $x₁); + add@(x5, $x3, $x4); + add@(x6, $x2, $x5); + mullh@(x7, RegMod, $x6); + mulhl@(x8, RegMod, $x6); + mulhh@(x9, RegMod, $x6); + shiftL@(x10, $x7, 128); + shiftR@(x11, $x7, 128); + shiftL@(x12, $x8, 128); + shiftR@(x13, $x8, 128); + mulll@(x14, RegMod, $x6); + add@(x15, $x12, $x14); + addc@(x16, carry{$x15}, $x11, $x13); + add@(x17, $x10, $x15); + addc@(x18, carry{$x17}, $x9, $x16); + add@(x19, $x17, $x₁); + addc@(x20, carry{$x19}, $x18, $x₂); + sel@(x21, carry{$x20}, RegZero, RegMod); + sub@(x22, $x20, $x21); + addm@(x23, $x22, RegZero, RegMod); + Ret $x23 + *) End Montgomery256. (* Extra-specialized ad-hoc pretty-printing *) @@ -9424,30 +9580,30 @@ Module FancyPrintingNotations. (primitive 0) TT) (only printing, at level 9) : expr_scope. Notation "'$R'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 : expr_scope. - Notation "'Lower128{RegMod}'" := + Notation "'c.Lower(RegMod)'" := (AppIdent (primitive 79228162514264337593543950335) TT) (only printing, at level 9) : expr_scope. - Notation "'RegMod' '>>' '128'" := + Notation "'c.Upper(RegMod)'" := (AppIdent (primitive 340282366841710300967557013911933812736) - TT) (only printing, at level 9, format "'RegMod' '>>' '128'") : expr_scope. - Notation "'Lower128{RegMuLow}'" := + TT) (only printing, at level 9) : expr_scope. + Notation "'c.Lower(RegMuLow)'" := (AppIdent (primitive 340282366841710300930663525764514709507) TT) (only printing, at level 9) : expr_scope. - Notation "'RegMuLow' '>>' '128'" := + Notation "'c.Upper(RegMuLow)'" := (AppIdent (primitive 79228162514264337589248983038) - TT) (only printing, at level 9, format "'RegMuLow' '>>' '128'") : expr_scope. - Notation "'Lower128{RegPinv}'" := + TT) (only printing, at level 9) : expr_scope. + Notation "'c.Lower(RegPinv)'" := (AppIdent (primitive 79228162514264337593543950337) TT) (only printing, at level 9) : expr_scope. - Notation "'RegPinv' '>>' '128'" := + Notation "'c.Upper(RegPinv)'" := (AppIdent (primitive 340282366841710300986003757985643364352) - TT) (only printing, at level 9, format "'RegPinv' '>>' '128'") : expr_scope. + TT) (only printing, at level 9) : expr_scope. Notation "'uint256'" := (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : ctype_scope. Notation "'uint128'" @@ -9495,6 +9651,9 @@ Module FancyPrintingNotations. Notation "'c.Sub(' '$' n ',' x ',' y ');' f" := (expr_let n := Z.cast uint256 @@ (fst @@ (Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)))) in f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope. + Notation "'c.Sub(' '$' n ',' x ',' y ');' f" := + (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)) in + f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope. Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" := (Z.cast uint256 @@ (Z.add_modulo @@ (x, y, z)))%expr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : expr_scope. Notation "'c.ShiftR(' '$' n ',' x ',' y ');' f" := @@ -9503,15 +9662,17 @@ Module FancyPrintingNotations. (expr_let n := Z.cast _ @@ (Z.rshi_concrete $R m @@ (x, y)) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Rshi(' '$' n ',' x ',' y ',' m ');' ']' '//' f") : expr_scope. Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" := (expr_let n := Z.cast _ @@ (Z.shiftl y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. + Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" := + (expr_let n := Z.cast _ @@ (Z.shiftl y @@ (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455 @@ x))) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. Notation "'c.Lower128(' '$' n ',' x ');' f" := (expr_let n := Z.cast _ @@ (Z.land 340282366920938463463374607431768211455 @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$' n ',' x ');' ']' '//' f") : expr_scope. - Notation "'c.LowerHalf(' x ')'" - := (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455)) - (at level 10, only printing, format "c.LowerHalf( x )") + Notation "'c.Lower(' x ')'" + := (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455 @@ x)) + (at level 10, only printing, format "c.Lower( x )") : expr_scope. - Notation "'c.UpperHalf(' x ')'" - := (Z.cast uint128 @@ (Z.shiftr 340282366920938463463374607431768211455)) - (at level 10, only printing, format "c.UpperHalf( x )") + Notation "'c.Upper(' x ')'" + := (Z.cast uint128 @@ (Z.shiftr 128 @@ x)) + (at level 10, only printing, format "c.Upper( x )") : expr_scope. Notation "( v << count )" := (Z.cast _ @@ (Z.shiftl count @@ v)%expr) @@ -9538,16 +9699,16 @@ c.Rshi($x1, RegZero, $x_hi, 255); c.Rshi($x2, $x_hi, $x_lo, 255); c.ShiftR($x3, $x2, 128); c.Lower128($x4, $x2); -c.Mul128x128($x5, RegMuLow >> 128, $x4); +c.Mul128x128($x5, c.Upper(RegMuLow), $x4); c.ShiftR($x6, $x5, 128); c.Lower128($x7, $x5); -c.Mul128x128($x8, Lower128{RegMuLow}, $x3); +c.Mul128x128($x8, c.Lower(RegMuLow), $x3); c.ShiftR($x9, $x8, 128); c.Lower128($x10, $x8); -c.Mul128x128($x11, RegMuLow >> 128, $x3); +c.Mul128x128($x11, c.Upper(RegMuLow), $x3); c.ShiftL($x12, $x7, 128); c.ShiftL($x13, $x10, 128); -c.Mul128x128($x14, Lower128{RegMuLow}, $x4); +c.Mul128x128($x14, c.Lower(RegMuLow), $x4); c.Add256($x15, $x13, $x14); c.Addc128($x16, $x15_hi, $x6, $x9); c.Add256($x17, $x12, $x15_lo); @@ -9559,16 +9720,16 @@ c.Addc128($x22, $x21_hi, RegZero, $x20_lo); c.Rshi($x23, $x22_lo, $x21_lo, 1); c.ShiftR($x24, $x23, 128); c.Lower128($x25, $x23); -c.Mul128x128($x26, Lower128{RegMod}, $x24); +c.Mul128x128($x26, c.Lower(RegMod), $x24); c.ShiftR($x27, $x26, 128); c.Lower128($x28, $x26); -c.Mul128x128($x29, RegMod >> 128, $x25); +c.Mul128x128($x29, c.Upper(RegMod), $x25); c.ShiftR($x30, $x29, 128); c.Lower128($x31, $x29); -c.Mul128x128($x32, RegMod >> 128, $x24); +c.Mul128x128($x32, c.Upper(RegMod), $x24); c.ShiftL($x33, $x28, 128); c.ShiftL($x34, $x31, 128); -c.Mul128x128($x35, Lower128{RegMod}, $x25); +c.Mul128x128($x35, c.Lower(RegMod), $x25); c.Add256($x36, $x34, $x35); c.Addc256($x37, $x36_hi, $x27, $x30); c.Add256($x38, $x33, $x36_lo); @@ -9584,36 +9745,28 @@ c.AddM($ret, $x43, RegZero, RegMod); Print Montgomery256.montred256. (* -c.ShiftR($x0, $x_lo, 128); -c.Lower128($x1, $x_lo); -c.Mul128x128($x2, Lower128{RegPinv}, $x0); -c.Lower128($x3, $x2); -c.Mul128x128($x4, RegPinv >> 128, $x1); -c.Lower128($x5, $x4); -c.ShiftL($x6, $x3, 128); -c.ShiftL($x7, $x5, 128); -c.Mul128x128($x8, Lower128{RegPinv}, $x1); -c.Add256($x9, $x7, $x8); -c.Add256($x10, $x6, $x9_lo); -c.ShiftR($x11, $x10_lo, 128); -c.Lower128($x12, $x10_lo); -c.Mul128x128($x13, Lower128{RegMod}, $x11); -c.ShiftR($x14, $x13, 128); -c.Lower128($x15, $x13); -c.Mul128x128($x16, RegMod >> 128, $x12); -c.ShiftR($x17, $x16, 128); -c.Lower128($x18, $x16); -c.Mul128x128($x19, RegMod >> 128, $x11); -c.ShiftL($x20, $x15, 128); -c.ShiftL($x21, $x18, 128); -c.Mul128x128($x22, Lower128{RegMod}, $x12); -c.Add256($x23, $x21, $x22); -c.Addc256($x24, $x23_hi, $x14, $x17); -c.Add256($x25, $x20, $x23_lo); -c.Addc256($x26, $x25_hi, $x19, $x24_lo); -c.Add256($x27, $x25_lo, $x_lo); -c.Addc256($x28, $x27_hi, $x26_lo, $x_hi); -c.Selc($x29, $x28_hi, RegZero, RegMod); -c.Sub($x30, $x28_lo, $x29); -c.AddM($ret, $x30, RegZero, RegMod); +c.Mul128x128($x0, c.Lower(RegPinv), c.Upper($x_lo)); +c.Mul128x128($x1, c.Upper(RegPinv), c.Lower($x_lo)); +c.ShiftL($x2, $x0, 128); +c.ShiftL($x3, $x1, 128); +c.Mul128x128($x4, c.Lower(RegPinv), c.Lower($x_lo)); +c.Add256($x5, $x3, $x4); +c.Add256($x6, $x2, $x5_lo); +c.Mul128x128($x7, c.Lower(RegMod), c.Upper($x6_lo)); +c.Mul128x128($x8, c.Upper(RegMod), c.Lower($x6_lo)); +c.Mul128x128($x9, c.Upper(RegMod), c.Upper($x6_lo)); +c.ShiftL($x10, $x7, 128); +c.ShiftR($x11, $x7, 128); +c.ShiftL($x12, $x8, 128); +c.ShiftR($x13, $x8, 128); +c.Mul128x128($x14, c.Lower(RegMod), c.Lower($x6_lo)); +c.Add256($x15, $x12, $x14); +c.Addc256($x16, $x15_hi, $x11, $x13); +c.Add256($x17, $x10, $x15_lo); +c.Addc256($x18, $x17_hi, $x9, $x16_lo); +c.Add256($x19, $x17_lo, $x_lo); +c.Addc256($x20, $x19_hi, $x18_lo, $x_hi); +c.Selc($x21, $x20_hi, RegZero, RegMod); +c.Sub($x22, $x20_lo, $x21); +c.AddM($ret, $x22_lo, RegZero, RegMod); *) -- cgit v1.2.3