diff options
Diffstat (limited to 'src/Compilers/InlineCast.v')
-rw-r--r-- | src/Compilers/InlineCast.v | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/src/Compilers/InlineCast.v b/src/Compilers/InlineCast.v new file mode 100644 index 000000000..675835760 --- /dev/null +++ b/src/Compilers/InlineCast.v @@ -0,0 +1,90 @@ +Require Import Crypto.Compilers.Syntax. +Require Import Crypto.Compilers.SmartCast. +Require Import Crypto.Compilers.TypeUtil. +Require Import Crypto.Compilers.Inline. +Require Import Crypto.Util.Notations. + +Local Open Scope expr_scope. +Local Open Scope ctype_scope. +Section language. + Context {base_type_code : Type} + {op : flat_type base_type_code -> flat_type base_type_code -> Type} + (base_type_beq : base_type_code -> base_type_code -> bool) + (base_type_bl_transparent : forall x y, base_type_beq x y = true -> x = y) + (base_type_leb : base_type_code -> base_type_code -> bool) + (Cast : forall var A A', exprf base_type_code op (var:=var) (Tbase A) -> exprf base_type_code op (var:=var) (Tbase A')) + (is_cast : forall src dst, op src dst -> bool) + (is_const : forall src dst, op src dst -> bool). + Local Infix "<=?" := base_type_leb : expr_scope. + Local Infix "=?" := base_type_beq : expr_scope. + + Local Notation base_type_min := (base_type_min base_type_leb). + Local Notation SmartCast_base := (@SmartCast_base _ op _ base_type_bl_transparent Cast). + + Local Notation flat_type := (flat_type base_type_code). + Local Notation exprf := (@exprf base_type_code op). + Local Notation Expr := (@Expr base_type_code op). + + (** We can squash [a -> b -> c] into [a -> c] if [min a b c = min a + c], i.e., if the narrowest type we pass through in the original + case is the same as the narrowest type we pass through in the + new case. *) + Definition squash_cast {var} (a b c : base_type_code) + : @exprf var (Tbase a) -> @exprf var (Tbase c) + := if base_type_beq (base_type_min b (base_type_min a c)) (base_type_min a c) + then SmartCast_base + else fun x => Cast _ b c (Cast _ a b x). + Fixpoint maybe_push_cast {var t} (v : @exprf var t) : option (@exprf var t) + := match v in Syntax.exprf _ _ t return option (exprf t) with + | Var _ _ as v' + => Some v' + | Op t1 tR opc args + => match t1, tR return op t1 tR -> exprf t1 -> option (exprf tR) with + | Tbase b, Tbase c + => fun opc args + => if is_cast _ _ opc + then match @maybe_push_cast _ _ args with + | Some (Op t1 tR opc' args') + => match t1, tR return op t1 tR -> exprf t1 -> option (exprf (Tbase c)) with + | Tbase a, Tbase b + => fun opc' args' + => if is_cast _ _ opc' + then Some (squash_cast a b c args') + else None + | Unit, Tbase _ + => fun opc' args' + => if is_const _ _ opc' + then Some (SmartCast_base (Op opc' TT)) + else None + | _, _ => fun _ _ => None + end opc' args' + | Some (Var _ v as v') => Some (SmartCast_base v') + | Some _ => None + | None => None + end + else None + | Unit, _ + => fun opc args + => if is_const _ _ opc + then Some (Op opc TT) + else None + | _, _ + => fun _ _ => None + end opc args + | Pair _ _ _ _ + | LetIn _ _ _ _ + | TT + => None + end. + Definition push_cast {var t} : @exprf var t -> @inline_directive _ op var t + := match t with + | Tbase _ => fun v => match maybe_push_cast v with + | Some e => inline e + | None => default_inline v + end + | _ => default_inline + end. + + Definition InlineCast {t} (e : Expr t) : Expr t + := InlineConstGen (@push_cast) e. +End language. |