aboutsummaryrefslogtreecommitdiff
path: root/src/Testbit.v
blob: 2bfcc3df655d301d07b1a63d1b4af0b7242918e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
Require Import Coq.Lists.List.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.BaseSystem Crypto.BaseSystemProofs.
Require Import Coq.ZArith.ZArith Coq.ZArith.Zdiv.
Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith.
Import Nat.

Local Open Scope Z.

Definition testbit (limb_width n : nat) (us : list Z) :=
  Z.testbit (nth_default 0 us (n / limb_width)) (Z.of_nat (n mod limb_width)).

(* identical limb widths *)
Definition uniform_base (l : list Z) r := forall n d, (n < length l)%nat ->
  nth n l d = r ^ (Z.of_nat n).

Definition successive_powers (l : list Z) r := forall n d, (S n < length l)%nat ->
  nth (S n) l d = r * nth n l d.

Fixpoint unfold_bits (limb_width : Z) (us : list Z) :=
  match us with
  | nil => 0
  | u0 :: us' => Z.land u0 (Z.ones limb_width) + Z.shiftl (unfold_bits limb_width us') limb_width
  end.

Lemma uniform_base_first : forall b0 bs r,
  uniform_base (b0 :: bs) r -> b0 = 1.
Proof.
  boring.
  match goal with
  | [ H : uniform_base _ _ |- _ ] => unfold uniform_base in H;
    specialize (H 0%nat 0); simpl in H; eapply H; omega
  end.
Qed.

Lemma uniform_base_second : forall b0 b1 bs r,
  uniform_base (b0 :: b1 :: bs) r -> b1 = r.
Proof.
  boring.
  match goal with
  | [ H : uniform_base _ _ |- _ ] => unfold uniform_base in H;
    specialize (H 1%nat 0); cbv [nth length] in H;
    rewrite Z.pow_1_r in H; apply H; omega
  end.
Qed.


Lemma successive_powers_second : forall b0 b1 bs r,
  successive_powers (b0 :: b1 :: bs) r -> b1 = r * b0.
Proof.
  boring.
  match goal with
  | [ H : successive_powers _ _ |- _ ] => unfold uniform_base in H;
    specialize (H 0%nat 0); cbv [nth length] in H; apply H; omega
  end.
Qed.

Ltac uniform_base_subst :=
  match goal with
  | [H : uniform_base (?b0 :: ?b1 :: _) _ |- _ ]=>
      erewrite (uniform_base_first  b0); eauto;
      erewrite (uniform_base_second b0 b1); eauto
  end.

Lemma successive_powers_tail : forall x0 xs r, successive_powers (x0 :: xs) r ->
  successive_powers xs r.
Proof.
  unfold successive_powers; boring.
Qed.

Lemma decode_uniform_shift : forall us base limb_width, (S (length us) <= length base)%nat ->
  successive_powers base (2 ^ limb_width) ->
  decode base (mul_each (2 ^ limb_width) us) = decode base (0 :: us).
Proof.
  unfold decode, decode', accumulate, mul_each;
    induction us; induction base; try solve [boring].
  intros; simpl in *.
  destruct base; [ boring; omega | ].
  simpl; f_equal.
  + erewrite (successive_powers_second _ z); eauto.
    ring.
  + apply IHus; [ omega | ].
    eapply successive_powers_tail; eassumption.
Qed.

Lemma uniform_base_successive_powers : forall xs r, uniform_base xs r ->
  successive_powers xs r.
Proof.
  unfold uniform_base, successive_powers; intros ? ? G n ? ?.
  do 2 rewrite G by omega.
  rewrite Nat2Z.inj_succ.
  apply Z.pow_succ_r.
  apply Nat2Z.is_nonneg.
Qed.

Lemma uniform_base_BaseVector : forall base r, (r > 0) -> (0 < length base)%nat ->
  uniform_base base r -> BaseVector base.
Proof.
  unfold uniform_base.
  intros ? ? r_gt_0 base_nonempty uniform.
  constructor.
  + intros b In_b_base.
    apply In_nth_error_value in In_b_base.
    destruct In_b_base as [x nth_error_x].
    pose proof (nth_error_value_length _ _ _ _ nth_error_x) as index_bound.
    specialize (uniform x 0 index_bound).
    rewrite <- nth_default_eq in uniform.
    erewrite nth_error_value_eq_nth_default in uniform; eauto.
    subst.
    destruct r; [ | apply pos_pow_nat_pos | pose proof (Zlt_neg_0 p) ] ; omega.
  + intros.
    rewrite nth_default_eq.
    rewrite uniform; auto.
  + intros.
    subst b.
    subst r0.
    repeat rewrite nth_default_eq.
    repeat rewrite uniform by omega; auto.
    rewrite <- Z.pow_add_r by apply Nat2Z.is_nonneg.
    rewrite Nat2Z.inj_add.
    rewrite <- Z.pow_sub_r, <- Z.pow_add_r by omega.
    f_equal.
    omega.
