aboutsummaryrefslogtreecommitdiff
path: root/src/Reflection
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-02-13 14:01:56 -0500
committerGravatar Jason Gross <jgross@mit.edu>2017-02-13 14:01:56 -0500
commit093834a3e35d86ce768102d1c5e894565a36fd74 (patch)
tree65870dfd5d20c413b01be5e542ec72417c9d732f /src/Reflection
parentf5ed7d87fddab66267f2a8242d31c205a0e24a4a (diff)
Split up BoundByCast
Diffstat (limited to 'src/Reflection')
-rw-r--r--src/Reflection/BoundByCast.v241
-rw-r--r--src/Reflection/InlineCast.v90
-rw-r--r--src/Reflection/MultiSizeTest2.v1
-rw-r--r--src/Reflection/SmartBound.v143
-rw-r--r--src/Reflection/SmartCast.v41
-rw-r--r--src/Reflection/TypeUtil.v35
6 files changed, 318 insertions, 233 deletions
diff --git a/src/Reflection/BoundByCast.v b/src/Reflection/BoundByCast.v
index fb7b4b576..09bdc207e 100644
--- a/src/Reflection/BoundByCast.v
+++ b/src/Reflection/BoundByCast.v
@@ -1,12 +1,11 @@
-Require Import Coq.Bool.Sumbool.
Require Import Crypto.Reflection.Syntax.
+Require Import Crypto.Reflection.SmartBound.
+Require Import Crypto.Reflection.InlineCast.
Require Import Crypto.Reflection.Application.
-Require Import Crypto.Reflection.SmartMap.
Require Import Crypto.Reflection.Inline.
Require Import Crypto.Reflection.Linearize.
Require Import Crypto.Reflection.MapCast.
Require Import Crypto.Reflection.Eta.
-Require Import Crypto.Util.Notations.
Local Open Scope expr_scope.
Local Open Scope ctype_scope.
@@ -25,249 +24,25 @@ Section language.
(genericize_op : forall src dst (opc : op src dst) (new_bounded_type_in new_bounded_type_out : base_type_code),
option { src'dst' : _ & op (fst src'dst') (snd src'dst') })
(failf : forall var t, @exprf base_type_code op var (Tbase t)).
- Local Infix "<=?" := base_type_leb : expr_scope.
- Local Infix "=?" := base_type_beq : expr_scope.
- Local Notation flat_type := (flat_type base_type_code).
- Local Notation type := (type base_type_code).
- Local Notation exprf := (@exprf base_type_code op).
- Local Notation expr := (@expr base_type_code op).
Local Notation Expr := (@Expr base_type_code op).
- Definition base_type_min (a b : base_type_code) : base_type_code
- := if a <=? b then a else b.
- Definition base_type_max (a b : base_type_code) : base_type_code
- := if a <=? b then b else a.
- Section gen.
- Context (join : base_type_code -> base_type_code -> base_type_code).
- Fixpoint flat_type_join {t : flat_type}
- : interp_flat_type (fun _ => base_type_code) t -> option base_type_code
- := match t with
- | Tbase _ => fun v => Some v
- | Unit => fun _ => None
- | Prod A B
- => fun v => match @flat_type_join A (fst v), @flat_type_join B (snd v) with
- | Some a, Some b => Some (join a b)
- | Some a, None => Some a
- | None, Some b => Some b
- | None, None => None
- end
- end.
- End gen.
- Definition flat_type_min {t} := @flat_type_join base_type_min t.
- Definition flat_type_max {t} := @flat_type_join base_type_max t.
-
- Definition SmartCast_base {var A A'} (x : exprf (var:=var) (Tbase A))
- : exprf (var:=var) (Tbase A')
- := match sumbool_of_bool (base_type_beq A A') with
- | left pf => match base_type_bl_transparent _ _ pf with
- | eq_refl => x
- end
- | right _ => Cast _ _ A' x
- end.
-
- Fixpoint SmartCast {var} A B
- : option (interp_flat_type var A -> exprf (var:=var) B)
- := match A, B return option (interp_flat_type var A -> exprf (var:=var) B) with
- | Tbase A, Tbase B => Some (fun v => SmartCast_base (Var (var:=var) v))
- | Prod A0 A1, Prod B0 B1
- => match @SmartCast _ A0 B0, @SmartCast _ A1 B1 with
- | Some f, Some g => Some (fun xy => Pair (f (fst xy)) (g (snd xy)))
- | _, _ => None
- end
- | Unit, Unit => Some (fun _ => TT)
- | Tbase _, _
- | Prod _ _, _
- | Unit, _
- => None
- end.
-
- Section inline_cast.
- (** We can squash [a -> b -> c] into [a -> c] if [min a b c = min
- a c], i.e., if the narrowest type we pass through in the
- original case is the same as the narrowest type we pass
- through in the new case. *)
- Definition squash_cast {var} (a b c : base_type_code)
- : @exprf var (Tbase a) -> @exprf var (Tbase c)
- := if base_type_beq (base_type_min b (base_type_min a c)) (base_type_min a c)
- then SmartCast_base
- else fun x => Cast _ b c (Cast _ a b x).
- Fixpoint maybe_push_cast {var t} (v : @exprf var t) : option (@exprf var t)
- := match v in Syntax.exprf _ _ t return option (exprf t) with
- | Var _ _ as v'
- => Some v'
- | Op t1 tR opc args
- => match t1, tR return op t1 tR -> exprf t1 -> option (exprf tR) with
- | Tbase b, Tbase c
- => fun opc args
- => if is_cast _ _ opc
- then match @maybe_push_cast _ _ args with
- | Some (Op t1 tR opc' args')
- => match t1, tR return op t1 tR -> exprf t1 -> option (exprf (Tbase c)) with
- | Tbase a, Tbase b
- => fun opc' args'
- => if is_cast _ _ opc'
- then Some (squash_cast a b c args')
- else None
- | Unit, Tbase _
- => fun opc' args'
- => if is_const _ _ opc'
- then Some (SmartCast_base (Op opc' TT))
- else None
- | _, _ => fun _ _ => None
- end opc' args'
- | Some (Var _ v as v') => Some (SmartCast_base v')
- | Some _ => None
- | None => None
- end
- else None
- | Unit, _
- => fun opc args
- => if is_const _ _ opc
- then Some (Op opc TT)
- else None
- | _, _
- => fun _ _ => None
- end opc args
- | Pair _ _ _ _
- | LetIn _ _ _ _
- | TT
- => None
- end.
- Definition push_cast {var t} : @exprf var t -> inline_directive (op:=op) (var:=var) t
- := match t with
- | Tbase _ => fun v => match maybe_push_cast v with
- | Some e => inline e
- | None => default_inline v
- end
- | _ => default_inline
- end.
-
- Definition InlineCast {t} (e : Expr t) : Expr t
- := InlineConstGen (@push_cast) e.
- End inline_cast.
-
- Definition bound_flat_type {t} : interp_flat_type interp_base_type_bounds t
- -> flat_type
- := @SmartFlatTypeMap2 _ _ interp_base_type_bounds (fun t v => Tbase (bound_base_type t v)) t.
- Fixpoint bound_type {t} : forall (e_bounds : interp_type interp_base_type_bounds t)
- (input_bounds : interp_all_binders_for' t interp_base_type_bounds),
- type
- := match t return interp_type _ t -> interp_all_binders_for' t _ -> type with
- | Tflat T => fun e_bounds _ => @bound_flat_type T e_bounds
- | Arrow A B
- => fun e_bounds input_bounds
- => Arrow (@bound_base_type A (fst input_bounds))
- (@bound_type B (e_bounds (fst input_bounds)) (snd input_bounds))
- end.
- Definition bound_op
- ovar src1 dst1 src2 dst2 (opc1 : op src1 dst1) (opc2 : op src2 dst2)
- : exprf (var:=ovar) src1
- -> interp_flat_type interp_base_type_bounds src2
- -> exprf (var:=ovar) dst1
- := fun args input_bounds
- => let output_bounds := interp_op_bounds _ _ opc2 input_bounds in
- let input_ts := SmartVarfMap bound_base_type input_bounds in
- let output_ts := SmartVarfMap bound_base_type output_bounds in
- let new_type_in := flat_type_max input_ts in
- let new_type_out := flat_type_max output_ts in
- let new_opc := match new_type_in, new_type_out with
- | Some new_type_in, Some new_type_out
- => genericize_op _ _ opc1 new_type_in new_type_out
- | None, _ | _, None => None
- end in
- match new_opc with
- | Some (existT _ new_opc)
- => match SmartCast _ _, SmartCast _ _ with
- | Some SmartCast_args, Some SmartCast_result
- => LetIn args
- (fun args
- => LetIn (Op new_opc (SmartCast_args args))
- (fun opv => SmartCast_result opv))
- | None, _
- | _, None
- => Op opc1 args
- end
- | None
- => Op opc1 args
- end.
-
- Section smart_bound.
- Definition interpf_smart_bound {var t}
- (e : interp_flat_type var t) (bounds : interp_flat_type interp_base_type_bounds t)
- : interp_flat_type (fun t => exprf (var:=var) (Tbase t)) (bound_flat_type bounds)
- := SmartFlatTypeMap2Interp2
- (f:=fun t v => Tbase _)
- (fun t bs v => Cast _ t (bound_base_type t bs) (Var v))
- bounds e.
- Definition interpf_smart_unbound {var t}
- (bounds : interp_flat_type interp_base_type_bounds t)
- (e : interp_flat_type (fun t => exprf (var:=var) (Tbase t)) (bound_flat_type bounds))
- : interp_flat_type (fun t => @exprf var (Tbase t)) t
- := SmartFlatTypeMapUnInterp2
- (f:=fun t v => Tbase (bound_base_type t _))
- (fun t bs v => Cast _ (bound_base_type t bs) t v)
- e.
-
- Definition smart_boundf {var t1} (e1 : exprf (var:=var) t1) (bounds : interp_flat_type interp_base_type_bounds t1)
- : exprf (var:=var) (bound_flat_type bounds)
- := LetIn e1 (fun e1' => SmartPairf (var:=var) (interpf_smart_bound e1' bounds)).
- Fixpoint UnSmartArrow {P t}
- : forall (e_bounds : interp_type interp_base_type_bounds t)
- (input_bounds : interp_all_binders_for' t interp_base_type_bounds)
- (e : P (SmartArrow (bound_flat_type input_bounds)
- (bound_flat_type (ApplyInterpedAll' e_bounds input_bounds)))),
- P (bound_type e_bounds input_bounds)
- := match t
- return (forall (e_bounds : interp_type interp_base_type_bounds t)
- (input_bounds : interp_all_binders_for' t interp_base_type_bounds)
- (e : P (SmartArrow (bound_flat_type input_bounds)
- (bound_flat_type (ApplyInterpedAll' e_bounds input_bounds)))),
- P (bound_type e_bounds input_bounds))
- with
- | Tflat T => fun _ _ x => x
- | Arrow A B => fun e_bounds input_bounds
- => @UnSmartArrow
- (fun t => P (Arrow (bound_base_type A (fst input_bounds)) t))
- B
- (e_bounds (fst input_bounds))
- (snd input_bounds)
- end.
- Definition smart_bound {var t1} (e1 : expr (var:=var) t1)
- (e_bounds : interp_type interp_base_type_bounds t1)
- (input_bounds : interp_all_binders_for' t1 interp_base_type_bounds)
- : expr (var:=var) (bound_type e_bounds input_bounds)
- := UnSmartArrow
- e_bounds
- input_bounds
- (SmartAbs
- (fun args
- => LetIn
- args
- (fun args
- => LetIn
- (SmartPairf (interpf_smart_unbound input_bounds (SmartVarfMap (fun _ => Var) args)))
- (fun v => smart_boundf
- (ApplyAll e1 (interp_all_binders_for_of' v))
- (ApplyInterpedAll' e_bounds input_bounds))))).
- Definition SmartBound {t1} (e : Expr t1)
- (input_bounds : interp_all_binders_for' t1 interp_base_type_bounds)
- : Expr (bound_type _ input_bounds)
- := fun var => smart_bound (e var) (interp (@interp_op_bounds) (e _)) input_bounds.
- End smart_bound.
-
Definition Boundify {t1} (e1 : Expr t1) args2
: Expr _
:= ExprEta
(InlineConstGen
- (@push_cast)
+ (@push_cast _ _ _ base_type_bl_transparent base_type_leb Cast is_cast is_const)
(Linearize
(SmartBound
+ _
+ interp_op_bounds
+ bound_base_type
+ Cast
(@MapInterpCast
base_type_code interp_base_type_bounds
op (@interp_op_bounds)
(@failf)
- (@bound_op)
+ (@bound_op _ _ _ interp_op_bounds bound_base_type _ base_type_bl_transparent base_type_leb Cast genericize_op)
t1 e1 (interp_all_binders_for_to' args2))
(interp_all_binders_for_to' args2)))).
End language.
diff --git a/src/Reflection/InlineCast.v b/src/Reflection/InlineCast.v
new file mode 100644
index 000000000..554c42da7
--- /dev/null
+++ b/src/Reflection/InlineCast.v
@@ -0,0 +1,90 @@
+Require Import Crypto.Reflection.Syntax.
+Require Import Crypto.Reflection.SmartCast.
+Require Import Crypto.Reflection.TypeUtil.
+Require Import Crypto.Reflection.Inline.
+Require Import Crypto.Util.Notations.
+
+Local Open Scope expr_scope.
+Local Open Scope ctype_scope.
+Section language.
+ Context {base_type_code : Type}
+ {op : flat_type base_type_code -> flat_type base_type_code -> Type}
+ (base_type_beq : base_type_code -> base_type_code -> bool)
+ (base_type_bl_transparent : forall x y, base_type_beq x y = true -> x = y)
+ (base_type_leb : base_type_code -> base_type_code -> bool)
+ (Cast : forall var A A', exprf base_type_code op (var:=var) (Tbase A) -> exprf base_type_code op (var:=var) (Tbase A'))
+ (is_cast : forall src dst, op src dst -> bool)
+ (is_const : forall src dst, op src dst -> bool).
+ Local Infix "<=?" := base_type_leb : expr_scope.
+ Local Infix "=?" := base_type_beq : expr_scope.
+
+ Local Notation base_type_min := (base_type_min base_type_leb).
+ Local Notation SmartCast_base := (@SmartCast_base _ op _ base_type_bl_transparent Cast).
+
+ Local Notation flat_type := (flat_type base_type_code).
+ Local Notation exprf := (@exprf base_type_code op).
+ Local Notation Expr := (@Expr base_type_code op).
+
+ (** We can squash [a -> b -> c] into [a -> c] if [min a b c = min a
+ c], i.e., if the narrowest type we pass through in the original
+ case is the same as the narrowest type we pass through in the
+ new case. *)
+ Definition squash_cast {var} (a b c : base_type_code)
+ : @exprf var (Tbase a) -> @exprf var (Tbase c)
+ := if base_type_beq (base_type_min b (base_type_min a c)) (base_type_min a c)
+ then SmartCast_base
+ else fun x => Cast _ b c (Cast _ a b x).
+ Fixpoint maybe_push_cast {var t} (v : @exprf var t) : option (@exprf var t)
+ := match v in Syntax.exprf _ _ t return option (exprf t) with
+ | Var _ _ as v'
+ => Some v'
+ | Op t1 tR opc args
+ => match t1, tR return op t1 tR -> exprf t1 -> option (exprf tR) with
+ | Tbase b, Tbase c
+ => fun opc args
+ => if is_cast _ _ opc
+ then match @maybe_push_cast _ _ args with
+ | Some (Op t1 tR opc' args')
+ => match t1, tR return op t1 tR -> exprf t1 -> option (exprf (Tbase c)) with
+ | Tbase a, Tbase b
+ => fun opc' args'
+ => if is_cast _ _ opc'
+ then Some (squash_cast a b c args')
+ else None
+ | Unit, Tbase _
+ => fun opc' args'
+ => if is_const _ _ opc'
+ then Some (SmartCast_base (Op opc' TT))
+ else None
+ | _, _ => fun _ _ => None
+ end opc' args'
+ | Some (Var _ v as v') => Some (SmartCast_base v')
+ | Some _ => None
+ | None => None
+ end
+ else None
+ | Unit, _
+ => fun opc args
+ => if is_const _ _ opc
+ then Some (Op opc TT)
+ else None
+ | _, _
+ => fun _ _ => None
+ end opc args
+ | Pair _ _ _ _
+ | LetIn _ _ _ _
+ | TT
+ => None
+ end.
+ Definition push_cast {var t} : @exprf var t -> inline_directive (op:=op) (var:=var) t
+ := match t with
+ | Tbase _ => fun v => match maybe_push_cast v with
+ | Some e => inline e
+ | None => default_inline v
+ end
+ | _ => default_inline
+ end.
+
+ Definition InlineCast {t} (e : Expr t) : Expr t
+ := InlineConstGen (@push_cast) e.
+End language.
diff --git a/src/Reflection/MultiSizeTest2.v b/src/Reflection/MultiSizeTest2.v
index 49629cf58..afd53bd19 100644
--- a/src/Reflection/MultiSizeTest2.v
+++ b/src/Reflection/MultiSizeTest2.v
@@ -1,5 +1,6 @@
Require Import Coq.omega.Omega.
Require Import Crypto.Reflection.Syntax.
+Require Import Crypto.Reflection.TypeUtil.
Require Import Crypto.Reflection.BoundByCast.
(** * Preliminaries: bounded and unbounded number types *)
diff --git a/src/Reflection/SmartBound.v b/src/Reflection/SmartBound.v
new file mode 100644
index 000000000..f77fe4274
--- /dev/null
+++ b/src/Reflection/SmartBound.v
@@ -0,0 +1,143 @@
+Require Import Crypto.Reflection.Syntax.
+Require Import Crypto.Reflection.TypeUtil.
+Require Import Crypto.Reflection.SmartCast.
+Require Import Crypto.Reflection.Application.
+Require Import Crypto.Reflection.SmartMap.
+Require Import Crypto.Util.Notations.
+
+Local Open Scope expr_scope.
+Local Open Scope ctype_scope.
+Section language.
+ Context {base_type_code : Type}
+ {op : flat_type base_type_code -> flat_type base_type_code -> Type}
+ (interp_base_type_bounds : base_type_code -> Type)
+ (interp_op_bounds : forall src dst, op src dst -> interp_flat_type interp_base_type_bounds src -> interp_flat_type interp_base_type_bounds dst)
+ (bound_base_type : forall t, interp_base_type_bounds t -> base_type_code)
+ (base_type_beq : base_type_code -> base_type_code -> bool)
+ (base_type_bl_transparent : forall x y, base_type_beq x y = true -> x = y)
+ (base_type_leb : base_type_code -> base_type_code -> bool)
+ (Cast : forall var A A', exprf base_type_code op (var:=var) (Tbase A) -> exprf base_type_code op (var:=var) (Tbase A'))
+ (genericize_op : forall src dst (opc : op src dst) (new_bounded_type_in new_bounded_type_out : base_type_code),
+ option { src'dst' : _ & op (fst src'dst') (snd src'dst') })
+ (failf : forall var t, @exprf base_type_code op var (Tbase t)).
+ Local Infix "<=?" := base_type_leb : expr_scope.
+ Local Infix "=?" := base_type_beq : expr_scope.
+
+ Local Notation flat_type_max := (flat_type_max base_type_leb).
+ Local Notation SmartCast := (@SmartCast _ op _ base_type_bl_transparent Cast).
+
+ Local Notation flat_type := (flat_type base_type_code).
+ Local Notation type := (type base_type_code).
+ Local Notation exprf := (@exprf base_type_code op).
+ Local Notation expr := (@expr base_type_code op).
+ Local Notation Expr := (@Expr base_type_code op).
+
+ Definition bound_flat_type {t} : interp_flat_type interp_base_type_bounds t
+ -> flat_type
+ := @SmartFlatTypeMap2 _ _ interp_base_type_bounds (fun t v => Tbase (bound_base_type t v)) t.
+ Fixpoint bound_type {t} : forall (e_bounds : interp_type interp_base_type_bounds t)
+ (input_bounds : interp_all_binders_for' t interp_base_type_bounds),
+ type
+ := match t return interp_type _ t -> interp_all_binders_for' t _ -> type with
+ | Tflat T => fun e_bounds _ => @bound_flat_type T e_bounds
+ | Arrow A B
+ => fun e_bounds input_bounds
+ => Arrow (@bound_base_type A (fst input_bounds))
+ (@bound_type B (e_bounds (fst input_bounds)) (snd input_bounds))
+ end.
+ Definition bound_op
+ ovar src1 dst1 src2 dst2 (opc1 : op src1 dst1) (opc2 : op src2 dst2)
+ : exprf (var:=ovar) src1
+ -> interp_flat_type interp_base_type_bounds src2
+ -> exprf (var:=ovar) dst1
+ := fun args input_bounds
+ => let output_bounds := interp_op_bounds _ _ opc2 input_bounds in
+ let input_ts := SmartVarfMap bound_base_type input_bounds in
+ let output_ts := SmartVarfMap bound_base_type output_bounds in
+ let new_type_in := flat_type_max input_ts in
+ let new_type_out := flat_type_max output_ts in
+ let new_opc := match new_type_in, new_type_out with
+ | Some new_type_in, Some new_type_out
+ => genericize_op _ _ opc1 new_type_in new_type_out
+ | None, _ | _, None => None
+ end in
+ match new_opc with
+ | Some (existT _ new_opc)
+ => match SmartCast _ _, SmartCast _ _ with
+ | Some SmartCast_args, Some SmartCast_result
+ => LetIn args
+ (fun args
+ => LetIn (Op new_opc (SmartCast_args args))
+ (fun opv => SmartCast_result opv))
+ | None, _
+ | _, None
+ => Op opc1 args
+ end
+ | None
+ => Op opc1 args
+ end.
+
+ Section smart_bound.
+ Definition interpf_smart_bound {var t}
+ (e : interp_flat_type var t) (bounds : interp_flat_type interp_base_type_bounds t)
+ : interp_flat_type (fun t => exprf (var:=var) (Tbase t)) (bound_flat_type bounds)
+ := SmartFlatTypeMap2Interp2
+ (f:=fun t v => Tbase _)
+ (fun t bs v => Cast _ t (bound_base_type t bs) (Var v))
+ bounds e.
+ Definition interpf_smart_unbound {var t}
+ (bounds : interp_flat_type interp_base_type_bounds t)
+ (e : interp_flat_type (fun t => exprf (var:=var) (Tbase t)) (bound_flat_type bounds))
+ : interp_flat_type (fun t => @exprf var (Tbase t)) t
+ := SmartFlatTypeMapUnInterp2
+ (f:=fun t v => Tbase (bound_base_type t _))
+ (fun t bs v => Cast _ (bound_base_type t bs) t v)
+ e.
+
+ Definition smart_boundf {var t1} (e1 : exprf (var:=var) t1) (bounds : interp_flat_type interp_base_type_bounds t1)
+ : exprf (var:=var) (bound_flat_type bounds)
+ := LetIn e1 (fun e1' => SmartPairf (var:=var) (interpf_smart_bound e1' bounds)).
+ Fixpoint UnSmartArrow {P t}
+ : forall (e_bounds : interp_type interp_base_type_bounds t)
+ (input_bounds : interp_all_binders_for' t interp_base_type_bounds)
+ (e : P (SmartArrow (bound_flat_type input_bounds)
+ (bound_flat_type (ApplyInterpedAll' e_bounds input_bounds)))),
+ P (bound_type e_bounds input_bounds)
+ := match t
+ return (forall (e_bounds : interp_type interp_base_type_bounds t)
+ (input_bounds : interp_all_binders_for' t interp_base_type_bounds)
+ (e : P (SmartArrow (bound_flat_type input_bounds)
+ (bound_flat_type (ApplyInterpedAll' e_bounds input_bounds)))),
+ P (bound_type e_bounds input_bounds))
+ with
+ | Tflat T => fun _ _ x => x
+ | Arrow A B => fun e_bounds input_bounds
+ => @UnSmartArrow
+ (fun t => P (Arrow (bound_base_type A (fst input_bounds)) t))
+ B
+ (e_bounds (fst input_bounds))
+ (snd input_bounds)
+ end.
+ Definition smart_bound {var t1} (e1 : expr (var:=var) t1)
+ (e_bounds : interp_type interp_base_type_bounds t1)
+ (input_bounds : interp_all_binders_for' t1 interp_base_type_bounds)
+ : expr (var:=var) (bound_type e_bounds input_bounds)
+ := UnSmartArrow
+ e_bounds
+ input_bounds
+ (SmartAbs
+ (fun args
+ => LetIn
+ args
+ (fun args
+ => LetIn
+ (SmartPairf (interpf_smart_unbound input_bounds (SmartVarfMap (fun _ => Var) args)))
+ (fun v => smart_boundf
+ (ApplyAll e1 (interp_all_binders_for_of' v))
+ (ApplyInterpedAll' e_bounds input_bounds))))).
+ Definition SmartBound {t1} (e : Expr t1)
+ (input_bounds : interp_all_binders_for' t1 interp_base_type_bounds)
+ : Expr (bound_type _ input_bounds)
+ := fun var => smart_bound (e var) (interp (@interp_op_bounds) (e _)) input_bounds.
+ End smart_bound.
+End language.
diff --git a/src/Reflection/SmartCast.v b/src/Reflection/SmartCast.v
new file mode 100644
index 000000000..ee3712954
--- /dev/null
+++ b/src/Reflection/SmartCast.v
@@ -0,0 +1,41 @@
+Require Import Coq.Bool.Sumbool.
+Require Import Crypto.Reflection.Syntax.
+Require Import Crypto.Reflection.TypeUtil.
+Require Import Crypto.Util.Notations.
+
+Local Open Scope expr_scope.
+Local Open Scope ctype_scope.
+Section language.
+ Context {base_type_code : Type}
+ {op : flat_type base_type_code -> flat_type base_type_code -> Type}
+ (base_type_beq : base_type_code -> base_type_code -> bool)
+ (base_type_bl_transparent : forall x y, base_type_beq x y = true -> x = y)
+ (Cast : forall var A A', exprf base_type_code op (var:=var) (Tbase A) -> exprf base_type_code op (var:=var) (Tbase A')).
+
+ Local Notation exprf := (@exprf base_type_code op).
+
+ Definition SmartCast_base {var A A'} (x : exprf (var:=var) (Tbase A))
+ : exprf (var:=var) (Tbase A')
+ := match sumbool_of_bool (base_type_beq A A') with
+ | left pf => match base_type_bl_transparent _ _ pf with
+ | eq_refl => x
+ end
+ | right _ => Cast _ _ A' x
+ end.
+
+ Fixpoint SmartCast {var} A B
+ : option (interp_flat_type var A -> exprf (var:=var) B)
+ := match A, B return option (interp_flat_type var A -> exprf (var:=var) B) with
+ | Tbase A, Tbase B => Some (fun v => SmartCast_base (Var (var:=var) v))
+ | Prod A0 A1, Prod B0 B1
+ => match @SmartCast _ A0 B0, @SmartCast _ A1 B1 with
+ | Some f, Some g => Some (fun xy => Pair (f (fst xy)) (g (snd xy)))
+ | _, _ => None
+ end
+ | Unit, Unit => Some (fun _ => TT)
+ | Tbase _, _
+ | Prod _ _, _
+ | Unit, _
+ => None
+ end.
+End language.
diff --git a/src/Reflection/TypeUtil.v b/src/Reflection/TypeUtil.v
new file mode 100644
index 000000000..8f7661bde
--- /dev/null
+++ b/src/Reflection/TypeUtil.v
@@ -0,0 +1,35 @@
+Require Import Crypto.Reflection.Syntax.
+Require Import Crypto.Util.Notations.
+
+Local Open Scope expr_scope.
+
+Section language.
+ Context {base_type_code : Type}
+ (base_type_beq : base_type_code -> base_type_code -> bool)
+ (base_type_leb : base_type_code -> base_type_code -> bool).
+ Local Infix "<=?" := base_type_leb : expr_scope.
+ Local Infix "=?" := base_type_beq : expr_scope.
+
+ Definition base_type_min (a b : base_type_code) : base_type_code
+ := if a <=? b then a else b.
+ Definition base_type_max (a b : base_type_code) : base_type_code
+ := if a <=? b then b else a.
+ Section gen.
+ Context (join : base_type_code -> base_type_code -> base_type_code).
+ Fixpoint flat_type_join {t : flat_type base_type_code}
+ : interp_flat_type (fun _ => base_type_code) t -> option base_type_code
+ := match t with
+ | Tbase _ => fun v => Some v
+ | Unit => fun _ => None
+ | Prod A B
+ => fun v => match @flat_type_join A (fst v), @flat_type_join B (snd v) with
+ | Some a, Some b => Some (join a b)
+ | Some a, None => Some a
+ | None, Some b => Some b
+ | None, None => None
+ end
+ end.
+ End gen.
+ Definition flat_type_min {t} := @flat_type_join base_type_min t.
+ Definition flat_type_max {t} := @flat_type_join base_type_max t.
+End language.