aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2018-03-21 13:54:51 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2018-03-21 19:51:48 -0400
commitdf9223b12bba6dd064bc7fc05ba64139a252d69d (patch)
tree2896aa27f07220f45a64d403ae25ada5922df896
parentfbf11137d51beef861cc0825603e33674b1353b4 (diff)
Don't inline var nodes on the first pass through partial evaluation
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v132
1 files changed, 92 insertions, 40 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index e739d6bcc..16f2e3c7c 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -4263,7 +4263,8 @@ Module Compilers.
Module ident.
Section interp.
- Context {var : type -> Type}.
+ Context (inline_var_nodes : bool)
+ {var : type -> Type}.
Fixpoint is_var_like {t} (e : @expr var t) : bool
:= match e with
| Var t v => true
@@ -4309,15 +4310,18 @@ Module Compilers.
=> @interp_let_in
_ B b
(fun b => f (inr (a, b))))
- | inl (data, e) => partial.expr.reflect
- (expr_let y := partial.expr.reify (t:=t) x in
- partial.expr.reify (f (inl (data, Var y)%core)))%expr
+ | inl (data, e)
+ => if inline_var_nodes && is_var_like e
+ then f x
+ else partial.expr.reflect
+ (expr_let y := partial.expr.reify (t:=t) x in
+ partial.expr.reify (f (inl (data, Var y)%core)))%expr
end
| type.type_primitive _ as t
=> fun (x : data t * expr t + type.interp t) (f : data t * expr t + type.interp t -> value var tC)
=> match x with
| inl (data, e)
- => if is_var_like e
+ => if inline_var_nodes && is_var_like e
then f x
else partial.expr.reflect
(expr_let y := (partial.expr.reify (t:=t) x) in
@@ -4884,7 +4888,8 @@ Module Compilers.
End partial.
Section partial_reduce.
- Context {var : type -> Type}.
+ Context (inline_var_nodes : bool)
+ {var : type -> Type}.
Definition partial_reduce'_step
(partial_reduce' : forall {t} (e : @expr (partial.value var) t),
@@ -4894,7 +4899,7 @@ Module Compilers.
:= match e in expr.expr t return partial.value var t with
| Var t v => v
| TT => inr tt
- | AppIdent s d idc args => partial.ident.interp idc (@partial_reduce' _ args)
+ | AppIdent s d idc args => partial.ident.interp inline_var_nodes idc (@partial_reduce' _ args)
| Pair A B a b => inr (@partial_reduce' A a, @partial_reduce' B b)
| App s d f x => @partial_reduce' _ f (@partial_reduce' _ x)
| Abs s d f => fun x => @partial_reduce' d (f x)
@@ -4918,8 +4923,8 @@ Module Compilers.
End partial_reduce.
- Definition PartialReduce {t} (e : Expr t) : Expr t
- := fun var => @partial_reduce var t (e _).
+ Definition PartialReduce (inline_var_nodes : bool) {t} (e : Expr t) : Expr t
+ := fun var => @partial_reduce inline_var_nodes var t (e _).
Module RelaxZRange.
Module ident.
@@ -4972,7 +4977,7 @@ Module Compilers.
Definition PartialReduceWithBounds1
{s d} (e : Expr (s -> d)) (b : ZRange.type.interp s)
: Expr (s -> d)
- := fun var => @partial_reduce_with_bounds1 var s d (e _) b.
+ := fun var => @partial_reduce_with_bounds1 true var s d (e _) b.
Definition CheckPartialReduceWithBounds1
(relax_zrange : zrange -> option zrange)
@@ -5009,7 +5014,7 @@ Module Compilers.
{t} (e : Expr t)
(b_out : ZRange.type.interp t)
: Expr t + ZRange.type.option.interp t
- := dlet_nd E := PartialReduce e in
+ := dlet_nd E := PartialReduce true e in
CheckPartialReduceWithBounds0 relax_zrange E b_out.
Axiom admit_pf : False.
@@ -5048,7 +5053,7 @@ Module Compilers.
-> is_tighter_than_bool z r' = true)
{t} (e : Expr t)
(b_out : ZRange.type.interp t)
- E (HE : PartialReduce e = E)
+ E (HE : PartialReduce true e = E)
rv (Hrv : CheckPartialReduceWithBounds0 relax_zrange E b_out = inl rv)
: Interp rv = Interp e
/\ ZRange.type.is_bounded_by b_out (Interp rv) = true.
@@ -5278,7 +5283,7 @@ Proof.
let v := Reify ((fun x => 2^x) 255)%Z in
pose v as E.
vm_compute in E.
- pose (PartialReduce (canonicalize_list_recursion E)) as E'.
+ pose (PartialReduce false (canonicalize_list_recursion E)) as E'.
vm_compute in E'.
lazymatch (eval cbv delta [E'] in E') with
| (fun var => AppIdent (ident.primitive ?v) TT) => idtac
@@ -5296,7 +5301,7 @@ Module test2.
(fun v => v)) in
pose v as E.
vm_compute in E.
- pose (PartialReduce (canonicalize_list_recursion E)) as E'.
+ pose (PartialReduce false (canonicalize_list_recursion E)) as E'.
vm_compute in E'.
lazymatch (eval cbv delta [E'] in E') with
| (fun var : type -> Type =>
@@ -5329,7 +5334,7 @@ Module test3.
(z * z)) in
pose v as E.
vm_compute in E.
- pose (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'.
+ pose (PartialReduce false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'.
vm_compute in E'.
lazymatch (eval cbv delta [E'] in E') with
| (fun var : type -> Type =>
@@ -5366,7 +5371,7 @@ Module test4.
(xz :: xz :: nil)) in
pose v as E.
vm_compute in E.
- pose (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'.
+ pose (PartialReduce false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'.
lazy in E'.
clear E.
pose (PartialReduceWithBounds1 E' ([r[0~>10]%zrange],[r[0~>10]%zrange])) as E''.
@@ -5393,7 +5398,7 @@ Module test5.
x) in
pose v as E.
vm_compute in E.
- pose (ReassociateSmallConstants.Reassociate (2^8) (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))))) as E'.
+ pose (ReassociateSmallConstants.Reassociate (2^8) (PartialReduce false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))))) as E'.
lazy in E'.
clear E.
lazymatch (eval cbv delta [E'] in E') with
@@ -5420,7 +5425,7 @@ Module test6.
pose (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))) as E'.
lazy in E'.
clear E.
- pose (PartialReduce E') as E''.
+ pose (PartialReduce false E') as E''.
lazy in E''.
lazymatch eval cbv delta [E''] in E'' with
| fun var : type -> Type => (λ x : var (type.type_primitive type.Z), Var x)%expr
@@ -5443,7 +5448,7 @@ Ltac cache_reify _ :=
let e := match RHS with context[expr.Interp _ ?e] => e end in
let E := fresh "E" in
set (E := e);
- let E' := constr:(PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) in
+ let E' := constr:(PartialReduce false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) in
let LHS := match goal with |- ?LHS = _ => LHS end in
lazymatch LHS with
| context LHS[@expr.Interp ?ident ?interp_ident ?t ?e]
@@ -5613,7 +5618,9 @@ Module Pipeline.
{t}
(E : for_reification.Expr t)
: ErrorT (Expr t)
- := let E := option_map PartialReduce (CPS.CallFunWithIdContinuation_opt (CPS.Translate (canonicalize_list_recursion E))) in
+ := let E := option_map
+ (PartialReduce false)
+ (CPS.CallFunWithIdContinuation_opt (CPS.Translate (canonicalize_list_recursion E))) in
match E with
| Some E => Success E
| None => Error (Type_too_complicated_for_cps t)
@@ -5631,7 +5638,7 @@ Module Pipeline.
(E : Expr (s -> d))
arg_bounds
: Expr (s -> d)
- := let E := PartialReduce E in
+ := let E := PartialReduce true E in
(* Note that DCE evaluates the expr with two different [var]
arguments, and so will likely result in a pipeline that is
2x slower *)
@@ -5770,13 +5777,13 @@ Module Pipeline.
{t}
(e : Expr t)
: Expr t
- := let E := PartialReduce e in
+ := let E := PartialReduce true e in
(* Note that DCE evaluates the expr with two different [var]
arguments, and so will likely result in a pipeline that is
2x slower *)
let E := if with_dead_code_elimination then DeadCodeElimination.EliminateDead E else E in
let E := ReassociateSmallConstants.Reassociate (2^8) E in
- let E := PartialReduce E in
+ let E := PartialReduce true E in
E.
Definition CheckBoundsPipelineConst
@@ -6637,23 +6644,68 @@ Module X25519_64.
Import PrintingNotations.
Print base_51_carry_mul.
-(* base_51_carry_mul = fun var : type -> Type => (λ v : var (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z))%ctype,
-expr_let v0 := (uint64)(v₁ [[0]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[1]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ (v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[3]])) +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[2]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[1]]))))) >> 51) in
-expr_let v1 := ((uint64)(v₁ [[0]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[1]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ (v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[3]])) +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[2]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[1]])))))) & 2251799813685247) in
-expr_let v2 := (uint64)((uint64)(v0) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[3]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[2]])))))) >> 51) in
-expr_let v3 := ((uint64)((uint64)(v0) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[3]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[2]]))))))) & 2251799813685247) in
-expr_let v4 := (uint64)((uint64)(v2) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[3]])))))) >> 51) in
-expr_let v5 := ((uint64)((uint64)(v2) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[3]]))))))) & 2251799813685247) in
-expr_let v6 := (uint64)((uint64)(v4) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[3]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[3]] *₁₂₈ v₂ [[0]] +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[4]])))))) >> 51) in
-expr_let v7 := ((uint64)((uint64)(v4) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[3]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[3]] *₁₂₈ v₂ [[0]] +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[4]]))))))) & 2251799813685247) in
-expr_let v8 := (uint64)((uint64)(v6) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[4]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[3]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[3]] *₁₂₈ v₂ [[1]] +₁₂₈ v₁ [[4]] *₁₂₈ v₂ [[0]])))) >> 51) in
-expr_let v9 := ((uint64)((uint64)(v6) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[4]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[3]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[3]] *₁₂₈ v₂ [[1]] +₁₂₈ v₁ [[4]] *₁₂₈ v₂ [[0]]))))) & 2251799813685247) in
-expr_let v10 := (uint64)((uint64)(v1) +₆₄ 19 *₆₄ (uint64)(v8) >> 51) in
-expr_let v11 := ((uint64)((uint64)(v1) +₆₄ 19 *₆₄ (uint64)(v8)) & 2251799813685247) in
-expr_let v12 := (uint64)((uint64)(v10) +₆₄ (uint64)(v3) >> 51) in
-expr_let v13 := ((uint64)((uint64)(v10) +₆₄ (uint64)(v3)) & 2251799813685247) in
-(uint64)(v11) :: (uint64)(v13) :: (uint64)(v12) +₆₄ (uint64)(v5) :: (uint64)(v7) :: (uint64)(v9) :: [])%expr
- : Expr (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z))
+(*base_51_carry_mul =
+fun var : type -> Type =>
+(λ v : var
+ (type.list (type.type_primitive type.Z) *
+ type.list (type.type_primitive type.Z))%ctype,
+ expr_let v0 := v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[4]])) in
+ expr_let v1 := v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[3]])) in
+ expr_let v2 := v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[2]])) in
+ expr_let v3 := v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[1]])) in
+ expr_let v4 := v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[4]])) in
+ expr_let v5 := v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[3]])) in
+ expr_let v6 := v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[2]])) in
+ expr_let v7 := v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[4]])) in
+ expr_let v8 := v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[3]])) in
+ expr_let v9 := v₁ [[1]] *₁₂₈ (19 * (uint64)(v₂[[4]])) in
+ expr_let v10 := v₁ [[4]] *₁₂₈ v₂ [[0]] in
+ expr_let v11 := v₁ [[3]] *₁₂₈ v₂ [[1]] in
+ expr_let v12 := v₁ [[3]] *₁₂₈ v₂ [[0]] in
+ expr_let v13 := v₁ [[2]] *₁₂₈ v₂ [[2]] in
+ expr_let v14 := v₁ [[2]] *₁₂₈ v₂ [[1]] in
+ expr_let v15 := v₁ [[2]] *₁₂₈ v₂ [[0]] in
+ expr_let v16 := v₁ [[1]] *₁₂₈ v₂ [[3]] in
+ expr_let v17 := v₁ [[1]] *₁₂₈ v₂ [[2]] in
+ expr_let v18 := v₁ [[1]] *₁₂₈ v₂ [[1]] in
+ expr_let v19 := v₁ [[1]] *₁₂₈ v₂ [[0]] in
+ expr_let v20 := v₁ [[0]] *₁₂₈ v₂ [[4]] in
+ expr_let v21 := v₁ [[0]] *₁₂₈ v₂ [[3]] in
+ expr_let v22 := v₁ [[0]] *₁₂₈ v₂ [[2]] in
+ expr_let v23 := v₁ [[0]] *₁₂₈ v₂ [[1]] in
+ expr_let v24 := v₁ [[0]] *₁₂₈ v₂ [[0]] in
+ expr_let v25 := v24 +₁₂₈ (v9 +₁₂₈ (v8 +₁₂₈ (v6 +₁₂₈ v3))) in
+ expr_let v26 := (uint64)(v25 >> 51) in
+ expr_let v27 := ((uint64)(v25) & 2251799813685247) in
+ expr_let v28 := v20 +₁₂₈ (v16 +₁₂₈ (v13 +₁₂₈ (v11 +₁₂₈ v10))) in
+ expr_let v29 := v21 +₁₂₈ (v17 +₁₂₈ (v14 +₁₂₈ (v12 +₁₂₈ v0))) in
+ expr_let v30 := v22 +₁₂₈ (v18 +₁₂₈ (v15 +₁₂₈ (v4 +₁₂₈ v1))) in
+ expr_let v31 := v23 +₁₂₈ (v19 +₁₂₈ (v7 +₁₂₈ (v5 +₁₂₈ v2))) in
+ expr_let v32 := v26 +₁₂₈ v31 in
+ expr_let v33 := (uint64)(v32 >> 51) in
+ expr_let v34 := ((uint64)(v32) & 2251799813685247) in
+ expr_let v35 := v33 +₁₂₈ v30 in
+ expr_let v36 := (uint64)(v35 >> 51) in
+ expr_let v37 := ((uint64)(v35) & 2251799813685247) in
+ expr_let v38 := v36 +₁₂₈ v29 in
+ expr_let v39 := (uint64)(v38 >> 51) in
+ expr_let v40 := ((uint64)(v38) & 2251799813685247) in
+ expr_let v41 := v39 +₁₂₈ v28 in
+ expr_let v42 := (uint64)(v41 >> 51) in
+ expr_let v43 := ((uint64)(v41) & 2251799813685247) in
+ expr_let v44 := 19 *₆₄ v42 in
+ expr_let v45 := v27 +₆₄ v44 in
+ expr_let v46 := (uint64)(v45 >> 51) in
+ expr_let v47 := ((uint64)(v45) & 2251799813685247) in
+ expr_let v48 := v46 +₆₄ v34 in
+ expr_let v49 := (uint64)(v48 >> 51) in
+ expr_let v50 := ((uint64)(v48) & 2251799813685247) in
+ expr_let v51 := v49 +₆₄ v37 in
+ v47 :: v50 :: v51 :: v40 :: v43 :: [])%expr
+ : Expr
+ (type.list (type.type_primitive type.Z) *
+ type.list (type.type_primitive type.Z) ->
+ type.list (type.type_primitive type.Z))
*)
End X25519_64.