aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jagro@google.com>2016-09-02 19:43:01 -0700
committerGravatar Jason Gross <jagro@google.com>2016-09-06 16:48:55 -0700
commite233e15c0eafd34d9cb6412361d7aaa373d774e0 (patch)
tree6d2afac75bdce67cc331326301597f5cf89d8114 /src
parentf34ef621c8d5db15490038c082383145d9231761 (diff)
Add Common Subexpression Elimination
Diffstat (limited to 'src')
-rw-r--r--src/Reflection/CommonSubexpressionElimination.v198
-rw-r--r--src/Reflection/TestCase.v35
2 files changed, 228 insertions, 5 deletions
diff --git a/src/Reflection/CommonSubexpressionElimination.v b/src/Reflection/CommonSubexpressionElimination.v
new file mode 100644
index 000000000..3c5e1cbd7
--- /dev/null
+++ b/src/Reflection/CommonSubexpressionElimination.v
@@ -0,0 +1,198 @@
+(** * Common Subexpression Elimination for PHOAS Syntax *)
+Require Import Coq.Lists.List.
+Require Import Crypto.Reflection.Syntax.
+Require Import Crypto.Util.Tactics Crypto.Util.Bool.
+
+Local Open Scope list_scope.
+
+Inductive symbolic_expr {base_type_code SConstT op_code} : Type :=
+| SConst (v : SConstT)
+| SVar (v : base_type_code) (n : nat)
+| SOp (op : op_code) (args : symbolic_expr)
+| SPair (x y : symbolic_expr)
+| SInvalid.
+Scheme Equality for symbolic_expr.
+
+Arguments symbolic_expr : clear implicits.
+
+Ltac inversion_symbolic_expr_step :=
+ match goal with
+ | [ H : SConst _ = SConst _ |- _ ] => inversion H; clear H
+ | [ H : SVar _ _ = SVar _ _ |- _ ] => inversion H; clear H
+ | [ H : SOp _ _ = SOp _ _ |- _ ] => inversion H; clear H
+ | [ H : SPair _ _ = SPair _ _ |- _ ] => inversion H; clear H
+ end.
+Ltac inversion_symbolic_expr := repeat inversion_symbolic_expr_step.
+
+Local Open Scope ctype_scope.
+Section symbolic.
+ (** Holds decidably-equal versions of raw expressions, for lookup. *)
+ Context (base_type_code : Type)
+ (SConstT : Type)
+ (op_code : Type)
+ (base_type_code_beq : base_type_code -> base_type_code -> bool)
+ (SConstT_beq : SConstT -> SConstT -> bool)
+ (op_code_beq : op_code -> op_code -> bool)
+ (base_type_code_bl : forall x y, base_type_code_beq x y = true -> x = y)
+ (base_type_code_lb : forall x y, x = y -> base_type_code_beq x y = true)
+ (SConstT_bl : forall x y, SConstT_beq x y = true -> x = y)
+ (SConstT_lb : forall x y, x = y -> SConstT_beq x y = true)
+ (op_code_bl : forall x y, op_code_beq x y = true -> x = y)
+ (op_code_lb : forall x y, x = y -> op_code_beq x y = true)
+ (interp_base_type : base_type_code -> Type)
+ (op : flat_type base_type_code -> flat_type base_type_code -> Type)
+ (symbolize_const : forall t, interp_base_type t -> SConstT)
+ (symbolize_op : forall s d, op s d -> op_code).
+
+ Local Notation symbolic_expr := (symbolic_expr base_type_code SConstT op_code).
+ Local Notation symbolic_expr_beq := (@symbolic_expr_beq base_type_code SConstT op_code base_type_code_beq SConstT_beq op_code_beq).
+ Local Notation symbolic_expr_lb := (@internal_symbolic_expr_dec_lb base_type_code SConstT op_code base_type_code_beq SConstT_beq op_code_beq base_type_code_lb SConstT_lb op_code_lb).
+ Local Notation symbolic_expr_bl := (@internal_symbolic_expr_dec_bl base_type_code SConstT op_code base_type_code_beq SConstT_beq op_code_beq base_type_code_bl SConstT_bl op_code_bl).
+
+ Local Notation flat_type := (flat_type base_type_code).
+ Local Notation type := (type base_type_code).
+ Let Tbase := @Tbase base_type_code.
+ Local Coercion Tbase : base_type_code >-> flat_type.
+ Local Notation interp_type := (interp_type interp_base_type).
+ Local Notation interp_flat_type := (interp_flat_type_gen interp_base_type).
+ Local Notation exprf := (@exprf base_type_code interp_base_type op).
+ Local Notation expr := (@expr base_type_code interp_base_type op).
+ Local Notation Expr := (@Expr base_type_code interp_base_type op).
+
+
+ Section with_var.
+ Context {var : base_type_code -> Type}.
+
+ Local Notation svar t := (var t * symbolic_expr)%type.
+ Local Notation fsvar := (fun t => svar t).
+ Local Notation mapping := (forall t : base_type_code, list (svar t))%type.
+
+ Context (prefix : list (sigT (fun t : flat_type => @exprf fsvar t))).
+
+ Definition empty_mapping : mapping := fun _ => nil.
+ Definition type_lookup t (xs : mapping) : list (svar t) := xs t.
+ Definition mapping_update_type t (xs : mapping) (upd : list (svar t) -> list (svar t))
+ : mapping
+ := fun t' => (if base_type_code_beq t t' as b return base_type_code_beq t t' = b -> _
+ then fun H => match base_type_code_bl _ _ H in (_ = t') return list (svar t') with
+ | eq_refl => upd (type_lookup t xs)
+ end
+ else fun _ => type_lookup t' xs)
+ eq_refl.
+
+ Fixpoint lookup' {t} (sv : symbolic_expr) (xs : list (svar t)) {struct xs} : option (var t) :=
+ match xs with
+ | nil => None
+ | (x, sv') :: xs' =>
+ if symbolic_expr_beq sv' sv
+ then Some x
+ else lookup' sv xs'
+ end.
+ Definition lookup t (sv : symbolic_expr) (xs : mapping) : option (var t) :=
+ lookup' sv (type_lookup t xs).
+ Definition symbolicify_var {t : base_type_code} (v : var t) (xs : mapping) : symbolic_expr :=
+ SVar t (length (type_lookup t xs)).
+ Definition add_mapping {t} (v : var t) (sv : symbolic_expr) (xs : mapping) : mapping :=
+ mapping_update_type t xs (fun ls => (v, sv) :: ls).
+
+ Definition symbolize_smart_const {t} : interp_flat_type t -> symbolic_expr
+ := smart_interp_flat_map base_type_code (g:=fun _ => symbolic_expr) (fun t v => SConst (symbolize_const t v)) (fun A B => SPair).
+
+ Fixpoint symbolize_exprf
+ {t} (v : @exprf fsvar t) {struct v}
+ : option symbolic_expr
+ := match v with
+ | Const t x => Some (symbolize_smart_const x)
+ | Var _ x => Some (snd x)
+ | Op _ _ op args => option_map
+ (fun sargs => SOp (symbolize_op _ _ op) sargs)
+ (@symbolize_exprf _ args)
+ | Let _ ex _ eC => None
+ | Pair _ ex _ ey => match @symbolize_exprf _ ex, @symbolize_exprf _ ey with
+ | Some sx, Some sy => Some (SPair sx sy)
+ | _, _ => None
+ end
+ end.
+
+ Fixpoint smart_lookup_gen f (proj : forall t, svar t -> f t)
+ (t : flat_type) (sv : symbolic_expr) (xs : mapping) {struct t}
+ : option (interp_flat_type_gen f t)
+ := match t return option (interp_flat_type_gen f t) with
+ | Syntax.Tbase t => option_map (fun v => proj t (v, sv)) (lookup t sv xs)
+ | Prod A B => match @smart_lookup_gen f proj A sv xs, @smart_lookup_gen f proj B sv xs with
+ | Some a, Some b => Some (a, b)
+ | _, _ => None
+ end
+ end.
+ Definition smart_lookup (t : flat_type) (sv : symbolic_expr) (xs : mapping) : option (interp_flat_type_gen fsvar t)
+ := @smart_lookup_gen fsvar (fun _ x => x) t sv xs.
+ Definition smart_lookupo (t : flat_type) (sv : option symbolic_expr) (xs : mapping) : option (interp_flat_type_gen fsvar t)
+ := match sv with
+ | Some sv => smart_lookup t sv xs
+ | None => None
+ end.
+ Definition symbolicify_smart_var {t : flat_type} (xs : mapping) (replacement : option symbolic_expr)
+ : interp_flat_type_gen var t -> interp_flat_type_gen fsvar t
+ := smart_interp_flat_map
+ (g:=interp_flat_type_gen fsvar)
+ base_type_code (fun t v => (v,
+ match replacement with
+ | Some sv => sv
+ | None => symbolicify_var v xs
+ end))
+ (fun A B => @pair _ _).
+ Fixpoint smart_add_mapping {t : flat_type} (xs : mapping) : interp_flat_type_gen fsvar t -> mapping
+ := match t return interp_flat_type_gen fsvar t -> mapping with
+ | Syntax.Tbase t => fun v => add_mapping (fst v) (snd v) xs
+ | Prod A B
+ => fun v => let xs := @smart_add_mapping B xs (snd v) in
+ let xs := @smart_add_mapping A xs (fst v) in
+ xs
+ end.
+
+ Definition csef_step
+ (csef : forall {t} (v : @exprf fsvar t) (xs : mapping), @exprf var t)
+ {t} (v : @exprf fsvar t) (xs : mapping)
+ : @exprf var t
+ := match v in @Syntax.exprf _ _ _ _ t return exprf t with
+ | Let tx ex _ eC => let sx := symbolize_exprf ex in
+ let ex' := @csef _ ex xs in
+ let sv := smart_lookupo tx sx xs in
+ match sv with
+ | Some v => @csef _ (eC v) xs
+ | None
+ => Let ex' (fun x => let x' := symbolicify_smart_var xs sx x in
+ @csef _ (eC x') (smart_add_mapping xs x'))
+ end
+ | Const _ x => Const x
+ | Var _ x => Var (fst x)
+ | Op _ _ op args => Op op (@csef _ args xs)
+ | Pair _ ex _ ey => Pair (@csef _ ex xs) (@csef _ ey xs)
+ end.
+
+ Fixpoint csef {t} (v : @exprf fsvar t) (xs : mapping)
+ := @csef_step (@csef) t v xs.
+
+ Fixpoint prepend_prefix {t} (e : @exprf fsvar t) (ls : list (sigT (fun t : flat_type => @exprf fsvar t)))
+ : @exprf fsvar t
+ := match ls with
+ | nil => e
+ | x :: xs => Let (projT2 x) (fun _ => @prepend_prefix _ e xs)
+ end.
+
+ Fixpoint cse {t} (v : @expr fsvar t) (xs : mapping) {struct v} : @expr var t
+ := match v in @Syntax.expr _ _ _ _ t return expr t with
+ | Return _ x => Return (csef (prepend_prefix x prefix) xs)
+ | Abs _ _ f => Abs (fun x => let x' := symbolicify_var x xs in
+ @cse _ (f (x, x')) (add_mapping x x' xs))
+ end.
+ End with_var.
+
+ Definition CSE {t} (e : Expr t) (prefix : forall var, list (sigT (fun t : flat_type => @exprf var t)))
+ : Expr t
+ := fun var => cse (prefix _) (e _) empty_mapping.
+End symbolic.
+
+Global Arguments csef {_} SConstT op_code base_type_code_beq SConstT_beq op_code_beq base_type_code_bl {_ _} symbolize_const symbolize_op {var t} _ _.
+Global Arguments cse {_} SConstT op_code base_type_code_beq SConstT_beq op_code_beq base_type_code_bl {_ _} symbolize_const symbolize_op {var} prefix {t} _ _.
+Global Arguments CSE {_} SConstT op_code base_type_code_beq SConstT_beq op_code_beq base_type_code_bl {_ _} symbolize_const symbolize_op {t} e prefix var.
diff --git a/src/Reflection/TestCase.v b/src/Reflection/TestCase.v
index 0f6384bb2..17580aa5d 100644
--- a/src/Reflection/TestCase.v
+++ b/src/Reflection/TestCase.v
@@ -1,11 +1,14 @@
Require Import Crypto.Reflection.Syntax.
Require Export Crypto.Reflection.Reify.
Require Import Crypto.Reflection.InputSyntax.
+Require Import Crypto.Reflection.CommonSubexpressionElimination.
Require Crypto.Reflection.Linearize.
Require Import Crypto.Reflection.WfReflective.
Import ReifyDebugNotations.
+Local Set Boolean Equality Schemes.
+Local Set Decidable Equality Schemes.
Inductive base_type := Tnat.
Definition interp_base_type (v : base_type) : Type :=
match v with
@@ -14,15 +17,18 @@ Definition interp_base_type (v : base_type) : Type :=
Local Notation tnat := (Tbase Tnat).
Inductive op : flat_type base_type -> flat_type base_type -> Type :=
| Add : op (Prod tnat tnat) tnat
-| Mul : op (Prod tnat tnat) tnat.
+| Mul : op (Prod tnat tnat) tnat
+| Sub : op (Prod tnat tnat) tnat.
Definition interp_op src dst (f : op src dst) : interp_flat_type_gen interp_base_type src -> interp_flat_type_gen interp_base_type dst
:= match f with
| Add => fun xy => fst xy + snd xy
| Mul => fun xy => fst xy * snd xy
+ | Sub => fun xy => fst xy - snd xy
end%nat.
Global Instance: forall x y, reify_op op (x + y)%nat 2 Add := fun _ _ => I.
Global Instance: forall x y, reify_op op (x * y)%nat 2 Mul := fun _ _ => I.
+Global Instance: forall x y, reify_op op (x - y)%nat 2 Sub := fun _ _ => I.
Global Instance: reify type nat := Tnat.
Ltac Reify' e := Reify.Reify' base_type interp_base_type op e.
@@ -70,7 +76,7 @@ Abort.
Definition example_expr : Syntax.Expr base_type interp_base_type op (Tbase Tnat).
Proof.
- let x := Reify (let x := 1 in let y := 1 in (let a := 1 in let '(c, d) := (2, 3) in a + x + c + d) + y)%nat in
+ let x := Reify (let x := 1 in let y := 1 in (let a := 1 in let '(c, d) := (2, 3) in a + x + (x + x) + (x + x) - (x + x) + c + d) + y)%nat in
exact x.
Defined.
@@ -89,6 +95,8 @@ Definition op_beq t1 tR : op t1 tR -> op t1 tR -> option pointed_Prop
| Add, _ => None
| Mul, Mul => Some trivial
| Mul, _ => None
+ | Sub, Sub => Some trivial
+ | Sub, _ => None
end.
Lemma op_beq_bl t1 tR (x y : op t1 tR)
: match op_beq t1 tR x y with
@@ -96,9 +104,8 @@ Lemma op_beq_bl t1 tR (x y : op t1 tR)
| None => False
end -> x = y.
Proof.
- destruct x; simpl.
- { refine match y with Add => _ | _ => _ end; tauto. }
- { refine match y with Add => _ | _ => _ end; tauto. }
+ destruct x; simpl;
+ refine match y with Add => _ | _ => _ end; tauto.
Qed.
Ltac reflect_Wf := WfReflective.reflect_Wf base_type_eq_semidec_is_dec op_beq_bl.
@@ -117,3 +124,21 @@ Qed.
Lemma example_expr_wf : Wf example_expr.
Proof. Time reflect_Wf. (* 0.008 s *) Qed.
+
+Section cse.
+ Let SConstT := nat.
+ Inductive op_code : Set := SAdd | SMul | SSub.
+ Definition symbolicify_const (t : base_type) : interp_base_type t -> SConstT
+ := match t with
+ | Tnat => fun x => x
+ end.
+ Definition symbolicify_op s d (v : op s d) : op_code
+ := match v with
+ | Add => SAdd
+ | Mul => SMul
+ | Sub => SSub
+ end.
+ Definition CSE {t} e := @CSE base_type SConstT op_code base_type_beq nat_beq op_code_beq internal_base_type_dec_bl interp_base_type op symbolicify_const symbolicify_op t e (fun _ => nil).
+End cse.
+
+Compute CSE (InlineConst (Linearize example_expr)).