diff options
author | Jason Gross <jgross@mit.edu> | 2017-04-15 02:01:56 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2017-05-14 00:52:04 -0400 |
commit | 096a24265d4df0bbb5321c6fa794577bee5cae25 (patch) | |
tree | 4b7cbeefaf12fc5ce836e80864a6221c7b44dcf9 | |
parent | 63e036b685457b7ecfb44e6caf966c4a7e8462d1 (diff) |
CSE without inlining arithmetic expressions
This takes care of most of #158. The remaining bits are reworking the
Wf and interpretation lemmas to actually work. (The former needs a only
bit of rethinking and rephrasing to handle the fact that sometimes we
change the stored symbolic expression from a complicated one to a fresh
variable, while the latter needs major surgery, which Adam tells me is
easy, and this is a note that when I come back to it, I should look at
the email thread with Adam about CSE from last summer.)
-rw-r--r-- | src/Compilers/CommonSubexpressionElimination.v | 16 | ||||
-rw-r--r-- | src/Compilers/CommonSubexpressionEliminationInterp.v | 9 | ||||
-rw-r--r-- | src/Compilers/CommonSubexpressionEliminationWf.v | 56 | ||||
-rw-r--r-- | src/Compilers/TestCase.v | 2 | ||||
-rw-r--r-- | src/Compilers/Z/Bounds/Pipeline/Definition.v | 4 | ||||
-rw-r--r-- | src/Compilers/Z/CommonSubexpressionElimination.v | 16 | ||||
-rw-r--r-- | src/Compilers/Z/CommonSubexpressionEliminationInterp.v | 8 | ||||
-rw-r--r-- | src/Compilers/Z/CommonSubexpressionEliminationWf.v | 8 | ||||
-rw-r--r-- | src/Specific/FancyMachine256/Core.v | 2 |
9 files changed, 78 insertions, 43 deletions
diff --git a/src/Compilers/CommonSubexpressionElimination.v b/src/Compilers/CommonSubexpressionElimination.v index de91f03b6..c5ffeb088 100644 --- a/src/Compilers/CommonSubexpressionElimination.v +++ b/src/Compilers/CommonSubexpressionElimination.v @@ -57,7 +57,8 @@ Section symbolic. (op_code_leb : op_code -> op_code -> bool) (base_type_leb : base_type_code -> base_type_code -> bool). Local Notation symbolic_expr := (symbolic_expr base_type_code op_code). - Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr). + Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr) + (inline_symbolic_expr_in_lookup : bool). Local Notation symbolic_expr_beq := (@symbolic_expr_beq base_type_code op_code base_type_code_beq op_code_beq). Local Notation symbolic_expr_lb := (@internal_symbolic_expr_dec_lb base_type_code op_code base_type_code_beq op_code_beq base_type_code_lb op_code_lb). @@ -223,11 +224,12 @@ Section symbolic. | Some sx => (sx, lookupb xs sx) | None => (symbolize_var xs tx, None) end in + let reduced_sx := if inline_symbolic_expr_in_lookup then sx else symbolize_var xs tx in match sv with - | Some v => @csef _ (eC (symbolicify_smart_var v sx)) xs + | Some v => @csef _ (eC (symbolicify_smart_var v reduced_sx)) (extendb xs reduced_sx v) | None - => LetIn ex' (fun x => let sx' := symbolicify_smart_var x sx in - @csef _ (eC sx') (extendb xs sx x)) + => LetIn ex' (fun x => let sx' := symbolicify_smart_var x reduced_sx in + @csef _ (eC sx') (extendb (extendb xs sx x) reduced_sx x)) end | TT => TT | Var _ x => Var (fst x) @@ -252,6 +254,6 @@ Section symbolic. := fun var => cse (prefix _) (e _) empty. End symbolic. -Global Arguments csef {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments {var t} _ _. -Global Arguments cse {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments {var} prefix {t} _ _. -Global Arguments CSE {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments {t} e prefix var. +Global Arguments csef {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup {var t} _ _. +Global Arguments cse {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup {var} prefix {t} _ _. +Global Arguments CSE {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup {t} e prefix var. diff --git a/src/Compilers/CommonSubexpressionEliminationInterp.v b/src/Compilers/CommonSubexpressionEliminationInterp.v index e0ebf71a4..0d7fa1e25 100644 --- a/src/Compilers/CommonSubexpressionEliminationInterp.v +++ b/src/Compilers/CommonSubexpressionEliminationInterp.v @@ -40,7 +40,8 @@ Section symbolic. (symbolize_op : forall s d, op s d -> op_code) (denote_op : forall s d, op_code -> option (op s d)). Local Notation symbolic_expr := (symbolic_expr base_type_code op_code). - Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr). + Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr) + (inline_symbolic_expr_in_lookup : bool). Local Notation symbolic_expr_beq := (@symbolic_expr_beq base_type_code op_code base_type_code_beq op_code_beq). Local Notation symbolic_expr_lb := (@internal_symbolic_expr_dec_lb base_type_code op_code base_type_code_beq op_code_beq base_type_code_lb op_code_lb). @@ -57,9 +58,9 @@ Section symbolic. Local Notation symbolicify_smart_var := (@symbolicify_smart_var base_type_code op_code). Local Notation symbolize_exprf := (@symbolize_exprf base_type_code op_code op symbolize_op). - Local Notation csef := (@csef base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments). - Local Notation cse := (@cse base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments). - Local Notation CSE := (@CSE base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments). + Local Notation csef := (@csef base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup). + Local Notation cse := (@cse base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup). + Local Notation CSE := (@CSE base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup). Local Notation SymbolicExprContext := (@SymbolicExprContext base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl). Local Notation SymbolicExprContextOk := (@SymbolicExprContextOk base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl base_type_code_lb op_code_bl op_code_lb). Local Notation prepend_prefix := (@prepend_prefix base_type_code op). diff --git a/src/Compilers/CommonSubexpressionEliminationWf.v b/src/Compilers/CommonSubexpressionEliminationWf.v index b66db5c33..8d8a5d86f 100644 --- a/src/Compilers/CommonSubexpressionEliminationWf.v +++ b/src/Compilers/CommonSubexpressionEliminationWf.v @@ -14,6 +14,7 @@ Require Import Crypto.Util.Bool. Require Import Crypto.Util.Tactics.RewriteHyp. Require Import Crypto.Util.Tactics.BreakMatch. Require Import Crypto.Util.Tactics.DestructHead. +Require Import Crypto.Util.Tactics.UniquePose. Require Import Crypto.Util.Tactics.SplitInContext. Require Import Crypto.Util.Decidable. @@ -26,11 +27,11 @@ Section symbolic. (base_type_code_lb : forall x y, x = y -> base_type_code_beq x y = true) (op_code_bl : forall x y, op_code_beq x y = true -> x = y) (op_code_lb : forall x y, x = y -> op_code_beq x y = true) - (interp_base_type : base_type_code -> Type) (op : flat_type base_type_code -> flat_type base_type_code -> Type) (symbolize_op : forall s d, op s d -> op_code). Local Notation symbolic_expr := (symbolic_expr base_type_code op_code). - Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr). + Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr) + (inline_symbolic_expr_in_lookup : bool). Local Notation symbolic_expr_beq := (@symbolic_expr_beq base_type_code op_code base_type_code_beq op_code_beq). Local Notation symbolic_expr_lb := (@internal_symbolic_expr_dec_lb base_type_code op_code base_type_code_beq op_code_beq base_type_code_lb op_code_lb). @@ -38,9 +39,6 @@ Section symbolic. Local Notation flat_type := (flat_type base_type_code). Local Notation type := (type base_type_code). - Local Notation interp_type := (interp_type interp_base_type). - Local Notation interp_flat_type_gen := interp_flat_type. - Local Notation interp_flat_type := (interp_flat_type interp_base_type). Local Notation exprf := (@exprf base_type_code op). Local Notation expr := (@expr base_type_code op). Local Notation Expr := (@Expr base_type_code op). @@ -48,9 +46,9 @@ Section symbolic. Local Notation symbolicify_smart_var := (@symbolicify_smart_var base_type_code op_code). Local Notation symbolize_exprf := (@symbolize_exprf base_type_code op_code op symbolize_op). Local Notation norm_symbolize_exprf := (@norm_symbolize_exprf base_type_code op_code op symbolize_op normalize_symbolic_op_arguments). - Local Notation csef := (@csef base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments). - Local Notation cse := (@cse base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments). - Local Notation CSE := (@CSE base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments). + Local Notation csef := (@csef base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup). + Local Notation cse := (@cse base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup). + Local Notation CSE := (@CSE base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup). Local Notation SymbolicExprContext := (@SymbolicExprContext base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl). Local Notation SymbolicExprContextOk := (@SymbolicExprContextOk base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl base_type_code_lb op_code_bl op_code_lb). Local Notation prepend_prefix := (@prepend_prefix base_type_code op). @@ -82,8 +80,8 @@ Section symbolic. Local Arguments lookupb : simpl never. Local Arguments extendb : simpl never. Lemma wff_csef G G' t e1 e2 - (m1 : @SymbolicExprContext (interp_flat_type_gen var1)) - (m2 : @SymbolicExprContext (interp_flat_type_gen var2)) + (m1 : @SymbolicExprContext (interp_flat_type var1)) + (m2 : @SymbolicExprContext (interp_flat_type var2)) (Hlen : length m1 = length m2) (Hm1m2None : forall t v, lookupb m1 v t = None <-> lookupb m2 v t = None) (Hm1m2Some : forall t v sv1 sv2, @@ -108,9 +106,13 @@ Section symbolic. | [ H : lookupb ?m1 ?x = Some ?k, H' : lookupb ?m2 ?x = None |- _ ] => apply Hm1m2None in H'; congruence end; - [ | constructor; intros; auto; [].. ]; + lazymatch goal with + | [ |- wff _ (LetIn _ _) (LetIn _ _) ] + => constructor; intros; auto; [] + | _ => idtac + end; match goal with H : _ |- _ => apply H end; - repeat first [ progress unfold symbolize_var + try solve [ repeat first [ progress unfold symbolize_var | rewrite Hlen | progress subst | setoid_rewrite length_extendb @@ -130,8 +132,34 @@ Section symbolic. | break_innermost_match_step | break_innermost_match_hyps_step | progress simpl in * - | solve [ intuition (eauto || congruence) ] ]. } - Qed. + | solve [ intuition (eauto || congruence) ] + | match goal with + | [ H : forall t x y, _ |- _ ] => specialize (fun t x0 x1 y0 y1 => H t (x0, x1) (y0, y1)); cbn [fst snd] in H + | [ H : In (existT _ ?t (?x, ?x')) (flatten_binding_list (symbolicify_smart_var _ _) (symbolicify_smart_var _ _)), + Hm1m2Some : forall t v sv1 sv2, _ -> _ -> forall k', In k' (flatten_binding_list _ _) -> In k' ?G |- _ ] + => is_var x; is_var x'; + lazymatch goal with + | [ H : In (existT _ t ((fst x, _), (fst x', _))) G |- _ ] => fail + | _ => let H' := fresh in + refine (let H' := flatten_binding_list_SmartVarfMap2_pair_in_generalize2 H _ _ in _); + destruct H' as [? [? H']]; + eapply Hm1m2Some in H'; [ | eassumption.. ] + end + end ] ]. + repeat first [ progress unfold symbolize_var + | rewrite Hlen + | progress subst + | setoid_rewrite length_extendb + | setoid_rewrite List.in_app_iff + | progress destruct_head' or + | solve [ eauto ] + | progress intros ]. + (** FIXME: This actually isn't true, because the symbolic + expr stored in G might not be the same as the one in the + expression tree, when the one in the expression tree is a + fresh var *) + admit. } + Admitted. Lemma wff_prepend_prefix {var1' var2'} prefix1 prefix2 G t e1 e2 (Hlen : length prefix1 = length prefix2) diff --git a/src/Compilers/TestCase.v b/src/Compilers/TestCase.v index a7fd81328..36774e4e3 100644 --- a/src/Compilers/TestCase.v +++ b/src/Compilers/TestCase.v @@ -187,7 +187,7 @@ Section cse. | Mul => SMul | Sub => SSub end. - Definition CSE {t} e := @CSE base_type op_code base_type_beq op_code_beq internal_base_type_dec_bl op symbolicify_op (fun _ x => x) t e (fun _ => nil). + Definition CSE {t} e := @CSE base_type op_code base_type_beq op_code_beq internal_base_type_dec_bl op symbolicify_op (fun _ x => x) true t e (fun _ => nil). End cse. Definition example_expr_simplified := Eval vm_compute in InlineConst is_const (ANormal example_expr). diff --git a/src/Compilers/Z/Bounds/Pipeline/Definition.v b/src/Compilers/Z/Bounds/Pipeline/Definition.v index fd131dffa..8ca3c9b46 100644 --- a/src/Compilers/Z/Bounds/Pipeline/Definition.v +++ b/src/Compilers/Z/Bounds/Pipeline/Definition.v @@ -61,6 +61,9 @@ Require Import Crypto.Compilers.Z.InlineWf. Require Import Crypto.Compilers.Linearize. Require Import Crypto.Compilers.LinearizeInterp. Require Import Crypto.Compilers.LinearizeWf. +Require Import Crypto.Compilers.Z.CommonSubexpressionElimination. +Require Import Crypto.Compilers.Z.CommonSubexpressionEliminationInterp. +Require Import Crypto.Compilers.Z.CommonSubexpressionEliminationWf. Require Import Crypto.Compilers.Z.ArithmeticSimplifierWf. Require Import Crypto.Compilers.Z.Bounds.MapCastByDeBruijn. Require Import Crypto.Compilers.Z.Bounds.MapCastByDeBruijnInterp. @@ -80,6 +83,7 @@ Definition PostWfPipeline let e := SimplifyArith e in let e := ANormal e in let e := InlineConst e in + let e := CSE false e in let e := MapCast _ e input_bounds in option_map (projT2_map diff --git a/src/Compilers/Z/CommonSubexpressionElimination.v b/src/Compilers/Z/CommonSubexpressionElimination.v index 210c2c7c8..6695d137e 100644 --- a/src/Compilers/Z/CommonSubexpressionElimination.v +++ b/src/Compilers/Z/CommonSubexpressionElimination.v @@ -151,25 +151,25 @@ Definition normalize_symbolic_expr_mod_c (opc : symbolic_op) (args : symbolic_ex => args end. -Definition csef {var t} (v : exprf _ _ t) xs +Definition csef inline_symbolic_expr_in_lookup {var t} (v : exprf _ _ t) xs := @csef base_type symbolic_op base_type_beq symbolic_op_beq internal_base_type_dec_bl op symbolize_op normalize_symbolic_expr_mod_c - var t v xs. + var inline_symbolic_expr_in_lookup t v xs. -Definition cse {var} (prefix : list _) {t} (v : expr _ _ t) xs +Definition cse inline_symbolic_expr_in_lookup {var} (prefix : list _) {t} (v : expr _ _ t) xs := @cse base_type symbolic_op base_type_beq symbolic_op_beq internal_base_type_dec_bl op symbolize_op normalize_symbolic_expr_mod_c - var prefix t v xs. + inline_symbolic_expr_in_lookup var prefix t v xs. -Definition CSE_gen {t} (e : Expr _ _ t) (prefix : forall var, list { t : flat_type base_type & exprf _ _ t }) +Definition CSE_gen inline_symbolic_expr_in_lookup {t} (e : Expr _ _ t) (prefix : forall var, list { t : flat_type base_type & exprf _ _ t }) : Expr _ _ t := @CSE base_type symbolic_op base_type_beq symbolic_op_beq internal_base_type_dec_bl op symbolize_op normalize_symbolic_expr_mod_c - t e prefix. + inline_symbolic_expr_in_lookup t e prefix. -Definition CSE {t} (e : Expr _ _ t) +Definition CSE inline_symbolic_expr_in_lookup {t} (e : Expr _ _ t) : Expr _ _ t - := @CSE_gen t e (fun _ => nil). + := @CSE_gen inline_symbolic_expr_in_lookup t e (fun _ => nil). diff --git a/src/Compilers/Z/CommonSubexpressionEliminationInterp.v b/src/Compilers/Z/CommonSubexpressionEliminationInterp.v index 6552084d9..280039058 100644 --- a/src/Compilers/Z/CommonSubexpressionEliminationInterp.v +++ b/src/Compilers/Z/CommonSubexpressionEliminationInterp.v @@ -6,16 +6,16 @@ Require Import Crypto.Compilers.Z.Syntax. Require Import Crypto.Compilers.CommonSubexpressionEliminationInterp. Require Import Crypto.Compilers.Z.CommonSubexpressionElimination. -Lemma InterpCSE_gen t (e : Expr _ _ t) prefix +Lemma InterpCSE_gen inline_symbolic_expr_in_lookup t (e : Expr _ _ t) prefix (Hwf : Wf e) - : forall x, Interp interp_op (@CSE_gen t e prefix) x = Interp interp_op e x. + : forall x, Interp interp_op (@CSE_gen inline_symbolic_expr_in_lookup t e prefix) x = Interp interp_op e x. Proof. apply InterpCSE; auto using internal_base_type_dec_bl, internal_base_type_dec_lb, internal_symbolic_op_dec_bl, internal_symbolic_op_dec_lb, denote_symbolic_op. Qed. -Lemma InterpCSE t (e : Expr _ _ t) (Hwf : Wf e) - : forall x, Interp interp_op (@CSE t e) x = Interp interp_op e x. +Lemma InterpCSE inline_symbolic_expr_in_lookup t (e : Expr _ _ t) (Hwf : Wf e) + : forall x, Interp interp_op (@CSE inline_symbolic_expr_in_lookup t e) x = Interp interp_op e x. Proof. apply InterpCSE_gen; auto. Qed. diff --git a/src/Compilers/Z/CommonSubexpressionEliminationWf.v b/src/Compilers/Z/CommonSubexpressionEliminationWf.v index 4f46f9454..a7365397c 100644 --- a/src/Compilers/Z/CommonSubexpressionEliminationWf.v +++ b/src/Compilers/Z/CommonSubexpressionEliminationWf.v @@ -5,7 +5,7 @@ Require Import Crypto.Compilers.Z.Syntax. Require Import Crypto.Compilers.CommonSubexpressionEliminationWf. Require Import Crypto.Compilers.Z.CommonSubexpressionElimination. -Lemma Wf_CSE_gen t (e : Expr _ _ t) +Lemma Wf_CSE_gen inline_symbolic_expr_in_lookup t (e : Expr _ _ t) prefix (Hlen : forall var1 var2, length (prefix var1) = length (prefix var2)) (Hprefix : forall var1 var2 n t1 t2 e1 e2, @@ -13,14 +13,14 @@ Lemma Wf_CSE_gen t (e : Expr _ _ t) -> List.nth_error (prefix var2) n = Some (existT _ t2 e2) -> exists pf : t1 = t2, wff nil (eq_rect _ (@exprf _ _ _) e1 _ pf) e2) (Hwf : Wf e) - : Wf (@CSE_gen t e prefix). + : Wf (@CSE_gen inline_symbolic_expr_in_lookup t e prefix). Proof. apply Wf_CSE; auto using internal_base_type_dec_bl, internal_base_type_dec_lb, internal_symbolic_op_dec_bl, internal_symbolic_op_dec_lb. Qed. -Lemma Wf_CSE t (e : Expr _ _ t) +Lemma Wf_CSE inline_symbolic_expr_in_lookup t (e : Expr _ _ t) (Hwf : Wf e) - : Wf (@CSE t e). + : Wf (@CSE inline_symbolic_expr_in_lookup t e). Proof. apply Wf_CSE_gen; simpl; auto. { destruct n; simpl; try congruence. } diff --git a/src/Specific/FancyMachine256/Core.v b/src/Specific/FancyMachine256/Core.v index 0d521ae17..cd54402ea 100644 --- a/src/Specific/FancyMachine256/Core.v +++ b/src/Specific/FancyMachine256/Core.v @@ -108,7 +108,7 @@ Section reflection. | OPaddm => SOPaddm end. - Definition CSE {t} e := @CSE base_type op_code base_type_beq op_code_beq internal_base_type_dec_bl op symbolicify_op (fun _ x => x) t e (fun _ => nil). + Definition CSE {t} e := @CSE base_type op_code base_type_beq op_code_beq internal_base_type_dec_bl op symbolicify_op (fun _ x => x) true t e (fun _ => nil). Inductive inline_option := opt_inline | opt_default | opt_noinline. |