aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/MapBaseType.v
blob: 554bba4a2ddbf743ac91cb74b2124e16b6b06e55 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
Require Import Coq.Bool.Sumbool.
Require Import Crypto.Compilers.SmartMap.
Require Import Crypto.Compilers.Syntax.
Require Import Crypto.Compilers.ExprInversion.

Section language.
  Context {base_type_code1 base_type_code2 : Type}
          {op1 : flat_type base_type_code1 -> flat_type base_type_code1 -> Type}
          {op2 : flat_type base_type_code2 -> flat_type base_type_code2 -> Type}
          (f_base : base_type_code1 -> base_type_code2)
          (f_op : forall var s d,
              op1 s d
              -> exprf base_type_code1 op1 (var:=var) s
              -> option (op2 (lift_flat_type f_base s) (lift_flat_type f_base d))).

  Section with_var.
    Context {var1 : base_type_code1 -> Type}
            {var2 : base_type_code2 -> Type}
            (f_var12 : forall t, var1 t -> var2 (f_base t))
            (f_var21 : forall t, var2 (f_base t) -> var1 t)
            (failb : forall t, exprf _ op2 (var:=var2) (Tbase t)).

    Local Notation failf t
      := (SmartPairf (SmartValf _ failb t)).

    Fixpoint mapf_base_type
             {t} (e : exprf base_type_code1 op1 (var:=var1) t)
      : exprf base_type_code2 op2 (var:=var2) (lift_flat_type f_base t)
      := match e in exprf _ _ t return exprf _ _ (lift_flat_type f_base t) with
         | TT => TT
         | Var t x => Var (f_var12 _ x)
         | Op t1 tR opc args
           => let opc := f_op _ _ _ opc args in
              let args := @mapf_base_type _ args in
              match opc with
              | Some opc => Op opc args
              | None => failf _
              end
         | LetIn tx ex tC eC
           => let ex := @mapf_base_type _ ex in
              let eC := fun x => @mapf_base_type _ (eC x) in
              LetIn ex (fun x => eC (untransfer_interp_flat_type (t:=tx) f_base f_var21 x))
         | Pair tx ex ty ey
           => let ex := @mapf_base_type _ ex in
              let ey := @mapf_base_type _ ey in
              Pair ex ey
         end.

    Definition map_base_type
               {t} (e : expr base_type_code1 op1 t)
      : expr base_type_code2 op2 (Arrow (lift_flat_type f_base (domain t)) (lift_flat_type f_base (codomain t)))
      := let f := invert_Abs e in
         let f := fun x => mapf_base_type (f x) in
         Abs (src:=lift_flat_type f_base (domain t))
             (fun x => f (untransfer_interp_flat_type _ f_var21 x)).
  End with_var.

  Section bool_gen.
    Context (check_base_type : base_type_code1 -> bool)
            {var : base_type_code1 -> Type}
            (val : forall t, var t).

    Fixpoint check_mapf_base_type_gen
             {t} (e : exprf base_type_code1 op1 (var:=var) t)
    : bool
      := match e with
         | TT => true
         | Var t x => check_base_type t
         | Op t1 tR opc args
           => let opc := f_op _ _ _ opc args in
              let check_args := @check_mapf_base_type_gen _ args in
              match opc with
              | Some opc => check_args
              | None => false
              end
         | LetIn tx ex tC eC
           => let check_ex := @check_mapf_base_type_gen _ ex in
              let check_eC := fun x => @check_mapf_base_type_gen _ (eC x) in
              andb check_ex (check_eC (SmartValf _ val _))
         | Pair tx ex ty ey
           => let check_ex := @check_mapf_base_type_gen _ ex in
              let check_ey := @check_mapf_base_type_gen _ ey in
              andb check_ex check_ey
         end.

    Definition check_map_base_type_gen
               {t} (e : expr base_type_code1 op1 (var:=var) t)
      : bool
      := let f := invert_Abs e in
         let f := fun x => check_mapf_base_type_gen (f x) in
         f (SmartValf _ val _).
  End bool_gen.

  Section bool.
    Definition check_mapf_base_type check_base_type {t} e
      := @check_mapf_base_type_gen check_base_type (fun _ => unit) (fun _ => tt) t e.
    Definition check_map_base_type check_base_type {t} e
      := @check_map_base_type_gen check_base_type (fun _ => unit) (fun _ => tt) t e.
  End bool.

  Definition MapBaseType'
             (failb : forall var t, exprf _ op2 (var:=var) (Tbase t))
             {t} (e : Expr base_type_code1 op1 t)
    : Expr base_type_code2 op2 (Arrow (lift_flat_type f_base (domain t)) (lift_flat_type f_base (codomain t)))
    := fun var => map_base_type
                    (var1:=fun t => var (f_base t)) (var2:=var)
                    (fun _ x => x) (fun _ x => x) (failb _) (e _).

  Definition MapBaseType
             (failb : forall var t, exprf _ op2 (var:=var) (Tbase t))
             {t} (e : Expr base_type_code1 op1 t)
    : option (Expr base_type_code2 op2 (Arrow (lift_flat_type f_base (domain t)) (lift_flat_type f_base (codomain t))))
    := if check_map_base_type (fun _ => true (* any base type is allowed *)) (e _)
       then Some (MapBaseType' failb e)
       else None.
End language.