From 82cb73f445ab650a5fecdedc942481d5abfdabc7 Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 25 Apr 2018 14:12:32 +0200 Subject: Translation to straightline code (first attempts, mostly working) --- src/Experiments/SimplyTypedArithmetic.v | 323 ++++++++++++++++++++++++++++++-- 1 file changed, 311 insertions(+), 12 deletions(-) (limited to 'src') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index a8f1e0a13..59ac46582 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -5529,11 +5529,12 @@ Module Compilers. | ident.Z_mul_split_concrete _ as idc | ident.Z.sub_get_borrow_concrete _ as idc => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => match x_y return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with + => let default _ := default_interp idc x_y in + match x_y return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with | inr (inr x, inr y) => let result := ident.interp idc (x, y) in inr (inr (fst result), inr (snd result)) - | _ => default_interp idc x_y + | _ => default tt end | ident.Z.add_get_carry_concrete _ as idc => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) @@ -7746,7 +7747,7 @@ Module X25519_32. Import PrintingNotations. Print base_25p5_carry_mul. (* -base_25p5_carry_mul = +base_25p5_carry_mul = fun var : type -> Type => (λ x : var (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z))%ctype, expr_let x0 := x₁ [[0]] *₆₄ x₂ [[0]] +₆₄ @@ -7885,7 +7886,7 @@ fun var : type -> Type => type.list (type.type_primitive type.Z))) *) End X25519_32. -*) + *) Module BarrettReduction. (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *) @@ -8618,7 +8619,7 @@ Module P192_64. Open Scope expr_scope. Set Printing Width 100000. Set Printing Depth 100000. - + Local Notation "'mul64' '(' x ',' y ')'" := (Z.cast2 (uint64, _)%core @@ (Z.mul_split_concrete 18446744073709551616 @@ (x , y)))%expr (at level 50) : expr_scope. Local Notation "'add64' '(' x ',' y ')'" := @@ -8627,7 +8628,7 @@ Module P192_64. (Z.cast2 (uint64, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (c, x , y)))%expr (at level 50) : expr_scope. Local Notation "'adx64' '(' c ',' x ',' y ')'" := (Z.cast bool @@ (Z.add_with_carry @@ (c, x , y)))%expr (at level 50) : expr_scope. - + Print mulmod. (* mulmod = fun var : type -> Type => λ x : var (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z))%ctype, @@ -8695,7 +8696,7 @@ Module P192_32. (Z.cast2 (uint32, bool)%core @@ (Z.add_get_carry_concrete 4294967296 @@ (x , y)))%expr (at level 50) : expr_scope. Local Notation "'adc32' '(' c ',' x ',' y ')'" := (Z.cast2 (uint32, bool)%core @@ (Z.add_with_get_carry_concrete 4294967296 @@ (c, x , y)))%expr (at level 50) : expr_scope. - + Print mulmod. (* mulmod = fun var : type -> Type => λ x : var (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z))%ctype, @@ -8890,6 +8891,291 @@ Module P256_32. End P256_32. *) +Require Import Coq.Program.Wf. + +Module Straightline. + Module expr. + Section with_var. + Context {var : type.type -> Type}. + Context {dummy_var : forall t, var t}. + + Let uexpr t := @Uncurried.expr.expr ident.ident var t. + + Inductive expr : type.type -> Type := + | Scalar {t} : scalar t -> expr t + | LetInAppIdentZ {s d} : zrange -> ident.ident s (type.Z) -> scalar s -> (var (type.Z) -> expr d) -> expr d + | LetInAppIdentZZ {s d} : zrange * zrange -> ident.ident s (type.Z*type.Z) -> scalar s -> (var (type.Z*type.Z) -> expr d) -> expr d + with scalar : type.type -> Type := + | Var t : var t -> scalar t + | TT : scalar (type.type_primitive type.unit) + | Pair {a b} : scalar a -> scalar b -> scalar (a * b) + | Cast : zrange -> scalar type.Z -> scalar type.Z + | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z) + | Fst {a b} : scalar (a * b) -> scalar a + | Snd {a b} : scalar (a * b) -> scalar b + | Primitive {t} : type.interp (type.type_primitive t) -> scalar t + . + + Fixpoint dummy t : expr t := Scalar (Var t (dummy_var t)). + + Definition scalar_of_uncurried_ident {s d} (idc : ident.ident s d) + : scalar s -> option (scalar d) := + match idc in ident.ident s d return scalar s -> option (scalar d) with + | ident.Z.cast r => fun args => Some (Cast r args) + | ident.Z.cast2 r => fun args => Some (Cast2 r args) + | @ident.fst A B => fun args => Some (@Fst A B args) + | @ident.snd A B => fun args => Some (@Snd A B args) + | @ident.primitive p x => fun _ => Some (Primitive x) + | _ => fun _ => None + end. + + Fixpoint scalar_of_uncurried {t} (e : uexpr t) : option (scalar t) := + match e in Uncurried.expr.expr t return option (scalar t) with + | expr.Var t v as e => Some (Var t v) + | expr.TT as e => Some TT + | expr.Pair A B a b + => match scalar_of_uncurried a, scalar_of_uncurried b with + | Some x, Some y => Some (Pair x y) + | _, _ => None + end + | expr.AppIdent _ _ idc args + => match scalar_of_uncurried args with + | Some x => scalar_of_uncurried_ident idc x + | None => None + end + | _ => None + end. + + Fixpoint range_type t : Type := + match t with + | type.type_primitive type.Z => zrange + | type.prod x y => range_type x * range_type y + | _ => unit + end. + + Definition invert_cast {t} (e : uexpr t) + : option (range_type t * uexpr t) := + match invert_AppIdent e with + | Some (existT s (idc, x)) => + (match idc in ident.ident s t return uexpr s -> option (range_type t * uexpr t) with + | ident.Z.cast r => fun x => Some (r, x) + | ident.Z.cast2 r => fun x => Some (r, x) + | _ => fun _ => None + end) x + | None => None + end. + + + (* if we have a cast, what we have is + cast r (AppIdent idc x') + where x' has type s and idc is s -> type.Z + and tx = type.Z + + we want this to be translated to + + (idc, x', + *) + (* ident.Let_In @@ (cast r x) => r, x *) + Definition invert_LetInCast {tx tC} (args : uexpr (tx * (tx -> tC))) + : option (range_type tx * uexpr tx * uexpr (tx -> tC)) := + match invert_Pair args with + | Some (x, e) => + match invert_cast x with + | Some (r, x') => Some (r, x', e) + | None => None + end + | None => None + end. + + + (* TODO : currently we look at the first application, which might be a cast. *) + Definition invert_LetInAppIdent {tx tC} (args : uexpr (tx * (tx -> tC))) + : option { s : type.type & (range_type tx * ident.ident s tx * scalar s * (var tx -> uexpr tC))%type } := + match invert_LetInCast args with + | Some (r, x, e) => + match invert_AppIdent x with + | Some (existT s idc_x') => + match scalar_of_uncurried (snd idc_x') with + | Some x'' => + match invert_Abs e with + | Some k => Some (existT _ s (r, fst idc_x', x'', k)) + | None => None + end + | None => None + end + | None => None + end + | None => None + end. + + Fixpoint depth {t} (e : uexpr t) : nat := + match e with + | expr.Var _ _ => O + | expr.TT => O + | expr.AppIdent _ _ idc args => S (depth args) + | expr.App _ _ f x => S (Nat.max (depth f) (depth x)) + | expr.Pair _ _ x y => S (Nat.max (depth x) (depth y)) + | expr.Abs _ _ f => S (depth (f (dummy_var _))) + end. + + Definition of_uncurried_step {t} (e : uexpr t) + (of_uncurried : forall {t}, uexpr t -> expr t) + : expr t := + (match e in Uncurried.expr.expr t return expr t -> expr t with + | AppIdent s d idc args => + (match idc in ident.ident s d return uexpr s -> expr d -> expr d with + | ident.Let_In tx tC => + (fun args default => + match invert_LetInAppIdent args return expr tC with + | Some (existT s (r, idc, x, k)) => + (match tx as tx0 return range_type tx0 -> ident.ident s tx0 -> (var tx0 -> expr tC) -> expr tC with + | type.type_primitive type.Z => + fun r idc k => @LetInAppIdentZ s tC r idc x k + | type.prod type.Z type.Z => + fun r idc k => @LetInAppIdentZZ s tC r idc x k + | _ => fun _ _ _ => default + end) r idc (fun y : var tx => of_uncurried (k y)) + | None => default + end) + | _ => fun _ default => default + end) args + | _ as e => + (fun default => + match scalar_of_uncurried e with + | Some s => Scalar s + | None => default + end) + end) (dummy t). + + (* TODO : uses fuel; ideally want a cleaner termination proof *) + Fixpoint of_uncurried (fuel : nat) {t} (e : uexpr t) + : expr t := + match fuel with + | S fuel' => of_uncurried_step e (@of_uncurried fuel') + | O => dummy t + end. + + End with_var. + End expr. + + (* TODO : Can I avoid having dummy_var appear here? *) + Definition of_Expr {s d} (e : Expr (s->d)) (var : type -> Type) (x:var s) dummy_var : expr.expr d + := + match invert_Abs (e var) with + | Some f => expr.of_uncurried (dummy_var:=dummy_var) (expr.depth (dummy_var:=dummy_var) (f x)) (f x) + | None => expr.dummy (dummy_var:=dummy_var) d + end. + +End Straightline. + +Module StraightlineTest. + Definition test : Expr (type.Z -> type.Z) := + fun var => + Abs + (fun (x : var type.Z) => + AppIdent (var:=var) ident.Let_In + (Pair (AppIdent (var:=var) (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (var:=var) (ident.Z.shiftr 8) (Var x))) + (Abs (fun x : var type.Z => expr.Var x)))). + + Eval vm_compute in (Straightline.of_Expr test). + + Definition test_mul : Expr (type.Z -> type.Z) := + fun var => + Abs + (fun (x : var type.Z) => + AppIdent (var:=var) ident.Let_In + (Pair (AppIdent (var:=var) (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (var:=var) (ident.Z.shiftr 8) (Var x))) + (Abs (fun x : var type.Z => + AppIdent ident.Let_In + (Pair (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent ident.Z.mul (Pair (AppIdent (@ident.primitive type.Z 12) TT) (Var x)))) + (Abs (fun x : var type.Z => Var x))) + )))). + Eval vm_compute in (Straightline.of_Expr test_mul). +End StraightlineTest. + +(* +Module InlineOperations. + Section with_var. + Context {var : type.type -> Type} (op_match : forall t, @expr.expr ident.ident var t -> bool). + + Fixpoint inline_op {t} (e : @expr.expr ident.ident var t) : @expr.expr ident.ident var t + := + match e in expr.expr t return @expr.expr ident.ident var t with + | Var t _ as e => e + | TT as e + => e + | Pair A B a b + => Pair (@inline_op A a) (@inline_op B b) + | App s d f x => App (@inline_op _ f) (@inline_op _ x) + | Abs s d f => Abs (fun v => @inline_op _ (f v)) + | AppIdent s d idc args + => inline_op_ident idc (@inline_op s args) + end. + + (* + AppIdent ident.Let_In (Pair (AppIdent (ident.shiftr n) x) (Abs (fun y : var _ => F (Var _ y)))) + + => + + F (AppIdent (ident.shiftr n) x) + *) + + Definition replace_num {t} (old : var type.Z) (new : @expr.expr ident.ident var type.Z) (e : @expr.expr ident.ident var t) + : @expr.expr ident.ident var t + := match e with + | Var _ v => + | TT => TT + | + + + Definition inline_if_match {tx tC} (args : @expr.expr ident.ident var (tx * (tx -> tC))) + : @expr.expr ident.ident var tC := + let default _ := AppIdent (@ident.Let_In tx tC) args in + match invert_Pair args with + | Some (e, Abs s d k) => + if op_match tx e then App k e else default tt + | None => default tt + end. + + Definition inline_op_ident {s d} (idc : ident.ident s d) + : expr.expr s -> expr.expr d + := + match idc in ident.ident s d return @expr.expr ident.ident var s -> expr.expr d with + | ident.Let_In tx tC => fun args : @expr.expr ident.ident var (tx * (tx -> tC)) => + @inline_if_match tx tC args + | _ as i => fun args => AppIdent i args + end. + + Fixpoint inline_op {t} (e : @expr.expr ident.ident var t) : @expr.expr ident.ident var t + := + match e in expr.expr t return @expr.expr ident.ident var t with + | Var t _ as e => e + | TT as e + => e + | Pair A B a b + => Pair (@inline_op A a) (@inline_op B b) + | App s d f x => App (@inline_op _ f) (@inline_op _ x) + | Abs s d f => Abs (fun v => @inline_op _ (f v)) + | AppIdent s d idc args + => inline_op_ident idc (@inline_op s args) + end. + End with_var. + + Fixpoint shift_land_matcher {var} t (e : @expr.expr ident.ident var t) : bool := + match e with + | AppIdent _ _ (ident.Z_cast _) args => shift_land_matcher _ args + | AppIdent _ _ (ident.Z_cast2 _) args => shift_land_matcher _ args + | AppIdent _ _ (ident.Z_land _) args => true + | AppIdent _ _ (ident.Z_shiftr _) args => true + | _ => false + end. + + Definition inline_shiftr_and_land {t} (e : Expr t) : Expr t := + fun var => inline_op shift_land_matcher (e _). + +End InlineOperations. +*) + Module MontgomeryReduction. Section MontRed'. Context (N R N' R' : Z). @@ -8909,8 +9195,8 @@ Module MontgomeryReduction. dlet_nd t1_t2 := (BaseConversion.widemul Zlog2R n nout y N) in dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [fst lo_hi; snd lo_hi] t1_t2 in dlet_nd y' := Z.zselect (snd sum_carry) 0 N in - dlet_nd lo'' := fst (Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y') in - Z.add_modulo lo'' 0 N. + dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in + Z.add_modulo (fst lo''_carry) 0 N. Local Lemma Hw : forall i, w i = R ^ Z.of_nat i. Proof. @@ -9054,11 +9340,24 @@ Module Montgomery256. As montred256_correct. Proof. Time solve_rmontred machine_wordsize. Time Qed. + Import Straightline.expr. Import PrintingNotations. - Open Scope expr_scope. - Set Printing Width 100000. - + Set Printing Depth 1000000. + Local Notation "'tZ'" := (type.type_primitive type.Z). + Eval lazy in (Straightline.of_Expr montred256). + (* TODO : why is the sub not cast when I remove fst? *) Print montred256. + + (* Check bounds on the sub, make sure it's outputting Some *) + Locate interp. + Eval lazy in (ZRange.ident.option.interp (ident.Z.sub_get_borrow_concrete (2^256)) (Some uint256, Some uint256)). + + Check ZRange.split_bounds. + Print ZRange.split_bounds. + (* TODO: apply the split_bounds updates from the barrett branch! That's what currently is fucking up *) + +(* TODO: to get the right kind of output operation, I probably want + to inline all these shifts/ands. *) (* montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, expr_let x0 := (uint128)(x₁ >> 128) in -- cgit v1.2.3