1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
|
Require Import Crypto.Reflection.Syntax.
Require Import Crypto.Reflection.SmartCast.
Require Import Crypto.Reflection.TypeUtil.
Require Import Crypto.Reflection.Inline.
Require Import Crypto.Util.Notations.
Local Open Scope expr_scope.
Local Open Scope ctype_scope.
Section language.
Context {base_type_code : Type}
{op : flat_type base_type_code -> flat_type base_type_code -> Type}
(base_type_beq : base_type_code -> base_type_code -> bool)
(base_type_bl_transparent : forall x y, base_type_beq x y = true -> x = y)
(base_type_leb : base_type_code -> base_type_code -> bool)
(Cast : forall var A A', exprf base_type_code op (var:=var) (Tbase A) -> exprf base_type_code op (var:=var) (Tbase A'))
(is_cast : forall src dst, op src dst -> bool)
(is_const : forall src dst, op src dst -> bool).
Local Infix "<=?" := base_type_leb : expr_scope.
Local Infix "=?" := base_type_beq : expr_scope.
Local Notation base_type_min := (base_type_min base_type_leb).
Local Notation SmartCast_base := (@SmartCast_base _ op _ base_type_bl_transparent Cast).
Local Notation flat_type := (flat_type base_type_code).
Local Notation exprf := (@exprf base_type_code op).
Local Notation Expr := (@Expr base_type_code op).
(** We can squash [a -> b -> c] into [a -> c] if [min a b c = min a
c], i.e., if the narrowest type we pass through in the original
case is the same as the narrowest type we pass through in the
new case. *)
Definition squash_cast {var} (a b c : base_type_code)
: @exprf var (Tbase a) -> @exprf var (Tbase c)
:= if base_type_beq (base_type_min b (base_type_min a c)) (base_type_min a c)
then SmartCast_base
else fun x => Cast _ b c (Cast _ a b x).
Fixpoint maybe_push_cast {var t} (v : @exprf var t) : option (@exprf var t)
:= match v in Syntax.exprf _ _ t return option (exprf t) with
| Var _ _ as v'
=> Some v'
| Op t1 tR opc args
=> match t1, tR return op t1 tR -> exprf t1 -> option (exprf tR) with
| Tbase b, Tbase c
=> fun opc args
=> if is_cast _ _ opc
then match @maybe_push_cast _ _ args with
| Some (Op t1 tR opc' args')
=> match t1, tR return op t1 tR -> exprf t1 -> option (exprf (Tbase c)) with
| Tbase a, Tbase b
=> fun opc' args'
=> if is_cast _ _ opc'
then Some (squash_cast a b c args')
else None
| Unit, Tbase _
=> fun opc' args'
=> if is_const _ _ opc'
then Some (SmartCast_base (Op opc' TT))
else None
| _, _ => fun _ _ => None
end opc' args'
| Some (Var _ v as v') => Some (SmartCast_base v')
| Some _ => None
| None => None
end
else None
| Unit, _
=> fun opc args
=> if is_const _ _ opc
then Some (Op opc TT)
else None
| _, _
=> fun _ _ => None
end opc args
| Pair _ _ _ _
| LetIn _ _ _ _
| TT
=> None
end.
Definition push_cast {var t} : @exprf var t -> @inline_directive _ op var t
:= match t with
| Tbase _ => fun v => match maybe_push_cast v with
| Some e => inline e
| None => default_inline v
end
| _ => default_inline
end.
Definition InlineCast {t} (e : Expr t) : Expr t
:= InlineConstGen (@push_cast) e.
End language.
|