aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/Named/AListContext.v
blob: 07ab6140af49c782069b7e9cee81db795d527a91 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
(** * Context made from an associative list, without modules *)
Require Import Coq.Bool.Sumbool.
Require Import Coq.Lists.List.
Require Import Crypto.Compilers.Named.Context.
Require Import Crypto.Compilers.Named.ContextDefinitions.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Equality.

Local Open Scope list_scope.
Section ctx.
  Context (key : Type)
          (key_beq : key -> key -> bool)
          (key_bl : forall k1 k2, key_beq k1 k2 = true -> k1 = k2)
          (key_lb : forall k1 k2, k1 = k2 -> key_beq k1 k2 = true)
          base_type_code (var : base_type_code -> Type)
          (base_type_code_beq : base_type_code -> base_type_code -> bool)
          (base_type_code_bl_transparent : forall x y, base_type_code_beq x y = true -> x = y)
          (base_type_code_lb : forall x y, x = y -> base_type_code_beq x y = true).

  Definition var_cast {a b} (x : var a) : option (var b)
    := match Sumbool.sumbool_of_bool (base_type_code_beq a b), Sumbool.sumbool_of_bool (base_type_code_beq b b) with
       | left pf, left pf' => match eq_trans (base_type_code_bl_transparent _ _ pf) (eq_sym (base_type_code_bl_transparent _ _ pf')) with
                              | eq_refl => Some x
                              end
       | right _, _ | _, right _ => None
       end.

  Fixpoint find (k : key) (xs : list (key * { t : _ & var t })) {struct xs}
    : option { t : _ & var t }
    := match xs with
       | nil => None
       | k'x :: xs' =>
         if key_beq k (fst k'x)
         then Some (snd k'x)
         else find k xs'
       end.

  Fixpoint remove (k : key) (xs : list (key * { t : _ & var t })) {struct xs}
    : list (key * { t : _ & var t })
    := match xs with
       | nil => nil
       | k'x :: xs' =>
         if key_beq k (fst k'x)
         then remove k xs'
         else k'x :: remove k xs'
       end.

  Definition add (k : key) (x : { t : _ & var t }) (xs : list (key * { t : _ & var t }))
    : list (key * { t : _ & var t })
    := (k, x) :: xs.

  Lemma find_remove_neq k k' xs (H : k <> k')
    : find k (remove k' xs) = find k xs.
  Proof.
    induction xs as [|x xs IHxs]; [ reflexivity | simpl ].
    break_innermost_match;
      repeat match goal with
             | [ H : key_beq _ _ = true |- _ ] => apply key_bl in H
             | [ H : context[key_beq ?x ?x] |- _ ] => rewrite (key_lb x x) in H by reflexivity
             | [ |- context[key_beq ?x ?x] ] => rewrite (key_lb x x) by reflexivity
             | [ H : ?x = false |- context[?x] ] => rewrite H
             | _ => congruence
             | _ => assumption
             | _ => progress subst
             | _ => progress simpl
             end.
  Qed.

  Lemma find_remove_nbeq k k' xs (H : key_beq k k' = false)
    : find k (remove k' xs) = find k xs.
  Proof.
    rewrite find_remove_neq; [ reflexivity | intro; subst ].
    rewrite key_lb in H by reflexivity; congruence.
  Qed.

  Definition AListContext : @Context base_type_code key var
    := {| ContextT := list (key * { t : _ & var t });
          lookupb ctx n t
          := match find n ctx with
             | Some (existT t' v)
               => var_cast v
             | None => None
             end;
          extendb ctx n t v
          := add n (existT _ t v) ctx;
          removeb ctx n t
          := remove n ctx;
          empty := nil |}.

  Lemma length_extendb (ctx : AListContext) k t v
    : length (@extendb _ _ _ AListContext ctx k t v) = S (length ctx).
  Proof. reflexivity. Qed.

  Lemma AListContextOk : @ContextOk base_type_code key var AListContext.
  Proof using base_type_code_lb key_bl key_lb.
    split;
      repeat first [ reflexivity
                   | progress simpl in *
                   | progress intros
                   | rewrite find_remove_nbeq by eassumption
                   | rewrite find_remove_neq by congruence
                   | match goal with
                     | [ |- context[key_beq ?x ?y] ]
                       => destruct (key_beq x y) eqn:?
                     | [ H : key_beq ?x ?y = true |- _ ]
                       => apply key_bl in H
                     end
                   | break_innermost_match_step
                   | progress unfold var_cast
                   | rewrite key_lb in * by reflexivity
                   | rewrite base_type_code_lb in * by reflexivity
                   | rewrite concat_pV
                   | congruence ].
  Qed.
End ctx.