aboutsummaryrefslogtreecommitdiff
path: root/src/Reflection/MapCastWithCastOp.v
blob: ae41e11814e84c87f5c684842473c2366ae364eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
Require Import Crypto.Reflection.Syntax.
Require Import Crypto.Reflection.Equality.
Require Import Crypto.Reflection.Application.
Require Import Crypto.Reflection.MapCast.
Require Import Crypto.Util.Sigma.
Require Import Crypto.Util.Prod.
Require Import Crypto.Util.Option.

Local Open Scope ctype_scope.
Local Open Scope expr_scope.
Section language.
  Context {base_type_code1 base_type_code2 : Type}
          {interp_base_type2 : base_type_code2 -> Type}
          {op1 : flat_type base_type_code1 -> flat_type base_type_code1 -> Type}
          {op2 : flat_type base_type_code2 -> flat_type base_type_code2 -> Type}
          (interp_op2 : forall src dst, op2 src dst -> interp_flat_type interp_base_type2 src -> interp_flat_type interp_base_type2 dst)
          (base_type_code1_beq : base_type_code1 -> base_type_code1 -> bool)
          (base_type_code1_bl : forall x y, base_type_code1_beq x y = true -> x = y)
          (base_type_code1_lb : forall x y, x = y -> base_type_code1_beq x y = true)
          (failv : forall {var t}, @exprf base_type_code1 op1 var (Tbase t))
          (new_base_type : forall t, interp_base_type2 t -> base_type_code1)
          (Cast : forall var t1 t2, @exprf base_type_code1 op1 var (Tbase t1)
                                    -> @exprf base_type_code1 op1 var (Tbase t2))
          (is_cast : forall t1 t2, op1 t1 t2 -> bool).
  Local Notation new_flat_type (*: forall t, interp_flat_type interp_base_type2 t -> flat_type base_type_code1*)
    := (@SmartFlatTypeMap2 _ _ interp_base_type2 (fun t v => Tbase (new_base_type t v))).
  Local Notation new_type := (@new_type base_type_code1 base_type_code2 interp_base_type2 new_base_type).
  Context (new_op : forall ovar src1 dst1 src2 dst2 (opc1 : op1 src1 dst1) (opc2 : op2 src2 dst2)
                           args2,
              option { new_src : _ & (@exprf base_type_code1 op1 ovar new_src
                                      -> @exprf base_type_code1 op1 ovar (new_flat_type (interpf interp_op2 (Op opc2 args2))))%type }).
  Local Notation SmartFail
    := (SmartValf _ (@failv _)).
  Local Notation failf t (* {t} : @exprf base_type_code1 op1 ovar t*)
    := (SmartPairf (SmartFail t)).

  Fixpoint VarBound {var} T1 T2 : interp_flat_type var T1 -> exprf _ op1 (var:=var) T2
    := match T1, T2 return interp_flat_type var T1 -> exprf _ _ T2 with
       | Tbase T1', Tbase T2' => fun v : var T1' => Cast _ _ _ (Var v)
       | _, Unit => fun _ => TT
       | Prod A1 B1, Prod A2 B2
         => fun xy
            => Pair (@VarBound _ _ _ (fst xy)) (@VarBound _ _ _ (snd xy))
       | Tbase _, _
       | Prod _ _, _
       | Unit, _
         => fun _ => failf _
       end.
  Fixpoint SmartBound {var t1 t2} (v : @exprf _ op1 var t1) : @exprf _ op1 var t2.
  Proof.
    refine match Sumbool.sumbool_of_bool (flat_type_beq _ base_type_code1_beq t1 t2) with
           | left pf => match flat_type_dec_bl _ _ base_type_code1_bl _ _ pf in (_ = y) return exprf _ _ y with
                        | eq_refl => v
                        end
           | right _ => _
           end.
    refine (match v in exprf _ _ t1, t2 return (exprf _ _ _ -> exprf _ _ t2) -> exprf _ _ t2 with
            | Op t1 tR opc args, _
              => if is_cast _ _ opc
                 then fun _ => @SmartBound _ _ _ args
                 else fun default => default v
            | Pair _ ex _ ey, Prod _ _ => fun _ => Pair (@SmartBound _ _ _ ex) (@SmartBound _ _ _ ey)
            | v', _ => fun default => default v'
            end _).
    refine (match t1, t2 return exprf _ _ t1 -> exprf _ _ t2 with
            | Tbase t1', Tbase t2' => Cast _ _ _
            | _, Unit => fun _ => TT
            | Prod A1 B1, Prod A2 B2 => fun x => LetIn x (VarBound _ _)
            | Tbase _, _
            | Prod _ _, _
            | Unit, _
              => fun _ => failf _
            end).
  Defined.
  Definition bound_op ovar
             {src1 dst1 src2 dst2}
             (opc1 : op1 src1 dst1)
             (opc2 : op2 src2 dst2)
    : forall args2
             (args' : @exprf base_type_code1 op1 ovar (@new_flat_type _ (interpf interp_op2 args2))),
      @exprf base_type_code1 op1 ovar (@new_flat_type _ (interpf interp_op2 (Op opc2 args2)))
    := if is_cast _ _ opc1
       then fun _ x => SmartBound x
       else fun args2 args'
            => match new_op _ _ _ _ _ opc1 opc2 args2 with
               | Some f => projT2 f (SmartBound args')
               | None => failf _
               end.

  Section with_var.
    Context {ovar : base_type_code1 -> Type}.
    Local Notation ivar t := (@exprf base_type_code1 op1 ovar (Tbase t)) (only parsing).
    Local Notation ivarf := (fun t => ivar t).

    Fixpoint bound_var tx1 tx2 tC1
             {struct tx2}
      : forall (f : interp_flat_type ivarf tx1 -> exprf _ op1 (var:=ovar) tC1)
               (v : interp_flat_type ivarf tx2),
        exprf _ op1 (var:=ovar) tC1
      := match tx1, tx2 return (interp_flat_type _ tx1 -> _) -> interp_flat_type _ tx2 -> _ with
         | Tbase T1, Tbase T2 => fun f v => f (SmartBound v)
         | Unit, _ => fun f _ => f tt
         | Prod A1 B1, Prod A2 B2
           => fun f v
              => @bound_var
                   _ _ _
                   (fun v1 => @bound_var _ _ _ (fun v2 => f (v1, v2)%core) (snd v))
                   (fst v)
         | Tbase _, _
         | Prod _ _, _
           => fun f _ => f (SmartValf _ (fun t => failf _) _)
         end.

    Definition mapf_interp_cast_with_cast_op
               {t1} (e1 : @exprf base_type_code1 op1 ivarf t1)
               {t2} (e2 : @exprf base_type_code2 op2 interp_base_type2 t2)
      : @exprf base_type_code1 op1 ovar (@new_flat_type _ (interpf interp_op2 e2))
      := @mapf_interp_cast
           base_type_code1 base_type_code2 interp_base_type2 op1 op2
           interp_op2 (@failv) new_base_type bound_op
           ovar bound_var
           t1 e1 t2 e2.
    Definition map_interp_cast_with_cast_op
             {t1} (e1 : @expr base_type_code1 op1 ivarf t1)
             {t2} (e2 : @expr base_type_code2 op2 interp_base_type2 t2)
      : forall (args2 : interp_all_binders_for' t2 interp_base_type2),
        @expr base_type_code1 op1 ovar (@new_type _ args2 (interp interp_op2 e2))
      := @map_interp_cast
           base_type_code1 base_type_code2 interp_base_type2 op1 op2
           interp_op2 (@failv) new_base_type bound_op
           ovar bound_var
           t1 e1 t2 e2.
  End with_var.
End language.

Global Arguments mapf_interp_cast_with_cast_op {_ _ _ _ _ _} base_type_code1_beq base_type_code1_bl failv {_} Cast is_cast new_op {ovar} {t1} e1 {t2} e2.
Global Arguments map_interp_cast_with_cast_op {_ _ _ _ _ _} base_type_code1_beq base_type_code1_bl failv {_} Cast is_cast new_op {ovar} {t1} e1 {t2} e2 args2.

Section homogenous.
  Context {base_type_code : Type}
          {interp_base_type2 : base_type_code -> Type}
          {op : flat_type base_type_code -> flat_type base_type_code -> Type}
          (interp_op2 : forall src dst, op src dst -> interp_flat_type interp_base_type2 src -> interp_flat_type interp_base_type2 dst)
          (base_type_code_beq : base_type_code -> base_type_code -> bool)
          (base_type_code_bl : forall x y, base_type_code_beq x y = true -> x = y)
          (base_type_code_lb : forall x y, x = y -> base_type_code_beq x y = true)
          (failv : forall {var t}, @exprf base_type_code op var (Tbase t))
          {new_base_type : forall t, interp_base_type2 t -> base_type_code}
          (Cast : forall var t1 t2, @exprf base_type_code op var (Tbase t1)
                                    -> @exprf base_type_code op var (Tbase t2))
          (is_cast : forall t1 t2, op t1 t2 -> bool).
  Definition MapInterpCastWithCastOp
             new_op
             {t} (e : Expr base_type_code op t) args
    : Expr base_type_code op (new_type (@new_base_type) args (Interp interp_op2 e))
    := fun var => map_interp_cast_with_cast_op base_type_code_beq base_type_code_bl (@failv) Cast is_cast new_op (e _) (e _) args.
End homogenous.