blob: 554bba4a2ddbf743ac91cb74b2124e16b6b06e55 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
|
Require Import Coq.Bool.Sumbool.
Require Import Crypto.Compilers.SmartMap.
Require Import Crypto.Compilers.Syntax.
Require Import Crypto.Compilers.ExprInversion.
Section language.
Context {base_type_code1 base_type_code2 : Type}
{op1 : flat_type base_type_code1 -> flat_type base_type_code1 -> Type}
{op2 : flat_type base_type_code2 -> flat_type base_type_code2 -> Type}
(f_base : base_type_code1 -> base_type_code2)
(f_op : forall var s d,
op1 s d
-> exprf base_type_code1 op1 (var:=var) s
-> option (op2 (lift_flat_type f_base s) (lift_flat_type f_base d))).
Section with_var.
Context {var1 : base_type_code1 -> Type}
{var2 : base_type_code2 -> Type}
(f_var12 : forall t, var1 t -> var2 (f_base t))
(f_var21 : forall t, var2 (f_base t) -> var1 t)
(failb : forall t, exprf _ op2 (var:=var2) (Tbase t)).
Local Notation failf t
:= (SmartPairf (SmartValf _ failb t)).
Fixpoint mapf_base_type
{t} (e : exprf base_type_code1 op1 (var:=var1) t)
: exprf base_type_code2 op2 (var:=var2) (lift_flat_type f_base t)
:= match e in exprf _ _ t return exprf _ _ (lift_flat_type f_base t) with
| TT => TT
| Var t x => Var (f_var12 _ x)
| Op t1 tR opc args
=> let opc := f_op _ _ _ opc args in
let args := @mapf_base_type _ args in
match opc with
| Some opc => Op opc args
| None => failf _
end
| LetIn tx ex tC eC
=> let ex := @mapf_base_type _ ex in
let eC := fun x => @mapf_base_type _ (eC x) in
LetIn ex (fun x => eC (untransfer_interp_flat_type (t:=tx) f_base f_var21 x))
| Pair tx ex ty ey
=> let ex := @mapf_base_type _ ex in
let ey := @mapf_base_type _ ey in
Pair ex ey
end.
Definition map_base_type
{t} (e : expr base_type_code1 op1 t)
: expr base_type_code2 op2 (Arrow (lift_flat_type f_base (domain t)) (lift_flat_type f_base (codomain t)))
:= let f := invert_Abs e in
let f := fun x => mapf_base_type (f x) in
Abs (src:=lift_flat_type f_base (domain t))
(fun x => f (untransfer_interp_flat_type _ f_var21 x)).
End with_var.
Section bool_gen.
Context (check_base_type : base_type_code1 -> bool)
{var : base_type_code1 -> Type}
(val : forall t, var t).
Fixpoint check_mapf_base_type_gen
{t} (e : exprf base_type_code1 op1 (var:=var) t)
: bool
:= match e with
| TT => true
| Var t x => check_base_type t
| Op t1 tR opc args
=> let opc := f_op _ _ _ opc args in
let check_args := @check_mapf_base_type_gen _ args in
match opc with
| Some opc => check_args
| None => false
end
| LetIn tx ex tC eC
=> let check_ex := @check_mapf_base_type_gen _ ex in
let check_eC := fun x => @check_mapf_base_type_gen _ (eC x) in
andb check_ex (check_eC (SmartValf _ val _))
| Pair tx ex ty ey
=> let check_ex := @check_mapf_base_type_gen _ ex in
let check_ey := @check_mapf_base_type_gen _ ey in
andb check_ex check_ey
end.
Definition check_map_base_type_gen
{t} (e : expr base_type_code1 op1 (var:=var) t)
: bool
:= let f := invert_Abs e in
let f := fun x => check_mapf_base_type_gen (f x) in
f (SmartValf _ val _).
End bool_gen.
Section bool.
Definition check_mapf_base_type check_base_type {t} e
:= @check_mapf_base_type_gen check_base_type (fun _ => unit) (fun _ => tt) t e.
Definition check_map_base_type check_base_type {t} e
:= @check_map_base_type_gen check_base_type (fun _ => unit) (fun _ => tt) t e.
End bool.
Definition MapBaseType'
(failb : forall var t, exprf _ op2 (var:=var) (Tbase t))
{t} (e : Expr base_type_code1 op1 t)
: Expr base_type_code2 op2 (Arrow (lift_flat_type f_base (domain t)) (lift_flat_type f_base (codomain t)))
:= fun var => map_base_type
(var1:=fun t => var (f_base t)) (var2:=var)
(fun _ x => x) (fun _ x => x) (failb _) (e _).
Definition MapBaseType
(failb : forall var t, exprf _ op2 (var:=var) (Tbase t))
{t} (e : Expr base_type_code1 op1 t)
: option (Expr base_type_code2 op2 (Arrow (lift_flat_type f_base (domain t)) (lift_flat_type f_base (codomain t))))
:= if check_map_base_type (fun _ => true (* any base type is allowed *)) (e _)
then Some (MapBaseType' failb e)
else None.
End language.
|