aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/InlineConstAndOpInterp.v
blob: e464d1e8a2457ff6fc0f01f273eab185804a84a8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
(** * Inline: Remove some [Let] expressions, inline constants, interpret constant operations *)
Require Import Crypto.Compilers.Syntax.
Require Import Crypto.Compilers.Wf.
Require Import Crypto.Compilers.SmartMap.
Require Import Crypto.Compilers.InlineConstAndOpWf.
Require Import Crypto.Compilers.InterpProofs.
Require Import Crypto.Compilers.Inline.
Require Import Crypto.Compilers.InlineInterp.
Require Import Crypto.Compilers.InlineConstAndOp.
Require Import Crypto.Util.Sigma Crypto.Util.Prod Crypto.Util.Option.
Require Import Crypto.Util.Tactics.BreakMatch.


Local Open Scope ctype_scope.
Section language.
  Context (base_type_code : Type)
          (interp_base_type : base_type_code -> Type)
          (op : flat_type base_type_code -> flat_type base_type_code -> Type)
          (interp_op : forall src dst, op src dst -> interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst).

  Local Notation flat_type := (flat_type base_type_code).
  Local Notation type := (type base_type_code).
  Local Notation interp_type := (interp_type interp_base_type).
  Local Notation interp_flat_type := (interp_flat_type interp_base_type).
  Local Notation exprf := (@exprf base_type_code op).
  Local Notation expr := (@expr base_type_code op).
  Local Notation Expr := (@Expr base_type_code op).
  Local Notation wff := (@wff base_type_code op).
  Local Notation wf := (@wf base_type_code op).

  Section with_invert.
    Context (invert_Const : forall s d, op s d -> @exprf interp_base_type s -> option (interp_flat_type d))
            (Const : forall t, interp_base_type t -> @exprf interp_base_type (Tbase t))
            (Hinvert_Const
             : forall s d opc e v, invert_Const s d opc e = Some v
                                   -> interp_op s d opc (interpf interp_op e) = v)
            (interpf_Const : forall t v, interpf interp_op (Const t v) = v).

    Lemma interpf_postprocess_for_const_and_op {t} (e : exprf t)
      : interpf interp_op
                (exprf_of_inline_directive
                   (postprocess_for_const_and_op interp_op invert_Const Const e))
        = interpf interp_op e.
    Proof.
      induction e; try reflexivity; simpl in *.
      all:repeat first [ reflexivity
                       | break_innermost_match_step
                       | progress cbv [SmartVarVarf]
                       | progress cbn [interpf exprf_of_inline_directive interpf_step LetIn.Let_In SmartVarVarf fst snd] in *
                       | solve [ auto ]
                       | rewrite SmartPairf_Pair
                       | apply (f_equal2 (@pair _ _))
                       | rewrite interpf_SmartPairf
                       | rewrite SmartVarfMap_compose
                       | rewrite SmartVarfMap_id
                       | setoid_rewrite interpf_Const
                       | erewrite ExprInversion.interpf_invert_PairsConst by eassumption ].
    Qed.

    Lemma interpf_inline_const_and_op_genf
          G {t} e1 e2
          (wf : @wff _ _ G t e1 e2)
          (H : forall t x x',
              List.In
                (existT (fun t : base_type_code => (exprf (Tbase t) * interp_base_type t)%type) t
                        (x, x')) G
              -> interpf interp_op x = x')
      : interpf interp_op (inline_const_and_op_genf (t:=t) interp_op invert_Const Const e1)
        = interpf interp_op e2.
    Proof.
      unfold inline_const_and_op_genf;
        eapply interpf_inline_const_genf; eauto using interpf_postprocess_for_const_and_op.
    Qed.

    Lemma interpf_inline_const_and_op_gen
          {t} e1 e2
          (Hwf : @wf _ _ t e1 e2)
      : forall x,
        interp interp_op (inline_const_and_op_gen (t:=t) interp_op invert_Const Const e1) x
        = interp interp_op e2 x.
    Proof.
      unfold inline_const_and_op_gen;
        eapply interp_inline_const_gen; eauto using interpf_postprocess_for_const_and_op.
    Qed.
  End with_invert.

  Section const_unit.
    Context (OpConst : forall t, interp_base_type t -> op Unit (Tbase t))
            (interp_op_OpConst : forall t v, interp_op _ _ (OpConst t v) tt = v).

    Lemma interpf_invert_ConstUnit s d opc e v
          (H : invert_ConstUnit interp_op opc e = Some v)
      : interp_op s d opc (interpf interp_op e) = v.
    Proof using Type.
      destruct s; simpl in *; inversion_option; subst.
      edestruct interpf; reflexivity.
    Qed.

    Lemma interpf_Const t v
      : interpf interp_op (Const OpConst (t:=t) v) = v.
    Proof using interp_op_OpConst.
      apply interp_op_OpConst.
    Qed.

    Lemma interpf_inline_const_and_opf
          G {t} e1 e2
          (wf : @wff _ _ G t e1 e2)
          (H : forall t x x',
              List.In
                (existT (fun t : base_type_code => (exprf (Tbase t) * interp_base_type t)%type) t
                        (x, x')) G
              -> interpf interp_op x = x')
      : interpf interp_op (inline_const_and_opf (t:=t) interp_op OpConst e1)
        = interpf interp_op e2.
    Proof.
      unfold inline_const_and_opf;
        eapply interpf_inline_const_genf; eauto using interpf_postprocess_for_const_and_op, interpf_invert_ConstUnit, interpf_Const.
    Qed.

    Lemma interpf_inline_const_and_op
          {t} e1 e2
          (Hwf : @wf _ _ t e1 e2)
      : forall x,
        interp interp_op (inline_const_and_op (t:=t) interp_op OpConst e1) x
        = interp interp_op e2 x.
    Proof.
      unfold inline_const_and_op;
        eapply interp_inline_const_gen; eauto using interpf_postprocess_for_const_and_op, interpf_invert_ConstUnit, interpf_Const.
    Qed.
  End const_unit.

  Lemma InterpInlineConstAndOpGen
        (invert_Const : forall var s d, op s d -> @exprf var s -> option (interp_flat_type d))
        (Const : forall var t, interp_base_type t -> @exprf var (Tbase t))

        {t} (e : Expr t)
        (wf : Wf e)
        (Hinvert_Const
         : forall s d opc e v,
            invert_Const _ s d opc e = Some v
            -> interp_op s d opc (interpf interp_op e) = v)
        (interpf_Const : forall t v, interpf interp_op (Const _ t v) = v)
    : forall x, Interp interp_op (InlineConstAndOpGen interp_op invert_Const Const e) x = Interp interp_op e x.
  Proof using Type.
    eapply InterpInlineConstGen;
      eauto using interpf_postprocess_for_const_and_op, interpf_invert_ConstUnit, interpf_Const.
  Qed.

  Lemma InterpInlineConstAndOp
        (OpConst : forall t, interp_base_type t -> op Unit (Tbase t))
        {t} (e : Expr t)
        (wf : Wf e)
        (interp_op_OpConst : forall t v, interp_op _ _ (OpConst t v) tt = v)
    : forall x, Interp interp_op (InlineConstAndOp interp_op OpConst e) x = Interp interp_op e x.
  Proof using Type.
    eapply InterpInlineConstGen;
      eauto using interpf_postprocess_for_const_and_op, interpf_invert_ConstUnit, interpf_Const.
  Qed.
End language.

(*Hint Rewrite @InterpInlineConst @interp_inline_const @interpf_inline_constf using solve_wf_side_condition : reflective_interp.*)