diff options
author | Jason Gross <jagro@google.com> | 2016-09-02 19:43:01 -0700 |
---|---|---|
committer | Jason Gross <jagro@google.com> | 2016-09-06 16:48:55 -0700 |
commit | e233e15c0eafd34d9cb6412361d7aaa373d774e0 (patch) | |
tree | 6d2afac75bdce67cc331326301597f5cf89d8114 /src | |
parent | f34ef621c8d5db15490038c082383145d9231761 (diff) |
Add Common Subexpression Elimination
Diffstat (limited to 'src')
-rw-r--r-- | src/Reflection/CommonSubexpressionElimination.v | 198 | ||||
-rw-r--r-- | src/Reflection/TestCase.v | 35 |
2 files changed, 228 insertions, 5 deletions
diff --git a/src/Reflection/CommonSubexpressionElimination.v b/src/Reflection/CommonSubexpressionElimination.v new file mode 100644 index 000000000..3c5e1cbd7 --- /dev/null +++ b/src/Reflection/CommonSubexpressionElimination.v @@ -0,0 +1,198 @@ +(** * Common Subexpression Elimination for PHOAS Syntax *) +Require Import Coq.Lists.List. +Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Util.Tactics Crypto.Util.Bool. + +Local Open Scope list_scope. + +Inductive symbolic_expr {base_type_code SConstT op_code} : Type := +| SConst (v : SConstT) +| SVar (v : base_type_code) (n : nat) +| SOp (op : op_code) (args : symbolic_expr) +| SPair (x y : symbolic_expr) +| SInvalid. +Scheme Equality for symbolic_expr. + +Arguments symbolic_expr : clear implicits. + +Ltac inversion_symbolic_expr_step := + match goal with + | [ H : SConst _ = SConst _ |- _ ] => inversion H; clear H + | [ H : SVar _ _ = SVar _ _ |- _ ] => inversion H; clear H + | [ H : SOp _ _ = SOp _ _ |- _ ] => inversion H; clear H + | [ H : SPair _ _ = SPair _ _ |- _ ] => inversion H; clear H + end. +Ltac inversion_symbolic_expr := repeat inversion_symbolic_expr_step. + +Local Open Scope ctype_scope. +Section symbolic. + (** Holds decidably-equal versions of raw expressions, for lookup. *) + Context (base_type_code : Type) + (SConstT : Type) + (op_code : Type) + (base_type_code_beq : base_type_code -> base_type_code -> bool) + (SConstT_beq : SConstT -> SConstT -> bool) + (op_code_beq : op_code -> op_code -> bool) + (base_type_code_bl : forall x y, base_type_code_beq x y = true -> x = y) + (base_type_code_lb : forall x y, x = y -> base_type_code_beq x y = true) + (SConstT_bl : forall x y, SConstT_beq x y = true -> x = y) + (SConstT_lb : forall x y, x = y -> SConstT_beq x y = true) + (op_code_bl : forall x y, op_code_beq x y = true -> x = y) + (op_code_lb : forall x y, x = y -> op_code_beq x y = true) + (interp_base_type : base_type_code -> Type) + (op : flat_type base_type_code -> flat_type base_type_code -> Type) + (symbolize_const : forall t, interp_base_type t -> SConstT) + (symbolize_op : forall s d, op s d -> op_code). + + Local Notation symbolic_expr := (symbolic_expr base_type_code SConstT op_code). + Local Notation symbolic_expr_beq := (@symbolic_expr_beq base_type_code SConstT op_code base_type_code_beq SConstT_beq op_code_beq). + Local Notation symbolic_expr_lb := (@internal_symbolic_expr_dec_lb base_type_code SConstT op_code base_type_code_beq SConstT_beq op_code_beq base_type_code_lb SConstT_lb op_code_lb). + Local Notation symbolic_expr_bl := (@internal_symbolic_expr_dec_bl base_type_code SConstT op_code base_type_code_beq SConstT_beq op_code_beq base_type_code_bl SConstT_bl op_code_bl). + + Local Notation flat_type := (flat_type base_type_code). + Local Notation type := (type base_type_code). + Let Tbase := @Tbase base_type_code. + Local Coercion Tbase : base_type_code >-> flat_type. + Local Notation interp_type := (interp_type interp_base_type). + Local Notation interp_flat_type := (interp_flat_type_gen interp_base_type). + Local Notation exprf := (@exprf base_type_code interp_base_type op). + Local Notation expr := (@expr base_type_code interp_base_type op). + Local Notation Expr := (@Expr base_type_code interp_base_type op). + + + Section with_var. + Context {var : base_type_code -> Type}. + + Local Notation svar t := (var t * symbolic_expr)%type. + Local Notation fsvar := (fun t => svar t). + Local Notation mapping := (forall t : base_type_code, list (svar t))%type. + + Context (prefix : list (sigT (fun t : flat_type => @exprf fsvar t))). + + Definition empty_mapping : mapping := fun _ => nil. + Definition type_lookup t (xs : mapping) : list (svar t) := xs t. + Definition mapping_update_type t (xs : mapping) (upd : list (svar t) -> list (svar t)) + : mapping + := fun t' => (if base_type_code_beq t t' as b return base_type_code_beq t t' = b -> _ + then fun H => match base_type_code_bl _ _ H in (_ = t') return list (svar t') with + | eq_refl => upd (type_lookup t xs) + end + else fun _ => type_lookup t' xs) + eq_refl. + + Fixpoint lookup' {t} (sv : symbolic_expr) (xs : list (svar t)) {struct xs} : option (var t) := + match xs with + | nil => None + | (x, sv') :: xs' => + if symbolic_expr_beq sv' sv + then Some x + else lookup' sv xs' + end. + Definition lookup t (sv : symbolic_expr) (xs : mapping) : option (var t) := + lookup' sv (type_lookup t xs). + Definition symbolicify_var {t : base_type_code} (v : var t) (xs : mapping) : symbolic_expr := + SVar t (length (type_lookup t xs)). + Definition add_mapping {t} (v : var t) (sv : symbolic_expr) (xs : mapping) : mapping := + mapping_update_type t xs (fun ls => (v, sv) :: ls). + + Definition symbolize_smart_const {t} : interp_flat_type t -> symbolic_expr + := smart_interp_flat_map base_type_code (g:=fun _ => symbolic_expr) (fun t v => SConst (symbolize_const t v)) (fun A B => SPair). + + Fixpoint symbolize_exprf + {t} (v : @exprf fsvar t) {struct v} + : option symbolic_expr + := match v with + | Const t x => Some (symbolize_smart_const x) + | Var _ x => Some (snd x) + | Op _ _ op args => option_map + (fun sargs => SOp (symbolize_op _ _ op) sargs) + (@symbolize_exprf _ args) + | Let _ ex _ eC => None + | Pair _ ex _ ey => match @symbolize_exprf _ ex, @symbolize_exprf _ ey with + | Some sx, Some sy => Some (SPair sx sy) + | _, _ => None + end + end. + + Fixpoint smart_lookup_gen f (proj : forall t, svar t -> f t) + (t : flat_type) (sv : symbolic_expr) (xs : mapping) {struct t} + : option (interp_flat_type_gen f t) + := match t return option (interp_flat_type_gen f t) with + | Syntax.Tbase t => option_map (fun v => proj t (v, sv)) (lookup t sv xs) + | Prod A B => match @smart_lookup_gen f proj A sv xs, @smart_lookup_gen f proj B sv xs with + | Some a, Some b => Some (a, b) + | _, _ => None + end + end. + Definition smart_lookup (t : flat_type) (sv : symbolic_expr) (xs : mapping) : option (interp_flat_type_gen fsvar t) + := @smart_lookup_gen fsvar (fun _ x => x) t sv xs. + Definition smart_lookupo (t : flat_type) (sv : option symbolic_expr) (xs : mapping) : option (interp_flat_type_gen fsvar t) + := match sv with + | Some sv => smart_lookup t sv xs + | None => None + end. + Definition symbolicify_smart_var {t : flat_type} (xs : mapping) (replacement : option symbolic_expr) + : interp_flat_type_gen var t -> interp_flat_type_gen fsvar t + := smart_interp_flat_map + (g:=interp_flat_type_gen fsvar) + base_type_code (fun t v => (v, + match replacement with + | Some sv => sv + | None => symbolicify_var v xs + end)) + (fun A B => @pair _ _). + Fixpoint smart_add_mapping {t : flat_type} (xs : mapping) : interp_flat_type_gen fsvar t -> mapping + := match t return interp_flat_type_gen fsvar t -> mapping with + | Syntax.Tbase t => fun v => add_mapping (fst v) (snd v) xs + | Prod A B + => fun v => let xs := @smart_add_mapping B xs (snd v) in + let xs := @smart_add_mapping A xs (fst v) in + xs + end. + + Definition csef_step + (csef : forall {t} (v : @exprf fsvar t) (xs : mapping), @exprf var t) + {t} (v : @exprf fsvar t) (xs : mapping) + : @exprf var t + := match v in @Syntax.exprf _ _ _ _ t return exprf t with + | Let tx ex _ eC => let sx := symbolize_exprf ex in + let ex' := @csef _ ex xs in + let sv := smart_lookupo tx sx xs in + match sv with + | Some v => @csef _ (eC v) xs + | None + => Let ex' (fun x => let x' := symbolicify_smart_var xs sx x in + @csef _ (eC x') (smart_add_mapping xs x')) + end + | Const _ x => Const x + | Var _ x => Var (fst x) + | Op _ _ op args => Op op (@csef _ args xs) + | Pair _ ex _ ey => Pair (@csef _ ex xs) (@csef _ ey xs) + end. + + Fixpoint csef {t} (v : @exprf fsvar t) (xs : mapping) + := @csef_step (@csef) t v xs. + + Fixpoint prepend_prefix {t} (e : @exprf fsvar t) (ls : list (sigT (fun t : flat_type => @exprf fsvar t))) + : @exprf fsvar t + := match ls with + | nil => e + | x :: xs => Let (projT2 x) (fun _ => @prepend_prefix _ e xs) + end. + + Fixpoint cse {t} (v : @expr fsvar t) (xs : mapping) {struct v} : @expr var t + := match v in @Syntax.expr _ _ _ _ t return expr t with + | Return _ x => Return (csef (prepend_prefix x prefix) xs) + | Abs _ _ f => Abs (fun x => let x' := symbolicify_var x xs in + @cse _ (f (x, x')) (add_mapping x x' xs)) + end. + End with_var. + + Definition CSE {t} (e : Expr t) (prefix : forall var, list (sigT (fun t : flat_type => @exprf var t))) + : Expr t + := fun var => cse (prefix _) (e _) empty_mapping. +End symbolic. + +Global Arguments csef {_} SConstT op_code base_type_code_beq SConstT_beq op_code_beq base_type_code_bl {_ _} symbolize_const symbolize_op {var t} _ _. +Global Arguments cse {_} SConstT op_code base_type_code_beq SConstT_beq op_code_beq base_type_code_bl {_ _} symbolize_const symbolize_op {var} prefix {t} _ _. +Global Arguments CSE {_} SConstT op_code base_type_code_beq SConstT_beq op_code_beq base_type_code_bl {_ _} symbolize_const symbolize_op {t} e prefix var. diff --git a/src/Reflection/TestCase.v b/src/Reflection/TestCase.v index 0f6384bb2..17580aa5d 100644 --- a/src/Reflection/TestCase.v +++ b/src/Reflection/TestCase.v @@ -1,11 +1,14 @@ Require Import Crypto.Reflection.Syntax. Require Export Crypto.Reflection.Reify. Require Import Crypto.Reflection.InputSyntax. +Require Import Crypto.Reflection.CommonSubexpressionElimination. Require Crypto.Reflection.Linearize. Require Import Crypto.Reflection.WfReflective. Import ReifyDebugNotations. +Local Set Boolean Equality Schemes. +Local Set Decidable Equality Schemes. Inductive base_type := Tnat. Definition interp_base_type (v : base_type) : Type := match v with @@ -14,15 +17,18 @@ Definition interp_base_type (v : base_type) : Type := Local Notation tnat := (Tbase Tnat). Inductive op : flat_type base_type -> flat_type base_type -> Type := | Add : op (Prod tnat tnat) tnat -| Mul : op (Prod tnat tnat) tnat. +| Mul : op (Prod tnat tnat) tnat +| Sub : op (Prod tnat tnat) tnat. Definition interp_op src dst (f : op src dst) : interp_flat_type_gen interp_base_type src -> interp_flat_type_gen interp_base_type dst := match f with | Add => fun xy => fst xy + snd xy | Mul => fun xy => fst xy * snd xy + | Sub => fun xy => fst xy - snd xy end%nat. Global Instance: forall x y, reify_op op (x + y)%nat 2 Add := fun _ _ => I. Global Instance: forall x y, reify_op op (x * y)%nat 2 Mul := fun _ _ => I. +Global Instance: forall x y, reify_op op (x - y)%nat 2 Sub := fun _ _ => I. Global Instance: reify type nat := Tnat. Ltac Reify' e := Reify.Reify' base_type interp_base_type op e. @@ -70,7 +76,7 @@ Abort. Definition example_expr : Syntax.Expr base_type interp_base_type op (Tbase Tnat). Proof. - let x := Reify (let x := 1 in let y := 1 in (let a := 1 in let '(c, d) := (2, 3) in a + x + c + d) + y)%nat in + let x := Reify (let x := 1 in let y := 1 in (let a := 1 in let '(c, d) := (2, 3) in a + x + (x + x) + (x + x) - (x + x) + c + d) + y)%nat in exact x. Defined. @@ -89,6 +95,8 @@ Definition op_beq t1 tR : op t1 tR -> op t1 tR -> option pointed_Prop | Add, _ => None | Mul, Mul => Some trivial | Mul, _ => None + | Sub, Sub => Some trivial + | Sub, _ => None end. Lemma op_beq_bl t1 tR (x y : op t1 tR) : match op_beq t1 tR x y with @@ -96,9 +104,8 @@ Lemma op_beq_bl t1 tR (x y : op t1 tR) | None => False end -> x = y. Proof. - destruct x; simpl. - { refine match y with Add => _ | _ => _ end; tauto. } - { refine match y with Add => _ | _ => _ end; tauto. } + destruct x; simpl; + refine match y with Add => _ | _ => _ end; tauto. Qed. Ltac reflect_Wf := WfReflective.reflect_Wf base_type_eq_semidec_is_dec op_beq_bl. @@ -117,3 +124,21 @@ Qed. Lemma example_expr_wf : Wf example_expr. Proof. Time reflect_Wf. (* 0.008 s *) Qed. + +Section cse. + Let SConstT := nat. + Inductive op_code : Set := SAdd | SMul | SSub. + Definition symbolicify_const (t : base_type) : interp_base_type t -> SConstT + := match t with + | Tnat => fun x => x + end. + Definition symbolicify_op s d (v : op s d) : op_code + := match v with + | Add => SAdd + | Mul => SMul + | Sub => SSub + end. + Definition CSE {t} e := @CSE base_type SConstT op_code base_type_beq nat_beq op_code_beq internal_base_type_dec_bl interp_base_type op symbolicify_const symbolicify_op t e (fun _ => nil). +End cse. + +Compute CSE (InlineConst (Linearize example_expr)). |