diff options
author | 2018-03-21 13:54:51 -0400 | |
---|---|---|
committer | 2018-03-21 19:51:48 -0400 | |
commit | df9223b12bba6dd064bc7fc05ba64139a252d69d (patch) | |
tree | 2896aa27f07220f45a64d403ae25ada5922df896 | |
parent | fbf11137d51beef861cc0825603e33674b1353b4 (diff) |
Don't inline var nodes on the first pass through partial evaluation
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 132 |
1 files changed, 92 insertions, 40 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index e739d6bcc..16f2e3c7c 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -4263,7 +4263,8 @@ Module Compilers. Module ident. Section interp. - Context {var : type -> Type}. + Context (inline_var_nodes : bool) + {var : type -> Type}. Fixpoint is_var_like {t} (e : @expr var t) : bool := match e with | Var t v => true @@ -4309,15 +4310,18 @@ Module Compilers. => @interp_let_in _ B b (fun b => f (inr (a, b)))) - | inl (data, e) => partial.expr.reflect - (expr_let y := partial.expr.reify (t:=t) x in - partial.expr.reify (f (inl (data, Var y)%core)))%expr + | inl (data, e) + => if inline_var_nodes && is_var_like e + then f x + else partial.expr.reflect + (expr_let y := partial.expr.reify (t:=t) x in + partial.expr.reify (f (inl (data, Var y)%core)))%expr end | type.type_primitive _ as t => fun (x : data t * expr t + type.interp t) (f : data t * expr t + type.interp t -> value var tC) => match x with | inl (data, e) - => if is_var_like e + => if inline_var_nodes && is_var_like e then f x else partial.expr.reflect (expr_let y := (partial.expr.reify (t:=t) x) in @@ -4884,7 +4888,8 @@ Module Compilers. End partial. Section partial_reduce. - Context {var : type -> Type}. + Context (inline_var_nodes : bool) + {var : type -> Type}. Definition partial_reduce'_step (partial_reduce' : forall {t} (e : @expr (partial.value var) t), @@ -4894,7 +4899,7 @@ Module Compilers. := match e in expr.expr t return partial.value var t with | Var t v => v | TT => inr tt - | AppIdent s d idc args => partial.ident.interp idc (@partial_reduce' _ args) + | AppIdent s d idc args => partial.ident.interp inline_var_nodes idc (@partial_reduce' _ args) | Pair A B a b => inr (@partial_reduce' A a, @partial_reduce' B b) | App s d f x => @partial_reduce' _ f (@partial_reduce' _ x) | Abs s d f => fun x => @partial_reduce' d (f x) @@ -4918,8 +4923,8 @@ Module Compilers. End partial_reduce. - Definition PartialReduce {t} (e : Expr t) : Expr t - := fun var => @partial_reduce var t (e _). + Definition PartialReduce (inline_var_nodes : bool) {t} (e : Expr t) : Expr t + := fun var => @partial_reduce inline_var_nodes var t (e _). Module RelaxZRange. Module ident. @@ -4972,7 +4977,7 @@ Module Compilers. Definition PartialReduceWithBounds1 {s d} (e : Expr (s -> d)) (b : ZRange.type.interp s) : Expr (s -> d) - := fun var => @partial_reduce_with_bounds1 var s d (e _) b. + := fun var => @partial_reduce_with_bounds1 true var s d (e _) b. Definition CheckPartialReduceWithBounds1 (relax_zrange : zrange -> option zrange) @@ -5009,7 +5014,7 @@ Module Compilers. {t} (e : Expr t) (b_out : ZRange.type.interp t) : Expr t + ZRange.type.option.interp t - := dlet_nd E := PartialReduce e in + := dlet_nd E := PartialReduce true e in CheckPartialReduceWithBounds0 relax_zrange E b_out. Axiom admit_pf : False. @@ -5048,7 +5053,7 @@ Module Compilers. -> is_tighter_than_bool z r' = true) {t} (e : Expr t) (b_out : ZRange.type.interp t) - E (HE : PartialReduce e = E) + E (HE : PartialReduce true e = E) rv (Hrv : CheckPartialReduceWithBounds0 relax_zrange E b_out = inl rv) : Interp rv = Interp e /\ ZRange.type.is_bounded_by b_out (Interp rv) = true. @@ -5278,7 +5283,7 @@ Proof. let v := Reify ((fun x => 2^x) 255)%Z in pose v as E. vm_compute in E. - pose (PartialReduce (canonicalize_list_recursion E)) as E'. + pose (PartialReduce false (canonicalize_list_recursion E)) as E'. vm_compute in E'. lazymatch (eval cbv delta [E'] in E') with | (fun var => AppIdent (ident.primitive ?v) TT) => idtac @@ -5296,7 +5301,7 @@ Module test2. (fun v => v)) in pose v as E. vm_compute in E. - pose (PartialReduce (canonicalize_list_recursion E)) as E'. + pose (PartialReduce false (canonicalize_list_recursion E)) as E'. vm_compute in E'. lazymatch (eval cbv delta [E'] in E') with | (fun var : type -> Type => @@ -5329,7 +5334,7 @@ Module test3. (z * z)) in pose v as E. vm_compute in E. - pose (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. + pose (PartialReduce false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. vm_compute in E'. lazymatch (eval cbv delta [E'] in E') with | (fun var : type -> Type => @@ -5366,7 +5371,7 @@ Module test4. (xz :: xz :: nil)) in pose v as E. vm_compute in E. - pose (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. + pose (PartialReduce false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. lazy in E'. clear E. pose (PartialReduceWithBounds1 E' ([r[0~>10]%zrange],[r[0~>10]%zrange])) as E''. @@ -5393,7 +5398,7 @@ Module test5. x) in pose v as E. vm_compute in E. - pose (ReassociateSmallConstants.Reassociate (2^8) (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))))) as E'. + pose (ReassociateSmallConstants.Reassociate (2^8) (PartialReduce false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))))) as E'. lazy in E'. clear E. lazymatch (eval cbv delta [E'] in E') with @@ -5420,7 +5425,7 @@ Module test6. pose (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))) as E'. lazy in E'. clear E. - pose (PartialReduce E') as E''. + pose (PartialReduce false E') as E''. lazy in E''. lazymatch eval cbv delta [E''] in E'' with | fun var : type -> Type => (λ x : var (type.type_primitive type.Z), Var x)%expr @@ -5443,7 +5448,7 @@ Ltac cache_reify _ := let e := match RHS with context[expr.Interp _ ?e] => e end in let E := fresh "E" in set (E := e); - let E' := constr:(PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) in + let E' := constr:(PartialReduce false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) in let LHS := match goal with |- ?LHS = _ => LHS end in lazymatch LHS with | context LHS[@expr.Interp ?ident ?interp_ident ?t ?e] @@ -5613,7 +5618,9 @@ Module Pipeline. {t} (E : for_reification.Expr t) : ErrorT (Expr t) - := let E := option_map PartialReduce (CPS.CallFunWithIdContinuation_opt (CPS.Translate (canonicalize_list_recursion E))) in + := let E := option_map + (PartialReduce false) + (CPS.CallFunWithIdContinuation_opt (CPS.Translate (canonicalize_list_recursion E))) in match E with | Some E => Success E | None => Error (Type_too_complicated_for_cps t) @@ -5631,7 +5638,7 @@ Module Pipeline. (E : Expr (s -> d)) arg_bounds : Expr (s -> d) - := let E := PartialReduce E in + := let E := PartialReduce true E in (* Note that DCE evaluates the expr with two different [var] arguments, and so will likely result in a pipeline that is 2x slower *) @@ -5770,13 +5777,13 @@ Module Pipeline. {t} (e : Expr t) : Expr t - := let E := PartialReduce e in + := let E := PartialReduce true e in (* Note that DCE evaluates the expr with two different [var] arguments, and so will likely result in a pipeline that is 2x slower *) let E := if with_dead_code_elimination then DeadCodeElimination.EliminateDead E else E in let E := ReassociateSmallConstants.Reassociate (2^8) E in - let E := PartialReduce E in + let E := PartialReduce true E in E. Definition CheckBoundsPipelineConst @@ -6637,23 +6644,68 @@ Module X25519_64. Import PrintingNotations. Print base_51_carry_mul. -(* base_51_carry_mul = fun var : type -> Type => (λ v : var (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z))%ctype, -expr_let v0 := (uint64)(v₁ [[0]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[1]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ (v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[3]])) +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[2]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[1]]))))) >> 51) in -expr_let v1 := ((uint64)(v₁ [[0]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[1]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ (v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[3]])) +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[2]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[1]])))))) & 2251799813685247) in -expr_let v2 := (uint64)((uint64)(v0) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[3]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[2]])))))) >> 51) in -expr_let v3 := ((uint64)((uint64)(v0) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[3]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[2]]))))))) & 2251799813685247) in -expr_let v4 := (uint64)((uint64)(v2) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[3]])))))) >> 51) in -expr_let v5 := ((uint64)((uint64)(v2) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[0]] +₁₂₈ (v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[4]])) +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[3]]))))))) & 2251799813685247) in -expr_let v6 := (uint64)((uint64)(v4) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[3]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[3]] *₁₂₈ v₂ [[0]] +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[4]])))))) >> 51) in -expr_let v7 := ((uint64)((uint64)(v4) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[3]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[1]] +₁₂₈ (v₁ [[3]] *₁₂₈ v₂ [[0]] +₁₂₈ v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[4]]))))))) & 2251799813685247) in -expr_let v8 := (uint64)((uint64)(v6) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[4]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[3]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[3]] *₁₂₈ v₂ [[1]] +₁₂₈ v₁ [[4]] *₁₂₈ v₂ [[0]])))) >> 51) in -expr_let v9 := ((uint64)((uint64)(v6) +₁₂₈ (v₁ [[0]] *₁₂₈ v₂ [[4]] +₁₂₈ (v₁ [[1]] *₁₂₈ v₂ [[3]] +₁₂₈ (v₁ [[2]] *₁₂₈ v₂ [[2]] +₁₂₈ (v₁ [[3]] *₁₂₈ v₂ [[1]] +₁₂₈ v₁ [[4]] *₁₂₈ v₂ [[0]]))))) & 2251799813685247) in -expr_let v10 := (uint64)((uint64)(v1) +₆₄ 19 *₆₄ (uint64)(v8) >> 51) in -expr_let v11 := ((uint64)((uint64)(v1) +₆₄ 19 *₆₄ (uint64)(v8)) & 2251799813685247) in -expr_let v12 := (uint64)((uint64)(v10) +₆₄ (uint64)(v3) >> 51) in -expr_let v13 := ((uint64)((uint64)(v10) +₆₄ (uint64)(v3)) & 2251799813685247) in -(uint64)(v11) :: (uint64)(v13) :: (uint64)(v12) +₆₄ (uint64)(v5) :: (uint64)(v7) :: (uint64)(v9) :: [])%expr - : Expr (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z)) +(*base_51_carry_mul = +fun var : type -> Type => +(λ v : var + (type.list (type.type_primitive type.Z) * + type.list (type.type_primitive type.Z))%ctype, + expr_let v0 := v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[4]])) in + expr_let v1 := v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[3]])) in + expr_let v2 := v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[2]])) in + expr_let v3 := v₁ [[4]] *₁₂₈ (19 * (uint64)(v₂[[1]])) in + expr_let v4 := v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[4]])) in + expr_let v5 := v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[3]])) in + expr_let v6 := v₁ [[3]] *₁₂₈ (19 * (uint64)(v₂[[2]])) in + expr_let v7 := v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[4]])) in + expr_let v8 := v₁ [[2]] *₁₂₈ (19 * (uint64)(v₂[[3]])) in + expr_let v9 := v₁ [[1]] *₁₂₈ (19 * (uint64)(v₂[[4]])) in + expr_let v10 := v₁ [[4]] *₁₂₈ v₂ [[0]] in + expr_let v11 := v₁ [[3]] *₁₂₈ v₂ [[1]] in + expr_let v12 := v₁ [[3]] *₁₂₈ v₂ [[0]] in + expr_let v13 := v₁ [[2]] *₁₂₈ v₂ [[2]] in + expr_let v14 := v₁ [[2]] *₁₂₈ v₂ [[1]] in + expr_let v15 := v₁ [[2]] *₁₂₈ v₂ [[0]] in + expr_let v16 := v₁ [[1]] *₁₂₈ v₂ [[3]] in + expr_let v17 := v₁ [[1]] *₁₂₈ v₂ [[2]] in + expr_let v18 := v₁ [[1]] *₁₂₈ v₂ [[1]] in + expr_let v19 := v₁ [[1]] *₁₂₈ v₂ [[0]] in + expr_let v20 := v₁ [[0]] *₁₂₈ v₂ [[4]] in + expr_let v21 := v₁ [[0]] *₁₂₈ v₂ [[3]] in + expr_let v22 := v₁ [[0]] *₁₂₈ v₂ [[2]] in + expr_let v23 := v₁ [[0]] *₁₂₈ v₂ [[1]] in + expr_let v24 := v₁ [[0]] *₁₂₈ v₂ [[0]] in + expr_let v25 := v24 +₁₂₈ (v9 +₁₂₈ (v8 +₁₂₈ (v6 +₁₂₈ v3))) in + expr_let v26 := (uint64)(v25 >> 51) in + expr_let v27 := ((uint64)(v25) & 2251799813685247) in + expr_let v28 := v20 +₁₂₈ (v16 +₁₂₈ (v13 +₁₂₈ (v11 +₁₂₈ v10))) in + expr_let v29 := v21 +₁₂₈ (v17 +₁₂₈ (v14 +₁₂₈ (v12 +₁₂₈ v0))) in + expr_let v30 := v22 +₁₂₈ (v18 +₁₂₈ (v15 +₁₂₈ (v4 +₁₂₈ v1))) in + expr_let v31 := v23 +₁₂₈ (v19 +₁₂₈ (v7 +₁₂₈ (v5 +₁₂₈ v2))) in + expr_let v32 := v26 +₁₂₈ v31 in + expr_let v33 := (uint64)(v32 >> 51) in + expr_let v34 := ((uint64)(v32) & 2251799813685247) in + expr_let v35 := v33 +₁₂₈ v30 in + expr_let v36 := (uint64)(v35 >> 51) in + expr_let v37 := ((uint64)(v35) & 2251799813685247) in + expr_let v38 := v36 +₁₂₈ v29 in + expr_let v39 := (uint64)(v38 >> 51) in + expr_let v40 := ((uint64)(v38) & 2251799813685247) in + expr_let v41 := v39 +₁₂₈ v28 in + expr_let v42 := (uint64)(v41 >> 51) in + expr_let v43 := ((uint64)(v41) & 2251799813685247) in + expr_let v44 := 19 *₆₄ v42 in + expr_let v45 := v27 +₆₄ v44 in + expr_let v46 := (uint64)(v45 >> 51) in + expr_let v47 := ((uint64)(v45) & 2251799813685247) in + expr_let v48 := v46 +₆₄ v34 in + expr_let v49 := (uint64)(v48 >> 51) in + expr_let v50 := ((uint64)(v48) & 2251799813685247) in + expr_let v51 := v49 +₆₄ v37 in + v47 :: v50 :: v51 :: v40 :: v43 :: [])%expr + : Expr + (type.list (type.type_primitive type.Z) * + type.list (type.type_primitive type.Z) -> + type.list (type.type_primitive type.Z)) *) End X25519_64. |