aboutsummaryrefslogtreecommitdiff
path: root/src/Util/IterAssocOp.v
blob: 8d39c82755f8f9760f86794aaaa828326d98c85c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
Require Import Coq.Setoids.Setoid Coq.Classes.Morphisms Coq.Classes.Equivalence.
Require Import Coq.NArith.NArith Coq.PArith.BinPosDef.
Require Import Coq.Numbers.Natural.Peano.NPeano.
Require Import Crypto.Algebra.
Import Monoid.

Local Open Scope equiv_scope.

Generalizable All Variables.
Section IterAssocOp.
  Context {T eq op id} {moinoid : @monoid T eq op id}
          {scalar : Type} (scToN : scalar -> N)
          (testbit : scalar -> nat -> bool)
          (testbit_spec : forall x i, testbit x i = N.testbit_nat (scToN x) i).

  Fixpoint nat_iter_op n a : T :=
    match n with
    | 0%nat => id
    | S n' => op a (nat_iter_op n' a)
    end.

  Lemma nat_iter_op_plus : forall m n a,
    op (nat_iter_op m a) (nat_iter_op n a) === nat_iter_op (m + n) a.
  Proof.
    induction m; intros; simpl; rewrite ?left_identity, <-?IHm, ?associative; reflexivity.
  Qed.

  Definition N_iter_op n a :=
    match n with
    | 0%N => id
    | Npos p => Pos.iter_op op p a
    end.

  Lemma Pos_iter_op_succ : forall p a, Pos.iter_op op (Pos.succ p) a === op a (Pos.iter_op op p a).
  Proof.
   induction p; intros; simpl; rewrite ?associative, ?IHp; reflexivity.
  Qed.

  Lemma N_iter_op_succ : forall n a, N_iter_op (N.succ n) a  === op a (N_iter_op n a).
  Proof.
    destruct n; simpl; intros; rewrite ?Pos_iter_op_succ, ?right_identity; reflexivity.
  Qed.

  Lemma N_iter_op_is_nat_iter_op : forall n a, N_iter_op n a === nat_iter_op (N.to_nat n) a.
  Proof.
    induction n using N.peano_ind; intros; rewrite ?N2Nat.inj_succ, ?N_iter_op_succ, ?IHn; reflexivity.
  Qed.

  Fixpoint funexp {A} (f : A -> A) (a : A) (exp : nat) : A :=
    match exp with
    | O => a
    | S exp' => f (funexp f a exp')
    end.

  Definition test_and_op sc a (state : nat * T) :=
    let '(i, acc) := state in
    let acc2 := op acc acc in
    match i with
    | O => (0, acc)
    | S i' => (i', if testbit sc i' then op a acc2 else acc2)
    end.

  Definition iter_op sc a bound : T :=
    snd (funexp (test_and_op sc a) (bound, id) bound).

  Definition test_and_op_inv sc a (s : nat * T) :=
    snd s === nat_iter_op (N.to_nat (N.shiftr_nat (scToN sc) (fst s))) a.

  Hint Rewrite
    N.succ_double_spec
    N.add_1_r
    Nat2N.inj_succ
    Nat2N.inj_mul
    N2Nat.id: N_nat_conv
 .

  Lemma Nsucc_double_to_nat : forall n,
    N.succ_double n = N.of_nat (S (2 * N.to_nat n)).
  Proof.
    intros.
    replace 2 with (N.to_nat 2) by auto.
    autorewrite with N_nat_conv.
    reflexivity.
  Qed.

  Lemma Ndouble_to_nat : forall n,
    N.double n = N.of_nat (2 * N.to_nat n).
  Proof.
    intros.
    replace 2 with (N.to_nat 2) by auto.
    autorewrite with N_nat_conv.
    reflexivity.
  Qed.

  Lemma Nshiftr_succ : forall n i,
    N.to_nat (N.shiftr_nat n i) =
    if N.testbit_nat n i
    then S (2 * N.to_nat (N.shiftr_nat n (S i)))
    else (2 * N.to_nat (N.shiftr_nat n (S i))).
  Proof.
    intros.
    rewrite Nshiftr_nat_S.
    case_eq (N.testbit_nat n i); intro testbit_i;
      pose proof (Nshiftr_nat_spec n i 0) as shiftr_n_odd;
      rewrite Nbit0_correct in shiftr_n_odd; simpl in shiftr_n_odd;
      rewrite testbit_i in shiftr_n_odd.
    + pose proof (Ndiv2_double_plus_one (N.shiftr_nat n i) shiftr_n_odd) as Nsucc_double_shift.
      rewrite Nsucc_double_to_nat in Nsucc_double_shift.
      apply Nat2N.inj.
      rewrite Nsucc_double_shift.
      apply N2Nat.id.
    + pose proof (Ndiv2_double (N.shiftr_nat n i) shiftr_n_odd) as Nsucc_double_shift.
      rewrite Ndouble_to_nat in Nsucc_double_shift.
      apply Nat2N.inj.
      rewrite Nsucc_double_shift.
      apply N2Nat.id.
  Qed.

  Lemma test_and_op_inv_step : forall sc a s,
    test_and_op_inv sc a s ->
    test_and_op_inv sc a (test_and_op sc a s).
  Proof.
    destruct s as [i acc].
    unfold test_and_op_inv, test_and_op; simpl; intro Hpre.
    destruct i; [ apply Hpre | ].
    simpl.
    rewrite Nshiftr_succ.
    case_eq (testbit sc i); intro testbit_eq; simpl;
      rewrite testbit_spec in testbit_eq; rewrite testbit_eq;
      rewrite Hpre, <- plus_n_O, nat_iter_op_plus; reflexivity.
  Qed.

  Lemma test_and_op_inv_holds : forall sc a i s,
    test_and_op_inv sc a s ->
    test_and_op_inv sc a (funexp (test_and_op sc a) s i).
  Proof.
    induction i; intros; auto; simpl; apply test_and_op_inv_step; auto.
  Qed.

  Lemma funexp_test_and_op_index : forall n a x acc y,
    fst (funexp (test_and_op n a) (x, acc) y) === x - y.
  Proof.
    induction y; simpl; rewrite <- ?Minus.minus_n_O; try reflexivity.
    destruct (funexp (test_and_op n a) (x, acc) y) as [i acc'].
    simpl in IHy.
    unfold test_and_op.
    destruct i; rewrite Nat.sub_succ_r; subst; rewrite <- IHy; simpl; reflexivity.
  Qed.

  Lemma iter_op_termination : forall sc a bound,
    N.size_nat (scToN sc) <= bound ->
    test_and_op_inv sc a
      (funexp (test_and_op sc a) (bound, id) bound) ->
    iter_op sc a bound === nat_iter_op (N.to_nat (scToN sc)) a.
  Proof.
    unfold test_and_op_inv, iter_op; simpl; intros ? ? ? ? Hinv.
    rewrite Hinv, funexp_test_and_op_index, Minus.minus_diag.
    reflexivity.
  Qed.

  Lemma Nsize_nat_equiv : forall n, N.size_nat n = N.to_nat (N.size n).
  Proof.
    destruct n; auto; simpl; induction p; simpl; auto; rewrite IHp, Pnat.Pos2Nat.inj_succ; reflexivity.
  Qed.

  Lemma Nshiftr_size : forall n bound, N.size_nat n <= bound ->
    N.shiftr_nat n bound = 0%N.
  Proof.
    intros.
    rewrite <- (Nat2N.id bound).
    rewrite Nshiftr_nat_equiv.
    destruct (N.eq_dec n 0); subst; [apply N.shiftr_0_l|].
    apply N.shiftr_eq_0.
    rewrite Nsize_nat_equiv in *.
    rewrite N.size_log2 in * by auto.
    apply N.le_succ_l.
    rewrite <- N.compare_le_iff.
    rewrite N2Nat.inj_compare.
    rewrite <- Compare_dec.nat_compare_le.
    rewrite Nat2N.id.
    auto.
  Qed.

  Lemma iter_op_spec : forall sc a bound, N.size_nat (scToN sc) <= bound ->
    iter_op sc a bound === nat_iter_op (N.to_nat (scToN sc)) a.
  Proof.
    intros.
    apply iter_op_termination; auto.
    apply test_and_op_inv_holds.
    unfold test_and_op_inv.
    simpl.
    rewrite Nshiftr_size by auto.
    reflexivity.
  Qed.

End IterAssocOp.