aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/Named/MapCast.v
blob: fddee84faf7d20d792a7fd3fd78b13193965451f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
Require Import Coq.Bool.Sumbool.
Require Import Crypto.Compilers.SmartMap.
Require Import Crypto.Compilers.Syntax.
Require Import Crypto.Compilers.Named.Context.
Require Import Crypto.Compilers.Named.Syntax.

Local Open Scope nexpr_scope.
Section language.
  Context {base_type_code : Type}
          {op : flat_type base_type_code -> flat_type base_type_code -> Type}
          {Name : Type}
          {interp_base_type_bounds : base_type_code -> Type}
          (interp_op_bounds : forall src dst, op src dst -> interp_flat_type interp_base_type_bounds src -> interp_flat_type interp_base_type_bounds dst)
          (pick_typeb : forall t, interp_base_type_bounds t -> base_type_code).
  Local Notation pick_type v := (SmartFlatTypeMap pick_typeb v).
  Context (cast_op : forall t tR (opc : op t tR) args_bs,
              op (pick_type args_bs) (pick_type (interp_op_bounds t tR opc args_bs)))
          {BoundsContext : Context Name interp_base_type_bounds}.

  Fixpoint mapf_cast
           (ctx : BoundsContext)
           {t} (e : exprf base_type_code op Name t)
           {struct e}
    : option { bounds : interp_flat_type interp_base_type_bounds t
                        & exprf base_type_code op Name (pick_type bounds) }
    := match e in exprf _ _ _ t return option { bounds : interp_flat_type interp_base_type_bounds t
                                                         & exprf base_type_code op Name (pick_type bounds) } with
       | TT => Some (existT _ tt TT)
       | Pair tx ex ty ey
         => match @mapf_cast ctx _ ex, @mapf_cast ctx _ ey with
            | Some (existT x_bs xv), Some (existT y_bs yv)
              => Some (existT _ (x_bs, y_bs)%core (Pair xv yv))
            | None, _ | _, None => None
            end
       | Var t x
         => option_map
              (fun bounds => existT _ bounds (Var x))
              (lookupb (t:=t) ctx x)
       | LetIn tx n ex tC eC
         => match @mapf_cast ctx _ ex with
            | Some (existT x_bounds ex')
              => option_map
                   (fun eC' => let 'existT Cx_bounds C_expr := eC' in
                               existT _ Cx_bounds (LetIn (pick_type x_bounds)
                                                         (SmartFlatTypeMapInterp2 (t:=tx) (fun _ _ (n : Name) => n) x_bounds n) ex' C_expr))
                   (@mapf_cast (extend (t:=tx) ctx n x_bounds) _ eC)
            | None => None
            end
       | Op t tR opc args
         => option_map
              (fun args'
               => let 'existT args_bounds argsv := args' in
                  existT _
                         (interp_op_bounds _ _ _ args_bounds)
                         (Op (cast_op t tR opc args_bounds) argsv))
              (@mapf_cast ctx _ args)
       end.

  Definition map_cast
             (ctx : BoundsContext)
             {t} (e : expr base_type_code op Name t)
             (input_bounds : interp_flat_type interp_base_type_bounds (domain t))
    : option { output_bounds : interp_flat_type interp_base_type_bounds (codomain t)
                               & expr base_type_code op Name (Arrow (pick_type input_bounds) (pick_type output_bounds)) }
    := option_map
         (fun v => existT
                     _
                     (projT1 v)
                     (Abs (SmartFlatTypeMapInterp2 (fun _ _ (n' : Name) => n') input_bounds (Abs_name e))
                          (projT2 v)))
         (mapf_cast (extend ctx (Abs_name e) input_bounds) (invert_Abs e)).
End language.