From 6bbb3d948da709737011cc0fc502a271aae7fb36 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 23 Feb 2018 14:25:35 -0500 Subject: Update montred to newish pipeline, revive DCE - Update the style of montred snythesis to match the changes in the pipeline - Bring back non-quadratic dead code elimination and make use of it for montgomery reduction - Update partial reduction to inline "var-like" things (fst, snd, pair applied to var) --- src/Experiments/SimplyTypedArithmetic.v | 342 ++++++++++++++++++-------------- 1 file changed, 189 insertions(+), 153 deletions(-) (limited to 'src') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index fef25fe59..254ef5adf 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -3564,6 +3564,18 @@ Module Compilers. Module ident. Section interp. Context {var : type -> Type}. + Fixpoint is_var_like {t} (e : @expr var t) : bool + := match e with + | Var t v => true + | TT => true + | AppIdent _ _ (ident.fst _ _) args => @is_var_like _ args + | AppIdent _ _ (ident.snd _ _) args => @is_var_like _ args + | Pair A B a b => @is_var_like A a && @is_var_like B b + | AppIdent _ _ _ _ => false + | App _ _ _ _ + | Abs _ _ _ + => false + end. Fixpoint interp_let_in {tC tx : type} {struct tx} : value var tx -> (value var tx -> value var tC) -> value var tC := match tx return value var tx -> (value var tx -> value var tC) -> value var tC with | type.arrow _ _ @@ -3600,10 +3612,9 @@ Module Compilers. (f : sign * expr t + type.interp t -> value var tC) => match x with | inl (sgn, e) - => match invert_Var e with - | Some v => f (inl (sgn, Var v)) - | None => partial.expr.reflect (expr_let y := e in partial.expr.reify (f (inl (sgn, Var y)%core)))%expr - end + => if is_var_like e + then f (inl (sgn, e)) + else partial.expr.reflect (expr_let y := e in partial.expr.reify (f (inl (sgn, Var y)%core)))%expr | inr v => f (inr v) end | type.type_primitive _ as t @@ -3967,6 +3978,104 @@ Module Compilers. Definition PartialReduce {t} (e : Expr t) : Expr t := fun var => @partial_reduce var t (e _). + Module DeadCodeElimination. + Fixpoint compute_live' {t} (e : @expr (fun _ => PositiveSet.t) t) (cur_idx : positive) + : positive * PositiveSet.t + := match e with + | Var t v => (cur_idx, v) + | TT => (cur_idx, PositiveSet.empty) + | AppIdent s d idc args + => let default := @compute_live' _ args cur_idx in + match args in expr.expr t return ident.ident t d -> _ with + | Pair A B x (Abs s d f) + => fun idc + => match idc with + | ident.Let_In _ _ + => let '(idx, live) := @compute_live' A x cur_idx in + let '(_, live) := @compute_live' _ (f (PositiveSet.add idx live)) (Pos.succ idx) in + (Pos.succ idx, live) + | _ => default + end + | _ => fun _ => default + end idc + | App s d f x + => let '(idx, live1) := @compute_live' _ f cur_idx in + let '(idx, live2) := @compute_live' _ x idx in + (idx, PositiveSet.union live1 live2) + | Pair A B a b + => let '(idx, live1) := @compute_live' A a cur_idx in + let '(idx, live2) := @compute_live' B b idx in + (idx, PositiveSet.union live1 live2) + | Abs s d f + => let '(_, live) := @compute_live' _ (f PositiveSet.empty) cur_idx in + (cur_idx, live) + end. + Definition compute_live {t} e : PositiveSet.t := snd (@compute_live' t e 1). + Definition ComputeLive {t} (e : Expr t) := compute_live (e _). + + Section with_var. + Context {var : type -> Type} + (live : PositiveSet.t). + Definition OUGHT_TO_BE_UNUSED {T1 T2} (v : T1) (v' : T2) := v. + Global Opaque OUGHT_TO_BE_UNUSED. + Fixpoint eliminate_dead' {t} (e : @expr (@expr var) t) (cur_idx : positive) + : positive * @expr var t + := match e with + | Var t v => (cur_idx, v) + | TT => (cur_idx, TT) + | AppIdent s d idc args + => let default := @eliminate_dead' _ args cur_idx in + let default := (fst default, AppIdent idc (snd default)) in + match args in expr.expr t return ident.ident t d -> positive * expr d -> positive * expr d with + | Pair A B x y + => match y in expr.expr Y return ident.ident (A * Y) d -> positive * expr d -> positive * expr d with + | Abs s' d' f + => fun idc + => let '(idx, x') := @eliminate_dead' A x cur_idx in + let f' := fun v => snd (@eliminate_dead' _ (f v) (Pos.succ idx)) in + match idc in ident.ident s d + return (match s return Type with + | A * _ => expr A + | _ => unit + end%ctype + -> match s return Type with + | _ * (s -> d) => (expr s -> expr d)%type + | _ => unit + end%ctype + -> positive * expr d + -> positive * expr d) + with + | ident.Let_In _ _ + => fun x' f' _ + => if PositiveSet.mem idx live + then (Pos.succ idx, AppIdent ident.Let_In (Pair x' (Abs (fun v => f' (Var v))))) + else (Pos.succ idx, f' (OUGHT_TO_BE_UNUSED x' (Pos.succ idx, PositiveSet.elements live))) + | _ => fun _ _ default => default + end x' f' + | _ => fun _ default => default + end + | _ => fun _ default => default + end idc default + | App s d f x + => let '(idx, f') := @eliminate_dead' _ f cur_idx in + let '(idx, x') := @eliminate_dead' _ x idx in + (idx, App f' x') + | Pair A B a b + => let '(idx, a') := @eliminate_dead' A a cur_idx in + let '(idx, b') := @eliminate_dead' B b idx in + (idx, Pair a' b') + | Abs s d f + => (cur_idx, Abs (fun v => snd (@eliminate_dead' _ (f (Var v)) cur_idx))) + end. + + Definition eliminate_dead {t} e : expr t + := snd (@eliminate_dead' t e 1). + End with_var. + + Definition EliminateDead {t} (e : Expr t) : Expr t + := fun var => eliminate_dead (ComputeLive e) (e _). + End DeadCodeElimination. + Module ReassociateSmallConstants. Import Compilers.Uncurried.expr.default. @@ -5347,8 +5456,11 @@ Ltac cache_reify _ := end; [ repeat match goal with |- context[expr.Interp _ _ _] => apply (f_equal (fun f => f _)) end; apply f_equal; - time lazy; - reflexivity + lazymatch goal with |- ?LHS = ?RHS => subst LHS end; + let RHS := lazymatch goal with |- ?LHS = ?RHS => RHS end in + time (let RHS' := (eval vm_compute in RHS) in (* [vm_compute] is much faster than [lazy] here on large things *) + time instantiate (1:=RHS'); + vm_cast_no_check (eq_refl RHS')) | clearbody E ]. Derive carry_mul_gen @@ -5505,6 +5617,7 @@ Module Pipeline. end. Definition BoundsPipeline + (with_dead_code_elimination : bool) relax_zrange {s d} arg_bounds @@ -5512,6 +5625,7 @@ Module Pipeline. (E : Expr (s -> d)) : ErrorT (BoundsAnalysis.Indexed.expr.Notations.expr (BoundsAnalysis.Indexed.Range.type_for_range relax_zrange (t:=BoundsAnalysis.Indexed.OfPHOAS.type.compile d) out_bounds)) := let E := PartialReduce E in + let E := if with_dead_code_elimination then DeadCodeElimination.EliminateDead E else E in let E := ReassociateSmallConstants.Reassociate (2^8) E in let E := BoundsAnalysis.OfPHOAS.AnalyzeBounds relax_zrange E arg_bounds in let E := match E with @@ -5524,13 +5638,14 @@ Module Pipeline. E. Lemma BoundsPipeline_correct + (with_dead_code_elimination : bool) relax_zrange {s d} arg_bounds out_bounds (E : Expr (s -> d)) rv - (Hrv : BoundsPipeline relax_zrange arg_bounds out_bounds E = Success rv) + (Hrv : BoundsPipeline with_dead_code_elimination relax_zrange arg_bounds out_bounds E = Success rv) : forall arg (Harg : BoundsAnalysis.Indexed.Range.type_for_range_bounded_by relax_zrange @@ -5554,12 +5669,14 @@ Module Pipeline. Admitted. Definition BoundsPipelineConst + (with_dead_code_elimination : bool) relax_zrange {t} bounds (E : Expr t) : ErrorT (BoundsAnalysis.Indexed.expr.Notations.expr (BoundsAnalysis.Indexed.Range.type_for_range relax_zrange (t:=BoundsAnalysis.Indexed.OfPHOAS.type.compile t) bounds)) := let E := PartialReduce E in + let E := if with_dead_code_elimination then DeadCodeElimination.EliminateDead E else E in let E := ReassociateSmallConstants.Reassociate (2^8) E in let E := BoundsAnalysis.OfPHOAS.AnalyzeBoundsConst relax_zrange E in let E := match E with @@ -5572,12 +5689,13 @@ Module Pipeline. E. Lemma BoundsPipelineConst_correct + (with_dead_code_elimination : bool) relax_zrange {d} bounds (E : Expr d) rv - (Hrv : BoundsPipelineConst relax_zrange bounds E = Success rv) + (Hrv : BoundsPipelineConst with_dead_code_elimination relax_zrange bounds E = Success rv) : exists res, let ctx := PositiveMap.empty _ in BoundsAnalysis.Indexed.expr.interp (@BoundsAnalysis.ident.interp) rv ctx @@ -5749,6 +5867,7 @@ Section rcarry_mul. Let BoundsPipeline21 in_bounds out_bounds res := let res := Pipeline.BoundsPipeline + false relax_zrange (s:=(type.list type.Z * type.list type.Z)%ctype) (d:=(type.list type.Z)%ctype) @@ -5759,6 +5878,7 @@ Section rcarry_mul. Let BoundsPipeline11 in_bounds out_bounds res := let res := Pipeline.BoundsPipeline + false relax_zrange (s:=(type.list type.Z)%ctype) (d:=(type.list type.Z)%ctype) @@ -5768,9 +5888,11 @@ Section rcarry_mul. res. Definition rexpr_1_correctT_ctx + relax_zrange + t_out ctx out_bounds - (f : type.interp (type.list type.Z)) + (f : type.interp t_out) rv := (exists res, BoundsAnalysis.Indexed.expr.interp (@BoundsAnalysis.ident.interp) rv ctx @@ -5781,9 +5903,10 @@ Section rcarry_mul. = f). Definition rexpr_n1_correctT - t_in + relax_zrange + t_in t_out in_bounds out_bounds - (f : _ -> type.interp (type.list type.Z)) + (f : _ -> type.interp t_out) rv := forall arg (arg' := @BoundsAnalysis.OfPHOAS.cast_back @@ -5798,34 +5921,38 @@ Section rcarry_mul. let ctx := BoundsAnalysis.Indexed.Context.extendb (PositiveMap.empty _) 1 arg in - @rexpr_1_correctT_ctx ctx out_bounds (f arg') rv. + @rexpr_1_correctT_ctx relax_zrange t_out ctx out_bounds (f arg') rv. Definition rexpr_21_correctT in_bounds out_bounds (f : _ -> type.interp (type.list type.Z)) rv - := @rexpr_n1_correctT (type.list type.Z * type.list type.Z) + := @rexpr_n1_correctT relax_zrange + (type.list type.Z * type.list type.Z) + (type.list type.Z) (in_bounds, in_bounds) out_bounds f rv. Definition rexpr_11_correctT in_bounds out_bounds (f : _ -> type.interp (type.list type.Z)) rv - := @rexpr_n1_correctT (type.list type.Z) + := @rexpr_n1_correctT relax_zrange + (type.list type.Z) + (type.list type.Z) in_bounds out_bounds f rv. Definition rexpr_Z1_correctT in_bounds out_bounds (f : _ -> type.interp (type.list type.Z)) rv - := @rexpr_n1_correctT type.Z + := @rexpr_n1_correctT relax_zrange type.Z (type.list type.Z) in_bounds out_bounds f rv. Definition rexpr_01_correctT out_bounds (f : type.interp (type.list type.Z)) rv - := @rexpr_1_correctT_ctx (PositiveMap.empty _) out_bounds f rv. + := @rexpr_1_correctT_ctx relax_zrange (type.list type.Z) (PositiveMap.empty _) out_bounds f rv. Definition rcarry_mul := let res := BoundsPipeline21 @@ -5858,6 +5985,7 @@ Section rcarry_mul. Definition rcarry := let res := Pipeline.BoundsPipeline + false relax_zrange (s:=(type.list type.Z)%ctype) (d:=(type.list type.Z)%ctype) @@ -5890,6 +6018,7 @@ Section rcarry_mul. Definition rrelax := let res := Pipeline.BoundsPipeline + false relax_zrange (s:=(type.list type.Z)%ctype) (d:=(type.list type.Z)%ctype) @@ -5996,6 +6125,7 @@ Section rcarry_mul. Definition rencode := let res := Pipeline.BoundsPipeline + false relax_zrange (s:=type.Z) (d:=(type.list type.Z)%ctype) @@ -6026,6 +6156,7 @@ Section rcarry_mul. Definition rzero := let res := Pipeline.BoundsPipelineConst + false relax_zrange (t:=(type.list type.Z)%ctype) tight_bounds @@ -6054,6 +6185,7 @@ Section rcarry_mul. Definition rone := let res := Pipeline.BoundsPipelineConst + false relax_zrange (t:=(type.list type.Z)%ctype) tight_bounds @@ -6281,7 +6413,7 @@ Section rcarry_mul. => cbv [tight_bounds loose_bounds] in H |- *; rewrite H | _ => rewrite BoundsAnalysis.OfPHOAS.cast_back_primitive_cast_primitive in * | [ fx := (BoundsAnalysis.Indexed.expr.interp _ ?rop (BoundsAnalysis.Indexed.Context.extendb _ _ ?x)), - H : forall arg, BoundsAnalysis.Indexed.Range.type_for_range_bounded_by _ ?bs arg = true -> rexpr_1_correctT_ctx _ _ _ ?rop + H : forall arg, BoundsAnalysis.Indexed.Range.type_for_range_bounded_by _ ?bs arg = true -> rexpr_1_correctT_ctx _ _ _ _ _ ?rop |- _ ] => destruct (H x); [ clear H | subst fx ] | [ fx := Option.bind ?x ?f, H : ?x = Some ?v |- _ ] @@ -6291,9 +6423,6 @@ Section rcarry_mul. assert (H' : fx = fx') by (subst fx fx'; apply f_equal2; [ exact H | reflexivity ]); cbn [Option.bind] in fx'; clearbody fx; subst fx - (*| [ H : forall arg, BoundsAnalysis.Indexed.Range.type_for_range_bounded_by _ ?bs arg = true -> rexpr_1_correctT_ctx _ _ _ ?rop - |- context[BoundsAnalysis.Indexed.expr.interp _ ?rop (BoundsAnalysis.Indexed.Context.extendb _ _ ?x)] ] - => specialize (H x); destruct H*) | _ => progress cbn [BoundsAnalysis.OfPHOAS.cast_back] in * (* for getting eauto to work *) end. @@ -6574,6 +6703,7 @@ Module PrintingNotations. Export BoundsAnalysis.ident. Export BoundsAnalysis.Notations. Open Scope btype_scope. + Global Set Printing Width 100000. Notation "'uint256'" := (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%btype) : btype_scope. Notation "'uint128'" @@ -6950,71 +7080,6 @@ Module X25519_32. End X25519_32. *) -Module RemoveDeadLets. - Import BoundsAnalysis.Indexed.expr. - Section RemoveDeadLets. - Local Notation ident := BoundsAnalysis.ident.ident. - - Fixpoint let_used (t : BoundsAnalysis.type.type) (n : positive) - (e : @expr ident t) : bool := - match e with - | Var T m => Pos.eqb n m - | TT => false - | AppIdent s _ _ x => let_used s n x - | Pair A B a b => (let_used A n a) || (let_used B n b) - | Let_In s d m x f => - (negb (Pos.eqb n m && negb (let_used s n x))) && ((let_used s n x) || (let_used d n f)) - end. - - Fixpoint remove_dead_lets (t : BoundsAnalysis.type.type) (e : @expr ident t) : @expr ident t := - match e in (expr t') return expr t' with - | Var T n => Var T n - | TT => TT - | AppIdent s T idc x => - AppIdent idc (remove_dead_lets _ x) - | Pair A B a b => Pair (remove_dead_lets _ a) (remove_dead_lets _ b) - | Let_In s T n x f => - if (let_used T n f) - then Let_In n (remove_dead_lets _ x) (remove_dead_lets _ f) - else remove_dead_lets _ f - end. - - Fixpoint inline_let (idx : positive) Tnew (new : @expr ident Tnew) t (e : @expr ident t) : @expr ident t := - match e in expr t' return expr t' with - | Var T n => if (Pos.eqb n idx) - then match BoundsAnalysis.type.transport (@expr ident) Tnew T new with - | Some new' => new' - | None => Var T n - end - else Var T n - | TT => TT - | AppIdent s T idc x => AppIdent idc (inline_let idx _ new _ x) - | Pair A B a b => Pair (inline_let idx _ new _ a) (inline_let idx _ new _ b) - | Let_In s T n x f => Let_In n (inline_let idx _ new _ x) (inline_let idx _ new _ f) - end. - - (* inlines lets that just re-bind a variable or half a variable with type prod *) - Fixpoint inline_silly_lets t (e : @expr ident t) : @expr ident t := - match e in (expr t') return expr t' with - | Var T n => Var T n - | TT => TT - | AppIdent s T idc x => - AppIdent idc (inline_silly_lets _ x) - | Pair A B a b => Pair (inline_silly_lets _ a) (inline_silly_lets _ b) - | Let_In s T n x f => - match x with - | Var T' m => inline_let n _ (Var T' m) _ f - | AppIdent _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m) => - inline_let n _ (@AppIdent _ _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m)) _ (inline_silly_lets _ f) - | _ => Let_In n (inline_silly_lets _ x) (inline_silly_lets _ f) - end - end. - - (* TODO: proofs--note these may block on getting canonical maps for contexts *) - (* TODO(jgross, from jadep): Should I put this into the pipeline? *) - End RemoveDeadLets. -End RemoveDeadLets. - Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. Require Import Crypto.Util.ZUtil.EquivModulo. @@ -7126,30 +7191,7 @@ Module MontgomeryReduction. montred_gen N R N' w w_half n lo_hi = montred' N R N' w w_half n lo_hi) As montred_gen_correct. - Proof. - intros. - etransitivity. - Focus 2. - { repeat apply (f_equal (fun f => f _)). - Reify_rhs(). - reflexivity. } Unfocus. - cbv beta. - let RHS := match goal with |- _ = ?RHS => RHS end in - let e := match RHS with context[expr.Interp _ ?e] => e end in - set (E := e). - Time let E' := constr:(PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) in - let E' := (eval vm_compute in E') in (* 0.131 for vm, about 0.6 for lazy, slower for native and cbv *) - pose E' as E''. - let LHS := match goal with |- ?LHS = _ => LHS end in - lazymatch LHS with - | context LHS[expr.Interp _ _] - => let LHS := context LHS[Interp E''] in - transitivity LHS - end; - [ clear E | exact admit ]. - subst montred_gen. - reflexivity. - Qed. + Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. Section rmontred. Context (N R N' : Z) @@ -7191,6 +7233,7 @@ Module MontgomeryReduction. Definition rmontred := let res := Pipeline.BoundsPipeline + true (relax_zrange) (s:=(type.Z * type.Z)%ctype) (d:=(type.Z)%ctype) @@ -7205,44 +7248,48 @@ Module MontgomeryReduction. @ (rw_half _) @ (rn _) )%expr in - check_args res. + res. + + (* copied from above *) + Local Ltac solve_correct_gen pipeline_lem gen_correct := + let Hrv := lazymatch goal with H : ?rop = Pipeline.Success _ |- _ => H end in + let rop := lazymatch type of Hrv with ?rop = Pipeline.Success _ => rop end in + hnf; intros; cbv [rop] in Hrv; + eapply pipeline_lem in Hrv; [ | eassumption.. ]; + let res := fresh "res" in + destruct Hrv as [res Hrv]; + exists res; do 2 try apply conj; + [ | | etransitivity ]; + [ solve [ apply Hrv ].. | ]; + repeat match goal with H := _ |- _ => subst H end; + erewrite <- gen_correct; + cbv [expr.Interp]; + cbn [expr.interp]; + f_equal; + cbn -[reify_list]; + try (rewrite interp_reify_list, map_map; cbn; + erewrite map_ext with (g:=id), map_id; try reflexivity); + try (intros []; reflexivity). + Local Ltac solve_correct gen_correct := + solve_correct_gen Pipeline.BoundsPipeline_correct gen_correct. + Local Ltac solve_correct_const gen_correct := + solve_correct_gen Pipeline.BoundsPipelineConst_correct gen_correct. Definition rmontred_correctT rv - := forall arg - (arg' := @BoundsAnalysis.OfPHOAS.cast_back - _ - (relax_zrange) - (arg_bounds) - arg), - BoundsAnalysis.OfPHOAS.Interp - (relax_zrange) - (arg_bounds) - (bs:=out_bounds) - arg - rv - = Some (montred' (Interp rN) (Interp rR) (Interp rN') (Interp rw) (Interp rw_half) (Interp rn) arg'). + := Eval hnf in + @rexpr_n1_correctT + relax_zrange + ((type.Z * type.Z)%ctype) type.Z + arg_bounds out_bounds + (montred' (Interp rN) (Interp rR) (Interp rN') (Interp rw) (Interp rw_half) (Interp rn)) + rv. Lemma rmontred_correct rv (Hrv : rmontred = Pipeline.Success rv) : rmontred_correctT rv. - Proof. - hnf; intros. - cbv [rmontred] in Hrv. - edestruct (Pipeline.BoundsPipeline _ _ _ _) as [rv'|] eqn:Hrv'; - [ | clear -Hrv; cbv [check_args] in Hrv; break_innermost_match_hyps; discriminate ]. - erewrite <- montred_gen_correct. - eapply Pipeline.BoundsPipeline_correct in Hrv'. - apply check_args_success_id in Hrv; inversion Hrv; subst rv. - rewrite Hrv'. - cbv [expr.Interp]. - cbn [expr.interp]. - apply f_equal; f_equal; - cbn -[reify_list]; - rewrite interp_reify_list, map_map; cbn; - erewrite map_ext with (g:=id), map_id; try reflexivity. - Qed. + Proof. solve_correct montred_gen_correct. Qed. End rmontred. End MontgomeryReduction. @@ -7257,25 +7304,14 @@ Module Montgomery256. Definition R := (2^256). Definition machine_wordsize := 256. - Derive montred256_with_dead_code - SuchThat (MontgomeryReduction.rmontred_correctT N R N' machine_wordsize montred256_with_dead_code) + Derive montred256 + SuchThat (MontgomeryReduction.rmontred_correctT N R N' machine_wordsize montred256) As montred256_correct. Proof. Time solve_rmontred(). Time Qed. - (* TODO: if dead code calls dead code, then remove_dead_lets can - progress if called multiple times. Should probably fix this, but - termination of fixpoints is hard *) - Definition montred256 := Eval lazy in - (RemoveDeadLets.remove_dead_lets _ - (RemoveDeadLets.remove_dead_lets _ - (RemoveDeadLets.remove_dead_lets _ - (RemoveDeadLets.remove_dead_lets _ - (RemoveDeadLets.remove_dead_lets _ - (RemoveDeadLets.remove_dead_lets _ - (RemoveDeadLets.inline_silly_lets _ montred256_with_dead_code))))))). - Import PrintingNotations. Open Scope nexpr_scope. + Set Printing Width 100000. Print montred256. (* expr_let 3 := (uint128)(fst @@ x_1 >> 128) in -- cgit v1.2.3