blob: 07ab6140af49c782069b7e9cee81db795d527a91 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
(** * Context made from an associative list, without modules *)
Require Import Coq.Bool.Sumbool.
Require Import Coq.Lists.List.
Require Import Crypto.Compilers.Named.Context.
Require Import Crypto.Compilers.Named.ContextDefinitions.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Equality.
Local Open Scope list_scope.
Section ctx.
Context (key : Type)
(key_beq : key -> key -> bool)
(key_bl : forall k1 k2, key_beq k1 k2 = true -> k1 = k2)
(key_lb : forall k1 k2, k1 = k2 -> key_beq k1 k2 = true)
base_type_code (var : base_type_code -> Type)
(base_type_code_beq : base_type_code -> base_type_code -> bool)
(base_type_code_bl_transparent : forall x y, base_type_code_beq x y = true -> x = y)
(base_type_code_lb : forall x y, x = y -> base_type_code_beq x y = true).
Definition var_cast {a b} (x : var a) : option (var b)
:= match Sumbool.sumbool_of_bool (base_type_code_beq a b), Sumbool.sumbool_of_bool (base_type_code_beq b b) with
| left pf, left pf' => match eq_trans (base_type_code_bl_transparent _ _ pf) (eq_sym (base_type_code_bl_transparent _ _ pf')) with
| eq_refl => Some x
end
| right _, _ | _, right _ => None
end.
Fixpoint find (k : key) (xs : list (key * { t : _ & var t })) {struct xs}
: option { t : _ & var t }
:= match xs with
| nil => None
| k'x :: xs' =>
if key_beq k (fst k'x)
then Some (snd k'x)
else find k xs'
end.
Fixpoint remove (k : key) (xs : list (key * { t : _ & var t })) {struct xs}
: list (key * { t : _ & var t })
:= match xs with
| nil => nil
| k'x :: xs' =>
if key_beq k (fst k'x)
then remove k xs'
else k'x :: remove k xs'
end.
Definition add (k : key) (x : { t : _ & var t }) (xs : list (key * { t : _ & var t }))
: list (key * { t : _ & var t })
:= (k, x) :: xs.
Lemma find_remove_neq k k' xs (H : k <> k')
: find k (remove k' xs) = find k xs.
Proof.
induction xs as [|x xs IHxs]; [ reflexivity | simpl ].
break_innermost_match;
repeat match goal with
| [ H : key_beq _ _ = true |- _ ] => apply key_bl in H
| [ H : context[key_beq ?x ?x] |- _ ] => rewrite (key_lb x x) in H by reflexivity
| [ |- context[key_beq ?x ?x] ] => rewrite (key_lb x x) by reflexivity
| [ H : ?x = false |- context[?x] ] => rewrite H
| _ => congruence
| _ => assumption
| _ => progress subst
| _ => progress simpl
end.
Qed.
Lemma find_remove_nbeq k k' xs (H : key_beq k k' = false)
: find k (remove k' xs) = find k xs.
Proof.
rewrite find_remove_neq; [ reflexivity | intro; subst ].
rewrite key_lb in H by reflexivity; congruence.
Qed.
Definition AListContext : @Context base_type_code key var
:= {| ContextT := list (key * { t : _ & var t });
lookupb ctx n t
:= match find n ctx with
| Some (existT t' v)
=> var_cast v
| None => None
end;
extendb ctx n t v
:= add n (existT _ t v) ctx;
removeb ctx n t
:= remove n ctx;
empty := nil |}.
Lemma length_extendb (ctx : AListContext) k t v
: length (@extendb _ _ _ AListContext ctx k t v) = S (length ctx).
Proof. reflexivity. Qed.
Lemma AListContextOk : @ContextOk base_type_code key var AListContext.
Proof using base_type_code_lb key_bl key_lb.
split;
repeat first [ reflexivity
| progress simpl in *
| progress intros
| rewrite find_remove_nbeq by eassumption
| rewrite find_remove_neq by congruence
| match goal with
| [ |- context[key_beq ?x ?y] ]
=> destruct (key_beq x y) eqn:?
| [ H : key_beq ?x ?y = true |- _ ]
=> apply key_bl in H
end
| break_innermost_match_step
| progress unfold var_cast
| rewrite key_lb in * by reflexivity
| rewrite base_type_code_lb in * by reflexivity
| rewrite concat_pV
| congruence ].
Qed.
End ctx.
|