aboutsummaryrefslogtreecommitdiff
path: root/src/Util/IterAssocOp.v
blob: d630698e7400a491095041ef2105e3ffa4bba6f9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
Require Import Coq.Setoids.Setoid Coq.Classes.Morphisms Coq.Classes.Equivalence.
Require Import Coq.NArith.NArith Coq.PArith.BinPosDef.
Require Import Coq.Numbers.Natural.Peano.NPeano.
Require Import Crypto.Algebra.
Import Monoid.

Local Open Scope equiv_scope.

Generalizable All Variables.
Section IterAssocOp.
  Context {T eq op id} {moinoid : @monoid T eq op id} (testbit : nat -> bool).
  Local Infix "===" := eq. Local Infix "===" := eq : type_scope.

  Local Notation nat_iter_op := (ScalarMult.scalarmult_ref (add:=op) (zero:=id)).

  Lemma nat_iter_op_plus m n a :
    op (nat_iter_op m a) (nat_iter_op n a) === nat_iter_op (m + n) a.
  Proof. symmetry; eapply ScalarMult.scalarmult_add_l. Qed.

  Definition N_iter_op n a :=
    match n with
    | 0%N => id
    | Npos p => Pos.iter_op op p a
    end.

  Lemma Pos_iter_op_succ : forall p a, Pos.iter_op op (Pos.succ p) a === op a (Pos.iter_op op p a).
  Proof.
   induction p; intros; simpl; rewrite ?associative, ?IHp; reflexivity.
  Qed.

  Lemma N_iter_op_succ : forall n a, N_iter_op (N.succ n) a  === op a (N_iter_op n a).
  Proof.
    destruct n; simpl; intros; rewrite ?Pos_iter_op_succ, ?right_identity; reflexivity.
  Qed.

  Lemma N_iter_op_is_nat_iter_op : forall n a, N_iter_op n a === nat_iter_op (N.to_nat n) a.
  Proof.
    induction n using N.peano_ind; intros; rewrite ?N2Nat.inj_succ, ?N_iter_op_succ, ?IHn; reflexivity.
  Qed.

  Context {sel:bool->T->T->T} {sel_correct:forall b x y, sel b x y = if b then y else x}.

  Fixpoint funexp {A} (f : A -> A) (a : A) (exp : nat) : A :=
    match exp with
    | O => a
    | S exp' => f (funexp f a exp')
    end.

  Definition test_and_op a (state : nat * T) :=
    let '(i, acc) := state in
    let acc2 := op acc acc in
    let acc2a := op a acc2 in
    match i with
    | O => (0, acc)
    | S i' => (i', sel (testbit i') acc2 acc2a)
    end.

  Definition iter_op bound a : T :=
    snd (funexp (test_and_op a) (bound, id) bound).

  (* correctness reference *)
  Context {x:N} {testbit_correct : forall i, testbit i = N.testbit_nat x i}.

  Definition test_and_op_inv a (s : nat * T) :=
    snd s === nat_iter_op (N.to_nat (N.shiftr_nat x (fst s))) a.

  Hint Rewrite
    N.succ_double_spec
    N.add_1_r
    Nat2N.inj_succ
    Nat2N.inj_mul
    N2Nat.id: N_nat_conv
 .

  Lemma Nsucc_double_to_nat : forall n,
    N.succ_double n = N.of_nat (S (2 * N.to_nat n)).
  Proof.
    intros.
    replace 2 with (N.to_nat 2) by auto.
    autorewrite with N_nat_conv.
    reflexivity.
  Qed.

  Lemma Ndouble_to_nat : forall n,
    N.double n = N.of_nat (2 * N.to_nat n).
  Proof.
    intros.
    replace 2 with (N.to_nat 2) by auto.
    autorewrite with N_nat_conv.
    reflexivity.
  Qed.

  Lemma Nshiftr_succ : forall n i,
    N.to_nat (N.shiftr_nat n i) =
    if N.testbit_nat n i
    then S (2 * N.to_nat (N.shiftr_nat n (S i)))
    else (2 * N.to_nat (N.shiftr_nat n (S i))).
  Proof.
    intros.
    rewrite Nshiftr_nat_S.
    case_eq (N.testbit_nat n i); intro testbit_i;
      pose proof (Nshiftr_nat_spec n i 0) as shiftr_n_odd;
      rewrite Nbit0_correct in shiftr_n_odd; simpl in shiftr_n_odd;
      rewrite testbit_i in shiftr_n_odd.
    + pose proof (Ndiv2_double_plus_one (N.shiftr_nat n i) shiftr_n_odd) as Nsucc_double_shift.
      rewrite Nsucc_double_to_nat in Nsucc_double_shift.
      apply Nat2N.inj.
      rewrite Nsucc_double_shift.
      apply N2Nat.id.
    + pose proof (Ndiv2_double (N.shiftr_nat n i) shiftr_n_odd) as Nsucc_double_shift.
      rewrite Ndouble_to_nat in Nsucc_double_shift.
      apply Nat2N.inj.
      rewrite Nsucc_double_shift.
      apply N2Nat.id.
  Qed.

  Lemma test_and_op_inv_step : forall a s,
    test_and_op_inv a s ->
    test_and_op_inv a (test_and_op a s).
  Proof.
    destruct s as [i acc].
    unfold test_and_op_inv, test_and_op; simpl; intro Hpre.
    destruct i; [ apply Hpre | ].
    simpl.
    rewrite Nshiftr_succ.
    rewrite sel_correct.
    case_eq (testbit i); intro testbit_eq; simpl;
      rewrite testbit_correct in testbit_eq; rewrite testbit_eq;
      rewrite Hpre, <- plus_n_O, nat_iter_op_plus; reflexivity.
  Qed.

  Lemma test_and_op_inv_holds : forall a i s,
    test_and_op_inv a s ->
    test_and_op_inv a (funexp (test_and_op a) s i).
  Proof.
    induction i; intros; auto; simpl; apply test_and_op_inv_step; auto.
  Qed.

  Lemma funexp_test_and_op_index : forall a x acc y,
    fst (funexp (test_and_op a) (x, acc) y) = x - y.
  Proof.
    induction y; simpl; rewrite <- ?Minus.minus_n_O; try reflexivity.
    match goal with |- context[funexp ?a ?b ?c] => destruct (funexp a b c) as [i acc'] end.
    simpl in IHy.
    unfold test_and_op.
    destruct i; rewrite Nat.sub_succ_r; subst; rewrite <- IHy; simpl; reflexivity.
  Qed.

  Lemma iter_op_termination : forall a bound,
    N.size_nat x <= bound ->
    test_and_op_inv a
      (funexp (test_and_op a) (bound, id) bound) ->
    iter_op bound a === nat_iter_op (N.to_nat x) a.
  Proof.
    unfold test_and_op_inv, iter_op; simpl; intros ? ? ? Hinv.
    rewrite Hinv, funexp_test_and_op_index, Minus.minus_diag.
    reflexivity.
  Qed.

  Lemma Nsize_nat_equiv : forall n, N.size_nat n = N.to_nat (N.size n).
  Proof.
    destruct n; auto; simpl; induction p; simpl; auto; rewrite IHp, Pnat.Pos2Nat.inj_succ; reflexivity.
  Qed.

  Lemma Nshiftr_size : forall n bound, N.size_nat n <= bound ->
    N.shiftr_nat n bound = 0%N.
  Proof.
    intros.
    rewrite <- (Nat2N.id bound).
    rewrite Nshiftr_nat_equiv.
    destruct (N.eq_dec n 0); subst; [apply N.shiftr_0_l|].
    apply N.shiftr_eq_0.
    rewrite Nsize_nat_equiv in *.
    rewrite N.size_log2 in * by auto.
    apply N.le_succ_l.
    rewrite <- N.compare_le_iff.
    rewrite N2Nat.inj_compare.
    rewrite <- Compare_dec.nat_compare_le.
    rewrite Nat2N.id.
    auto.
  Qed.

  Lemma iter_op_correct : forall a bound, N.size_nat x <= bound ->
    iter_op bound a === nat_iter_op (N.to_nat x) a.
  Proof.
    intros.
    apply iter_op_termination; auto.
    apply test_and_op_inv_holds.
    unfold test_and_op_inv.
    simpl.
    rewrite Nshiftr_size by auto.
    reflexivity.
  Qed.
End IterAssocOp.

Require Import Coq.Classes.Morphisms.
Require Import Crypto.Util.Tactics.
Require Import Crypto.Util.Relations.

Global Instance Proper_funexp {T R} {Equivalence_R:Equivalence R}
  : Proper ((R==>R) ==> R ==> Logic.eq ==> R) (@funexp T).
Proof.
  repeat intro; subst.
  match goal with [n0 : nat |- _ ] => rename n0 into n; induction n end; [solve [trivial]|].
  match goal with
      [H: (_ ==> _)%signature _ _ |- _ ] =>
      etransitivity; solve [eapply (H _ _ IHn)|reflexivity]
  end.
Qed.

Global Instance Proper_test_and_op {T R} {Equivalence_R:@Equivalence T R} :
  Proper ((R==>R==>R)
            ==> pointwise_relation _ Logic.eq
            ==> (Logic.eq==>R==>R==>R)
            ==> R
            ==> (fun nt NT => Logic.eq (fst nt) (fst NT) /\ R (snd nt) (snd NT))
            ==> (fun nt NT => Logic.eq (fst nt) (fst NT) /\ R (snd nt) (snd NT))
         ) (@test_and_op T).
Proof.
  repeat match goal with
           | _ => intro
           | _ => split
           | [p:prod _ _ |- _ ] => destruct p
           | [p:and _ _ |- _ ] => destruct p
           | _ => progress (cbv [fst snd test_and_op pointwise_relation respectful] in * )
           | _ => progress subst
           | _ => progress break_match
           | _ => solve [ congruence | eauto 99 ]
         end.
Qed.

Global Instance Proper_iter_op {T R} {Equivalence_R:@Equivalence T R} :
  Proper ((R==>R==>R)
            ==> R
            ==> pointwise_relation _ Logic.eq
            ==> (Logic.eq==>R==>R==>R)
            ==> Logic.eq
            ==> R
            ==> R)
  (@iter_op T).
Proof.
  repeat match goal with
           | _ => solve [ reflexivity | congruence | eauto 99 ]
           | _ => progress eapply (Proper_funexp (R:=(fun nt NT => Logic.eq (fst nt) (fst NT) /\ R (snd nt) (snd NT))))
           | _ => progress eapply Proper_test_and_op
           | _ => progress split
           | _ => progress (cbv [fst snd pointwise_relation respectful] in * )
           | _ => intro
         end.
Qed.