aboutsummaryrefslogtreecommitdiff
path: root/src/Reflection/Named/MapCast.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Reflection/Named/MapCast.v')
-rw-r--r--src/Reflection/Named/MapCast.v71
1 files changed, 71 insertions, 0 deletions
diff --git a/src/Reflection/Named/MapCast.v b/src/Reflection/Named/MapCast.v
new file mode 100644
index 000000000..a0b161a0a
--- /dev/null
+++ b/src/Reflection/Named/MapCast.v
@@ -0,0 +1,71 @@
+Require Import Coq.Bool.Sumbool.
+Require Import Crypto.Reflection.SmartMap.
+Require Import Crypto.Reflection.Syntax.
+Require Import Crypto.Reflection.Named.Syntax.
+
+Local Open Scope nexpr_scope.
+Section language.
+ Context {base_type_code : Type}
+ {op : flat_type base_type_code -> flat_type base_type_code -> Type}
+ {Name : Type}
+ {interp_base_type_bounds : base_type_code -> Type}
+ (interp_op_bounds : forall src dst, op src dst -> interp_flat_type interp_base_type_bounds src -> interp_flat_type interp_base_type_bounds dst)
+ (pick_typeb : forall t, interp_base_type_bounds t -> base_type_code).
+ Local Notation pick_type v := (SmartFlatTypeMap pick_typeb v).
+ Context (cast_op : forall t tR (opc : op t tR) args_bs,
+ op (pick_type args_bs) (pick_type (interp_op_bounds t tR opc args_bs)))
+ {BoundsContext : Context Name interp_base_type_bounds}.
+
+ Fixpoint mapf_cast
+ (ctx : BoundsContext)
+ {t} (e : exprf base_type_code op Name t)
+ {struct e}
+ : option { bounds : interp_flat_type interp_base_type_bounds t
+ & exprf base_type_code op Name (pick_type bounds) }
+ := match e in exprf _ _ _ t return option { bounds : interp_flat_type interp_base_type_bounds t
+ & exprf base_type_code op Name (pick_type bounds) } with
+ | TT => Some (existT _ tt TT)
+ | Pair tx ex ty ey
+ => match @mapf_cast ctx _ ex, @mapf_cast ctx _ ey with
+ | Some (existT x_bs xv), Some (existT y_bs yv)
+ => Some (existT _ (x_bs, y_bs)%core (Pair xv yv))
+ | None, _ | _, None => None
+ end
+ | Var t x
+ => option_map
+ (fun bounds => existT _ bounds (Var x))
+ (lookupb (t:=t) ctx x)
+ | LetIn tx n ex tC eC
+ => match @mapf_cast ctx _ ex with
+ | Some (existT x_bounds ex')
+ => option_map
+ (fun eC' => let 'existT Cx_bounds C_expr := eC' in
+ existT _ Cx_bounds (LetIn (pick_type x_bounds)
+ (SmartFlatTypeMapInterp2 (t:=tx) (fun _ _ (n : Name) => n) x_bounds n) ex' C_expr))
+ (@mapf_cast (extend (t:=tx) ctx n x_bounds) _ eC)
+ | None => None
+ end
+ | Op t tR opc args
+ => option_map
+ (fun args'
+ => let 'existT args_bounds argsv := args' in
+ existT _
+ (interp_op_bounds _ _ _ args_bounds)
+ (Op (cast_op t tR opc args_bounds) argsv))
+ (@mapf_cast ctx _ args)
+ end.
+
+ Definition map_cast
+ (ctx : BoundsContext)
+ {t} (e : expr base_type_code op Name t)
+ (input_bounds : interp_flat_type interp_base_type_bounds (domain t))
+ : option { output_bounds : interp_flat_type interp_base_type_bounds (codomain t)
+ & expr base_type_code op Name (Arrow (pick_type input_bounds) (pick_type output_bounds)) }
+ := option_map
+ (fun v => existT
+ _
+ (projT1 v)
+ (Abs (SmartFlatTypeMapInterp2 (fun _ _ (n' : Name) => n') input_bounds (Abs_name e))
+ (projT2 v)))
+ (mapf_cast (extend ctx (Abs_name e) input_bounds) (invert_Abs e)).
+End language.