aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/SmartBound.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Compilers/SmartBound.v')
-rw-r--r--src/Compilers/SmartBound.v135
1 files changed, 135 insertions, 0 deletions
diff --git a/src/Compilers/SmartBound.v b/src/Compilers/SmartBound.v
new file mode 100644
index 000000000..39fd34dd7
--- /dev/null
+++ b/src/Compilers/SmartBound.v
@@ -0,0 +1,135 @@
+Require Import Crypto.Compilers.Syntax.
+Require Import Crypto.Compilers.ExprInversion.
+Require Import Crypto.Compilers.TypeUtil.
+Require Import Crypto.Compilers.SmartCast.
+Require Import Crypto.Compilers.SmartMap.
+Require Import Crypto.Util.Notations.
+
+Local Open Scope expr_scope.
+Local Open Scope ctype_scope.
+Section language.
+ Context {base_type_code : Type}
+ {op : flat_type base_type_code -> flat_type base_type_code -> Type}
+ (interp_base_type_bounds : base_type_code -> Type)
+ (interp_op_bounds : forall src dst, op src dst -> interp_flat_type interp_base_type_bounds src -> interp_flat_type interp_base_type_bounds dst)
+ (bound_base_type : forall t, interp_base_type_bounds t -> base_type_code)
+ (base_type_beq : base_type_code -> base_type_code -> bool)
+ (base_type_bl_transparent : forall x y, base_type_beq x y = true -> x = y)
+ (base_type_leb : base_type_code -> base_type_code -> bool)
+ (Cast : forall var A A', exprf base_type_code op (var:=var) (Tbase A) -> exprf base_type_code op (var:=var) (Tbase A'))
+ (genericize_op : forall src dst (opc : op src dst) (new_bounded_type_in new_bounded_type_out : base_type_code),
+ option { src'dst' : _ & op (fst src'dst') (snd src'dst') })
+ (failf : forall var t, @exprf base_type_code op var (Tbase t)).
+ Local Infix "<=?" := base_type_leb : expr_scope.
+ Local Infix "=?" := base_type_beq : expr_scope.
+
+ Local Notation flat_type_max := (flat_type_max base_type_leb).
+ Local Notation SmartCast := (@SmartCast _ op _ base_type_bl_transparent Cast).
+
+ Local Notation flat_type := (flat_type base_type_code).
+ Local Notation type := (type base_type_code).
+ Local Notation exprf := (@exprf base_type_code op).
+ Local Notation expr := (@expr base_type_code op).
+ Local Notation Expr := (@Expr base_type_code op).
+
+ Definition bound_flat_type {t} : interp_flat_type interp_base_type_bounds t
+ -> flat_type
+ := @SmartFlatTypeMap _ interp_base_type_bounds (fun t v => bound_base_type t v) t.
+ Definition bound_type {t}
+ (e_bounds : interp_type interp_base_type_bounds t)
+ (input_bounds : interp_flat_type interp_base_type_bounds (domain t))
+ : type
+ := Arrow (@bound_flat_type (domain t) input_bounds)
+ (@bound_flat_type (codomain t) (e_bounds input_bounds)).
+ Definition bound_op
+ ovar src1 dst1 src2 dst2 (opc1 : op src1 dst1) (opc2 : op src2 dst2)
+ : exprf (var:=ovar) src1
+ -> interp_flat_type interp_base_type_bounds src2
+ -> exprf (var:=ovar) dst1
+ := fun args input_bounds
+ => let output_bounds := interp_op_bounds _ _ opc2 input_bounds in
+ let input_ts := SmartVarfMap bound_base_type input_bounds in
+ let output_ts := SmartVarfMap bound_base_type output_bounds in
+ let new_type_in := flat_type_max input_ts in
+ let new_type_out := flat_type_max output_ts in
+ let new_opc := match new_type_in, new_type_out with
+ | Some new_type_in, Some new_type_out
+ => genericize_op _ _ opc1 new_type_in new_type_out
+ | None, _ | _, None => None
+ end in
+ match new_opc with
+ | Some (existT _ new_opc)
+ => match SmartCast _ _, SmartCast _ _ with
+ | Some SmartCast_args, Some SmartCast_result
+ => LetIn args
+ (fun args
+ => LetIn (Op new_opc (SmartCast_args args))
+ (fun opv => SmartCast_result opv))
+ | None, _
+ | _, None
+ => Op opc1 args
+ end
+ | None
+ => Op opc1 args
+ end.
+
+ Section smart_bound.
+ Definition interpf_smart_bound_exprf {var t}
+ (e : interp_flat_type var t) (bounds : interp_flat_type interp_base_type_bounds t)
+ : interp_flat_type (fun t => exprf (var:=var) (Tbase t)) (bound_flat_type bounds)
+ := SmartFlatTypeMap2Interp2
+ (f:=fun t v => Tbase _)
+ (fun t bs v => Cast _ t (bound_base_type t bs) (Var v))
+ bounds e.
+ Definition interpf_smart_unbound_exprf {var t}
+ (bounds : interp_flat_type interp_base_type_bounds t)
+ (e : interp_flat_type (fun t => exprf (var:=var) (Tbase t)) (bound_flat_type bounds))
+ : interp_flat_type (fun t => @exprf var (Tbase t)) t
+ := SmartFlatTypeMapUnInterp2
+ (f:=fun t v => Tbase (bound_base_type t _))
+ (fun t bs v => Cast _ (bound_base_type t bs) t v)
+ e.
+
+ Definition interpf_smart_bound
+ {interp_base_type}
+ (cast_val : forall A A', interp_base_type A -> interp_base_type A')
+ {t}
+ (e : interp_flat_type interp_base_type t)
+ (bounds : interp_flat_type interp_base_type_bounds t)
+ : interp_flat_type interp_base_type (bound_flat_type bounds)
+ := SmartFlatTypeMap2Interp2
+ (f:=fun t v => Tbase _)
+ (fun t bs v => cast_val t (bound_base_type t bs) v)
+ bounds e.
+ Definition interpf_smart_unbound
+ {interp_base_type}
+ (cast_val : forall A A', interp_base_type A -> interp_base_type A')
+ {t}
+ (bounds : interp_flat_type interp_base_type_bounds t)
+ (e : interp_flat_type interp_base_type (bound_flat_type bounds))
+ : interp_flat_type interp_base_type t
+ := SmartFlatTypeMapUnInterp2 (f:=fun _ _ => Tbase _) (fun t b v => cast_val _ _ v) e.
+
+ Definition smart_boundf {var t1} (e1 : exprf (var:=var) t1) (bounds : interp_flat_type interp_base_type_bounds t1)
+ : exprf (var:=var) (bound_flat_type bounds)
+ := LetIn e1 (fun e1' => SmartPairf (var:=var) (interpf_smart_bound_exprf e1' bounds)).
+ Definition smart_bound {var t1} (e1 : expr (var:=var) t1)
+ (e_bounds : interp_type interp_base_type_bounds t1)
+ (input_bounds : interp_flat_type interp_base_type_bounds (domain t1))
+ : expr (var:=var) (bound_type e_bounds input_bounds)
+ := Abs
+ (fun args
+ => LetIn
+ (SmartPairf (interpf_smart_unbound_exprf input_bounds (SmartVarfMap (fun _ => Var) args)))
+ (fun v => smart_boundf
+ (invert_Abs e1 v)
+ (e_bounds input_bounds))).
+ Definition SmartBound {t1} (e : Expr t1)
+ (input_bounds : interp_flat_type interp_base_type_bounds (domain t1))
+ : Expr (bound_type _ input_bounds)
+ := fun var => smart_bound (e var) (interp (@interp_op_bounds) (e _)) input_bounds.
+ End smart_bound.
+End language.
+
+Global Arguments bound_flat_type _ _ _ !_ _ / .
+Global Arguments bound_type _ _ _ !_ _ _ / .