aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-04-15 02:01:56 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2017-05-14 00:52:04 -0400
commit096a24265d4df0bbb5321c6fa794577bee5cae25 (patch)
tree4b7cbeefaf12fc5ce836e80864a6221c7b44dcf9
parent63e036b685457b7ecfb44e6caf966c4a7e8462d1 (diff)
CSE without inlining arithmetic expressions
This takes care of most of #158. The remaining bits are reworking the Wf and interpretation lemmas to actually work. (The former needs a only bit of rethinking and rephrasing to handle the fact that sometimes we change the stored symbolic expression from a complicated one to a fresh variable, while the latter needs major surgery, which Adam tells me is easy, and this is a note that when I come back to it, I should look at the email thread with Adam about CSE from last summer.)
-rw-r--r--src/Compilers/CommonSubexpressionElimination.v16
-rw-r--r--src/Compilers/CommonSubexpressionEliminationInterp.v9
-rw-r--r--src/Compilers/CommonSubexpressionEliminationWf.v56
-rw-r--r--src/Compilers/TestCase.v2
-rw-r--r--src/Compilers/Z/Bounds/Pipeline/Definition.v4
-rw-r--r--src/Compilers/Z/CommonSubexpressionElimination.v16
-rw-r--r--src/Compilers/Z/CommonSubexpressionEliminationInterp.v8
-rw-r--r--src/Compilers/Z/CommonSubexpressionEliminationWf.v8
-rw-r--r--src/Specific/FancyMachine256/Core.v2
9 files changed, 78 insertions, 43 deletions
diff --git a/src/Compilers/CommonSubexpressionElimination.v b/src/Compilers/CommonSubexpressionElimination.v
index de91f03b6..c5ffeb088 100644
--- a/src/Compilers/CommonSubexpressionElimination.v
+++ b/src/Compilers/CommonSubexpressionElimination.v
@@ -57,7 +57,8 @@ Section symbolic.
(op_code_leb : op_code -> op_code -> bool)
(base_type_leb : base_type_code -> base_type_code -> bool).
Local Notation symbolic_expr := (symbolic_expr base_type_code op_code).
- Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr).
+ Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr)
+ (inline_symbolic_expr_in_lookup : bool).
Local Notation symbolic_expr_beq := (@symbolic_expr_beq base_type_code op_code base_type_code_beq op_code_beq).
Local Notation symbolic_expr_lb := (@internal_symbolic_expr_dec_lb base_type_code op_code base_type_code_beq op_code_beq base_type_code_lb op_code_lb).
@@ -223,11 +224,12 @@ Section symbolic.
| Some sx => (sx, lookupb xs sx)
| None => (symbolize_var xs tx, None)
end in
+ let reduced_sx := if inline_symbolic_expr_in_lookup then sx else symbolize_var xs tx in
match sv with
- | Some v => @csef _ (eC (symbolicify_smart_var v sx)) xs
+ | Some v => @csef _ (eC (symbolicify_smart_var v reduced_sx)) (extendb xs reduced_sx v)
| None
- => LetIn ex' (fun x => let sx' := symbolicify_smart_var x sx in
- @csef _ (eC sx') (extendb xs sx x))
+ => LetIn ex' (fun x => let sx' := symbolicify_smart_var x reduced_sx in
+ @csef _ (eC sx') (extendb (extendb xs sx x) reduced_sx x))
end
| TT => TT
| Var _ x => Var (fst x)
@@ -252,6 +254,6 @@ Section symbolic.
:= fun var => cse (prefix _) (e _) empty.
End symbolic.
-Global Arguments csef {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments {var t} _ _.
-Global Arguments cse {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments {var} prefix {t} _ _.
-Global Arguments CSE {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments {t} e prefix var.
+Global Arguments csef {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup {var t} _ _.
+Global Arguments cse {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup {var} prefix {t} _ _.
+Global Arguments CSE {_} op_code base_type_code_beq op_code_beq base_type_code_bl {_} symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup {t} e prefix var.
diff --git a/src/Compilers/CommonSubexpressionEliminationInterp.v b/src/Compilers/CommonSubexpressionEliminationInterp.v
index e0ebf71a4..0d7fa1e25 100644
--- a/src/Compilers/CommonSubexpressionEliminationInterp.v
+++ b/src/Compilers/CommonSubexpressionEliminationInterp.v
@@ -40,7 +40,8 @@ Section symbolic.
(symbolize_op : forall s d, op s d -> op_code)
(denote_op : forall s d, op_code -> option (op s d)).
Local Notation symbolic_expr := (symbolic_expr base_type_code op_code).
- Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr).
+ Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr)
+ (inline_symbolic_expr_in_lookup : bool).
Local Notation symbolic_expr_beq := (@symbolic_expr_beq base_type_code op_code base_type_code_beq op_code_beq).
Local Notation symbolic_expr_lb := (@internal_symbolic_expr_dec_lb base_type_code op_code base_type_code_beq op_code_beq base_type_code_lb op_code_lb).
@@ -57,9 +58,9 @@ Section symbolic.
Local Notation symbolicify_smart_var := (@symbolicify_smart_var base_type_code op_code).
Local Notation symbolize_exprf := (@symbolize_exprf base_type_code op_code op symbolize_op).
- Local Notation csef := (@csef base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments).
- Local Notation cse := (@cse base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments).
- Local Notation CSE := (@CSE base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments).
+ Local Notation csef := (@csef base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup).
+ Local Notation cse := (@cse base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup).
+ Local Notation CSE := (@CSE base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup).
Local Notation SymbolicExprContext := (@SymbolicExprContext base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl).
Local Notation SymbolicExprContextOk := (@SymbolicExprContextOk base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl base_type_code_lb op_code_bl op_code_lb).
Local Notation prepend_prefix := (@prepend_prefix base_type_code op).
diff --git a/src/Compilers/CommonSubexpressionEliminationWf.v b/src/Compilers/CommonSubexpressionEliminationWf.v
index b66db5c33..8d8a5d86f 100644
--- a/src/Compilers/CommonSubexpressionEliminationWf.v
+++ b/src/Compilers/CommonSubexpressionEliminationWf.v
@@ -14,6 +14,7 @@ Require Import Crypto.Util.Bool.
Require Import Crypto.Util.Tactics.RewriteHyp.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Tactics.DestructHead.
+Require Import Crypto.Util.Tactics.UniquePose.
Require Import Crypto.Util.Tactics.SplitInContext.
Require Import Crypto.Util.Decidable.
@@ -26,11 +27,11 @@ Section symbolic.
(base_type_code_lb : forall x y, x = y -> base_type_code_beq x y = true)
(op_code_bl : forall x y, op_code_beq x y = true -> x = y)
(op_code_lb : forall x y, x = y -> op_code_beq x y = true)
- (interp_base_type : base_type_code -> Type)
(op : flat_type base_type_code -> flat_type base_type_code -> Type)
(symbolize_op : forall s d, op s d -> op_code).
Local Notation symbolic_expr := (symbolic_expr base_type_code op_code).
- Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr).
+ Context (normalize_symbolic_op_arguments : op_code -> symbolic_expr -> symbolic_expr)
+ (inline_symbolic_expr_in_lookup : bool).
Local Notation symbolic_expr_beq := (@symbolic_expr_beq base_type_code op_code base_type_code_beq op_code_beq).
Local Notation symbolic_expr_lb := (@internal_symbolic_expr_dec_lb base_type_code op_code base_type_code_beq op_code_beq base_type_code_lb op_code_lb).
@@ -38,9 +39,6 @@ Section symbolic.
Local Notation flat_type := (flat_type base_type_code).
Local Notation type := (type base_type_code).
- Local Notation interp_type := (interp_type interp_base_type).
- Local Notation interp_flat_type_gen := interp_flat_type.
- Local Notation interp_flat_type := (interp_flat_type interp_base_type).
Local Notation exprf := (@exprf base_type_code op).
Local Notation expr := (@expr base_type_code op).
Local Notation Expr := (@Expr base_type_code op).
@@ -48,9 +46,9 @@ Section symbolic.
Local Notation symbolicify_smart_var := (@symbolicify_smart_var base_type_code op_code).
Local Notation symbolize_exprf := (@symbolize_exprf base_type_code op_code op symbolize_op).
Local Notation norm_symbolize_exprf := (@norm_symbolize_exprf base_type_code op_code op symbolize_op normalize_symbolic_op_arguments).
- Local Notation csef := (@csef base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments).
- Local Notation cse := (@cse base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments).
- Local Notation CSE := (@CSE base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments).
+ Local Notation csef := (@csef base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup).
+ Local Notation cse := (@cse base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup).
+ Local Notation CSE := (@CSE base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl op symbolize_op normalize_symbolic_op_arguments inline_symbolic_expr_in_lookup).
Local Notation SymbolicExprContext := (@SymbolicExprContext base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl).
Local Notation SymbolicExprContextOk := (@SymbolicExprContextOk base_type_code op_code base_type_code_beq op_code_beq base_type_code_bl base_type_code_lb op_code_bl op_code_lb).
Local Notation prepend_prefix := (@prepend_prefix base_type_code op).
@@ -82,8 +80,8 @@ Section symbolic.
Local Arguments lookupb : simpl never.
Local Arguments extendb : simpl never.
Lemma wff_csef G G' t e1 e2
- (m1 : @SymbolicExprContext (interp_flat_type_gen var1))
- (m2 : @SymbolicExprContext (interp_flat_type_gen var2))
+ (m1 : @SymbolicExprContext (interp_flat_type var1))
+ (m2 : @SymbolicExprContext (interp_flat_type var2))
(Hlen : length m1 = length m2)
(Hm1m2None : forall t v, lookupb m1 v t = None <-> lookupb m2 v t = None)
(Hm1m2Some : forall t v sv1 sv2,
@@ -108,9 +106,13 @@ Section symbolic.
| [ H : lookupb ?m1 ?x = Some ?k, H' : lookupb ?m2 ?x = None |- _ ]
=> apply Hm1m2None in H'; congruence
end;
- [ | constructor; intros; auto; [].. ];
+ lazymatch goal with
+ | [ |- wff _ (LetIn _ _) (LetIn _ _) ]
+ => constructor; intros; auto; []
+ | _ => idtac
+ end;
match goal with H : _ |- _ => apply H end;
- repeat first [ progress unfold symbolize_var
+ try solve [ repeat first [ progress unfold symbolize_var
| rewrite Hlen
| progress subst
| setoid_rewrite length_extendb
@@ -130,8 +132,34 @@ Section symbolic.
| break_innermost_match_step
| break_innermost_match_hyps_step
| progress simpl in *
- | solve [ intuition (eauto || congruence) ] ]. }
- Qed.
+ | solve [ intuition (eauto || congruence) ]
+ | match goal with
+ | [ H : forall t x y, _ |- _ ] => specialize (fun t x0 x1 y0 y1 => H t (x0, x1) (y0, y1)); cbn [fst snd] in H
+ | [ H : In (existT _ ?t (?x, ?x')) (flatten_binding_list (symbolicify_smart_var _ _) (symbolicify_smart_var _ _)),
+ Hm1m2Some : forall t v sv1 sv2, _ -> _ -> forall k', In k' (flatten_binding_list _ _) -> In k' ?G |- _ ]
+ => is_var x; is_var x';
+ lazymatch goal with
+ | [ H : In (existT _ t ((fst x, _), (fst x', _))) G |- _ ] => fail
+ | _ => let H' := fresh in
+ refine (let H' := flatten_binding_list_SmartVarfMap2_pair_in_generalize2 H _ _ in _);
+ destruct H' as [? [? H']];
+ eapply Hm1m2Some in H'; [ | eassumption.. ]
+ end
+ end ] ].
+ repeat first [ progress unfold symbolize_var
+ | rewrite Hlen
+ | progress subst
+ | setoid_rewrite length_extendb
+ | setoid_rewrite List.in_app_iff
+ | progress destruct_head' or
+ | solve [ eauto ]
+ | progress intros ].
+ (** FIXME: This actually isn't true, because the symbolic
+ expr stored in G might not be the same as the one in the
+ expression tree, when the one in the expression tree is a
+ fresh var *)
+ admit. }
+ Admitted.
Lemma wff_prepend_prefix {var1' var2'} prefix1 prefix2 G t e1 e2
(Hlen : length prefix1 = length prefix2)
diff --git a/src/Compilers/TestCase.v b/src/Compilers/TestCase.v
index a7fd81328..36774e4e3 100644
--- a/src/Compilers/TestCase.v
+++ b/src/Compilers/TestCase.v
@@ -187,7 +187,7 @@ Section cse.
| Mul => SMul
| Sub => SSub
end.
- Definition CSE {t} e := @CSE base_type op_code base_type_beq op_code_beq internal_base_type_dec_bl op symbolicify_op (fun _ x => x) t e (fun _ => nil).
+ Definition CSE {t} e := @CSE base_type op_code base_type_beq op_code_beq internal_base_type_dec_bl op symbolicify_op (fun _ x => x) true t e (fun _ => nil).
End cse.
Definition example_expr_simplified := Eval vm_compute in InlineConst is_const (ANormal example_expr).
diff --git a/src/Compilers/Z/Bounds/Pipeline/Definition.v b/src/Compilers/Z/Bounds/Pipeline/Definition.v
index fd131dffa..8ca3c9b46 100644
--- a/src/Compilers/Z/Bounds/Pipeline/Definition.v
+++ b/src/Compilers/Z/Bounds/Pipeline/Definition.v
@@ -61,6 +61,9 @@ Require Import Crypto.Compilers.Z.InlineWf.
Require Import Crypto.Compilers.Linearize.
Require Import Crypto.Compilers.LinearizeInterp.
Require Import Crypto.Compilers.LinearizeWf.
+Require Import Crypto.Compilers.Z.CommonSubexpressionElimination.
+Require Import Crypto.Compilers.Z.CommonSubexpressionEliminationInterp.
+Require Import Crypto.Compilers.Z.CommonSubexpressionEliminationWf.
Require Import Crypto.Compilers.Z.ArithmeticSimplifierWf.
Require Import Crypto.Compilers.Z.Bounds.MapCastByDeBruijn.
Require Import Crypto.Compilers.Z.Bounds.MapCastByDeBruijnInterp.
@@ -80,6 +83,7 @@ Definition PostWfPipeline
let e := SimplifyArith e in
let e := ANormal e in
let e := InlineConst e in
+ let e := CSE false e in
let e := MapCast _ e input_bounds in
option_map
(projT2_map
diff --git a/src/Compilers/Z/CommonSubexpressionElimination.v b/src/Compilers/Z/CommonSubexpressionElimination.v
index 210c2c7c8..6695d137e 100644
--- a/src/Compilers/Z/CommonSubexpressionElimination.v
+++ b/src/Compilers/Z/CommonSubexpressionElimination.v
@@ -151,25 +151,25 @@ Definition normalize_symbolic_expr_mod_c (opc : symbolic_op) (args : symbolic_ex
=> args
end.
-Definition csef {var t} (v : exprf _ _ t) xs
+Definition csef inline_symbolic_expr_in_lookup {var t} (v : exprf _ _ t) xs
:= @csef base_type symbolic_op base_type_beq symbolic_op_beq
internal_base_type_dec_bl op symbolize_op
normalize_symbolic_expr_mod_c
- var t v xs.
+ var inline_symbolic_expr_in_lookup t v xs.
-Definition cse {var} (prefix : list _) {t} (v : expr _ _ t) xs
+Definition cse inline_symbolic_expr_in_lookup {var} (prefix : list _) {t} (v : expr _ _ t) xs
:= @cse base_type symbolic_op base_type_beq symbolic_op_beq
internal_base_type_dec_bl op symbolize_op
normalize_symbolic_expr_mod_c
- var prefix t v xs.
+ inline_symbolic_expr_in_lookup var prefix t v xs.
-Definition CSE_gen {t} (e : Expr _ _ t) (prefix : forall var, list { t : flat_type base_type & exprf _ _ t })
+Definition CSE_gen inline_symbolic_expr_in_lookup {t} (e : Expr _ _ t) (prefix : forall var, list { t : flat_type base_type & exprf _ _ t })
: Expr _ _ t
:= @CSE base_type symbolic_op base_type_beq symbolic_op_beq
internal_base_type_dec_bl op symbolize_op
normalize_symbolic_expr_mod_c
- t e prefix.
+ inline_symbolic_expr_in_lookup t e prefix.
-Definition CSE {t} (e : Expr _ _ t)
+Definition CSE inline_symbolic_expr_in_lookup {t} (e : Expr _ _ t)
: Expr _ _ t
- := @CSE_gen t e (fun _ => nil).
+ := @CSE_gen inline_symbolic_expr_in_lookup t e (fun _ => nil).
diff --git a/src/Compilers/Z/CommonSubexpressionEliminationInterp.v b/src/Compilers/Z/CommonSubexpressionEliminationInterp.v
index 6552084d9..280039058 100644
--- a/src/Compilers/Z/CommonSubexpressionEliminationInterp.v
+++ b/src/Compilers/Z/CommonSubexpressionEliminationInterp.v
@@ -6,16 +6,16 @@ Require Import Crypto.Compilers.Z.Syntax.
Require Import Crypto.Compilers.CommonSubexpressionEliminationInterp.
Require Import Crypto.Compilers.Z.CommonSubexpressionElimination.
-Lemma InterpCSE_gen t (e : Expr _ _ t) prefix
+Lemma InterpCSE_gen inline_symbolic_expr_in_lookup t (e : Expr _ _ t) prefix
(Hwf : Wf e)
- : forall x, Interp interp_op (@CSE_gen t e prefix) x = Interp interp_op e x.
+ : forall x, Interp interp_op (@CSE_gen inline_symbolic_expr_in_lookup t e prefix) x = Interp interp_op e x.
Proof.
apply InterpCSE;
auto using internal_base_type_dec_bl, internal_base_type_dec_lb, internal_symbolic_op_dec_bl, internal_symbolic_op_dec_lb, denote_symbolic_op.
Qed.
-Lemma InterpCSE t (e : Expr _ _ t) (Hwf : Wf e)
- : forall x, Interp interp_op (@CSE t e) x = Interp interp_op e x.
+Lemma InterpCSE inline_symbolic_expr_in_lookup t (e : Expr _ _ t) (Hwf : Wf e)
+ : forall x, Interp interp_op (@CSE inline_symbolic_expr_in_lookup t e) x = Interp interp_op e x.
Proof.
apply InterpCSE_gen; auto.
Qed.
diff --git a/src/Compilers/Z/CommonSubexpressionEliminationWf.v b/src/Compilers/Z/CommonSubexpressionEliminationWf.v
index 4f46f9454..a7365397c 100644
--- a/src/Compilers/Z/CommonSubexpressionEliminationWf.v
+++ b/src/Compilers/Z/CommonSubexpressionEliminationWf.v
@@ -5,7 +5,7 @@ Require Import Crypto.Compilers.Z.Syntax.
Require Import Crypto.Compilers.CommonSubexpressionEliminationWf.
Require Import Crypto.Compilers.Z.CommonSubexpressionElimination.
-Lemma Wf_CSE_gen t (e : Expr _ _ t)
+Lemma Wf_CSE_gen inline_symbolic_expr_in_lookup t (e : Expr _ _ t)
prefix
(Hlen : forall var1 var2, length (prefix var1) = length (prefix var2))
(Hprefix : forall var1 var2 n t1 t2 e1 e2,
@@ -13,14 +13,14 @@ Lemma Wf_CSE_gen t (e : Expr _ _ t)
-> List.nth_error (prefix var2) n = Some (existT _ t2 e2)
-> exists pf : t1 = t2, wff nil (eq_rect _ (@exprf _ _ _) e1 _ pf) e2)
(Hwf : Wf e)
- : Wf (@CSE_gen t e prefix).
+ : Wf (@CSE_gen inline_symbolic_expr_in_lookup t e prefix).
Proof.
apply Wf_CSE; auto using internal_base_type_dec_bl, internal_base_type_dec_lb, internal_symbolic_op_dec_bl, internal_symbolic_op_dec_lb.
Qed.
-Lemma Wf_CSE t (e : Expr _ _ t)
+Lemma Wf_CSE inline_symbolic_expr_in_lookup t (e : Expr _ _ t)
(Hwf : Wf e)
- : Wf (@CSE t e).
+ : Wf (@CSE inline_symbolic_expr_in_lookup t e).
Proof.
apply Wf_CSE_gen; simpl; auto.
{ destruct n; simpl; try congruence. }
diff --git a/src/Specific/FancyMachine256/Core.v b/src/Specific/FancyMachine256/Core.v
index 0d521ae17..cd54402ea 100644
--- a/src/Specific/FancyMachine256/Core.v
+++ b/src/Specific/FancyMachine256/Core.v
@@ -108,7 +108,7 @@ Section reflection.
| OPaddm => SOPaddm
end.
- Definition CSE {t} e := @CSE base_type op_code base_type_beq op_code_beq internal_base_type_dec_bl op symbolicify_op (fun _ x => x) t e (fun _ => nil).
+ Definition CSE {t} e := @CSE base_type op_code base_type_beq op_code_beq internal_base_type_dec_bl op symbolicify_op (fun _ x => x) true t e (fun _ => nil).
Inductive inline_option := opt_inline | opt_default | opt_noinline.