Qed.

Definition no_overflow us limb_width := forall n,
  Z.land (nth_default 0 us n) (Z.ones limb_width) = nth_default 0 us n.

Lemma no_overflow_cons : forall u0 us limb_width,
  no_overflow (u0 :: us) limb_width -> Z.land u0 (Z.ones limb_width) = u0.
Proof.
  unfold no_overflow; intros ? ? ? no_overflow_u0_us.
  specialize (no_overflow_u0_us 0%nat).
  rewrite nth_default_cons in no_overflow_u0_us.
  assumption.
Qed.

Lemma no_overflow_tail : forall u0 us limb_width,
  no_overflow (u0 :: us) limb_width -> no_overflow us limb_width.
Proof.
  unfold no_overflow; intros.
  erewrite <- nth_default_cons_S; eauto.
Qed.

Lemma unfold_bits_decode : forall limb_width us base, (0 <= limb_width) ->
  (length us <= length base)%nat -> (0 < length base)%nat ->
  no_overflow us limb_width ->
  uniform_base base (2 ^ limb_width) ->
  BaseSystem.decode base us = unfold_bits limb_width us.
Proof.
  induction us; boring.
  rewrite <- (IHus base) by (omega || eauto using no_overflow_tail).
  rewrite decode_cons by (eapply uniform_base_BaseVector; eauto;
    rewrite gt_lt_symmetry; apply Z_pow_gt0; omega).
  simpl.
  f_equal.
  + symmetry. eapply no_overflow_cons; eauto.
  + rewrite Z.shiftl_mul_pow2 by assumption.
    erewrite <- decode_uniform_shift; eauto using uniform_base_successive_powers.
    rewrite mul_each_rep.
    unfold decode.
    apply Z.mul_comm.
Qed.


Lemma unfold_bits_indexing : forall us i limb_width, (0 < limb_width)%nat ->
  no_overflow us (Z.of_nat limb_width) ->
  nth_default 0 us i =
  Z.land (Z.shiftr (unfold_bits (Z.of_nat limb_width) us) (Z.of_nat (i * limb_width))) (Z.ones (Z.of_nat limb_width)).
Proof.
  induction us; intros.
  + rewrite nth_default_nil.
    rewrite Z.shiftr_0_l.
    auto using Z.land_0_l.
  + destruct i; simpl.
    - rewrite nth_default_cons.
      rewrite Z.shiftr_0_r, Z_land_add_land by omega.
      symmetry; eapply no_overflow_cons; eauto.
    - rewrite nth_default_cons_S.
      erewrite IHus; eauto using no_overflow_tail.
      remember (i * limb_width)%nat as k.
      rewrite Z_shiftr_add_land by omega.
      replace (limb_width + k - limb_width)%nat with k by omega.
      reflexivity.
Qed.

Lemma unfold_bits_testbit : forall limb_width us n, (0 < limb_width)%nat ->
  no_overflow us (Z.of_nat limb_width) ->
  testbit limb_width n us = Z.testbit (unfold_bits (Z.of_nat limb_width) us) (Z.of_nat n).
Proof.
  unfold testbit; intros.
  erewrite unfold_bits_indexing; eauto.
  rewrite <- Z_testbit_low by
    (split; try apply Nat2Z.inj_lt; pose proof (mod_bound_pos n limb_width); omega).
  rewrite Z.shiftr_spec by apply Nat2Z.is_nonneg.
  f_equal.
  rewrite <- Nat2Z.inj_add.
  apply Z2Nat.inj; try apply Nat2Z.is_nonneg.
  rewrite !Nat2Z.id.
  symmetry.
  rewrite Nat.add_comm, Nat.mul_comm.
  apply Nat.div_mod; omega.
Qed.

Lemma testbit_spec : forall n us base limb_width, (0 < limb_width)%nat ->
  (0 < length base)%nat -> (length us <= length base)%nat ->
  no_overflow us (Z.of_nat limb_width) ->
  uniform_base base (2 ^ (Z.of_nat limb_width)) ->
  testbit limb_width n us = Z.testbit (BaseSystem.decode base us) (Z.of_nat n).
Proof.
  intros.
  erewrite unfold_bits_testbit, unfold_bits_decode; eauto; omega.
Qed.