aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Freeze.v
blob: e766e7aead89c3c8dc50330b98afc2064aedd0b8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
Require Import Coq.ZArith.ZArith Coq.micromega.Lia.
Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil.
Require Import Crypto.Arithmetic.BaseConversion.
Require Import Crypto.Arithmetic.Core.
Require Import Crypto.Arithmetic.ModOps.
Require Import Crypto.Arithmetic.Partition.
Require Import Crypto.Arithmetic.Saturated.
Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.ListUtil.
Require Import Crypto.Util.Prod.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Tactics.DestructHead.
Require Import Crypto.Util.ZUtil.EquivModulo.
Require Import Crypto.Util.ZUtil.Opp.
Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem.
Require Import Crypto.Util.ZUtil.Tactics.PeelLe.
Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall.
Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo.

Require Import Crypto.Util.Notations.
Local Open Scope Z_scope.

(* TODO: rename this module? (Should it be, e.g., [Rows.freeze]?) *)
Module Freeze.
  Section Freeze.
    Context weight {wprops : @weight_properties weight}.

    Definition freeze n mask (m p:list Z) : list Z :=
      let '(p, carry) := Rows.sub weight n p m in
      let '(r, carry) := Rows.conditional_add weight n mask (-carry) p m in
      r.

    Lemma freezeZ m s c y :
      m = s - c ->
      0 < c < s ->
      s <> 0 ->
      0 <= y < 2*m ->
      ((y - m) + (if (dec (-((y - m) / s) = 0)) then 0 else m)) mod s
      = y mod m.
    Proof using Type.
      clear; intros.
      transitivity ((y - m) mod m);
        repeat first [ progress intros
                     | progress subst
                     | rewrite Z.opp_eq_0_iff in *
                     | break_innermost_match_step
                     | progress autorewrite with zsimplify_fast
                     | rewrite Z.div_small_iff in * by auto
                     | progress (Z.rewrite_mod_small; push_Zmod; Z.rewrite_mod_small)
                     | progress destruct_head'_or
                     | omega ].
    Qed.

    Lemma length_freeze n mask m p :
      length m = n -> length p = n -> length (freeze n mask m p) = n.
    Proof using wprops.
      cbv [freeze Rows.conditional_add Rows.add]; eta_expand; intros.
      distr_length; try assumption; cbn; intros; destruct_head'_or; destruct_head' False; subst;
        distr_length.
      erewrite Rows.length_sum_rows by (reflexivity || eassumption || distr_length); distr_length.
    Qed.
    Lemma eval_freeze_eq n mask m p
          (n_nonzero:n<>0%nat)
          (Hmask : List.map (Z.land mask) m = m)
          (Hplen : length p = n)
          (Hmlen : length m = n)
      : Positional.eval weight n (@freeze n mask m p)
        = (Positional.eval weight n p - Positional.eval weight n m +
           (if dec (-((Positional.eval weight n p - Positional.eval weight n m) / weight n) = 0) then 0 else Positional.eval weight n m))
            mod weight n.
            (*if dec ((Positional.eval weight n p - Positional.eval weight n m) / weight n = 0)
          then Positional.eval weight n p - Positional.eval weight n m
          else Positional.eval weight n p mod weight n.*)
    Proof using wprops.
      pose proof (@weight_positive weight wprops n).
      cbv [freeze Z.equiv_modulo]; eta_expand.
      repeat first [ solve [auto]
                   | rewrite Rows.conditional_add_partitions
                   | rewrite Rows.sub_partitions
                   | rewrite Rows.sub_div
                   | rewrite eval_partition
                   | progress distr_length
                   | progress pull_Zmod (*
                   | progress break_innermost_match_step
                   | progress destruct_head'_or
                   | omega
                   | f_equal; omega
                   | rewrite Z.div_small_iff in * by (auto using (@weight_positive weight ltac:(assumption)))
                   | progress Z.rewrite_mod_small *) ].
    Qed.

    Lemma eval_freeze n c mask m p
          (n_nonzero:n<>0%nat)
          (Hc : 0 < Associational.eval c < weight n)
          (Hmask : List.map (Z.land mask) m = m)
          (modulus:=weight n - Associational.eval c)
          (Hm : Positional.eval weight n m = modulus)
          (Hp : 0 <= Positional.eval weight n p < 2*modulus)
          (Hplen : length p = n)
          (Hmlen : length m = n)
      : Positional.eval weight n (@freeze n mask m p)
        = Positional.eval weight n p mod modulus.
    Proof using wprops.
      pose proof (@weight_positive weight wprops n).
      rewrite eval_freeze_eq by assumption.
      erewrite freezeZ; try eassumption; try omega.
      f_equal; omega.
    Qed.

    Lemma freeze_partitions n c mask m p
          (n_nonzero:n<>0%nat)
          (Hc : 0 < Associational.eval c < weight n)
          (Hmask : List.map (Z.land mask) m = m)
          (modulus:=weight n - Associational.eval c)
          (Hm : Positional.eval weight n m = modulus)
          (Hp : 0 <= Positional.eval weight n p < 2*modulus)
          (Hplen : length p = n)
          (Hmlen : length m = n)
      : @freeze n mask m p = partition weight n (Positional.eval weight n p mod modulus).
    Proof using wprops.
      pose proof (@weight_positive weight wprops n).
      pose proof (fun v => Z.mod_pos_bound v (weight n) ltac:(lia)).
      pose proof (Z.mod_pos_bound (Positional.eval weight n p) modulus ltac:(lia)).
      subst modulus.
      erewrite <- eval_freeze by eassumption.
      cbv [freeze]; eta_expand.
      rewrite Rows.conditional_add_partitions by (auto; rewrite Rows.sub_partitions; auto; distr_length).
      rewrite !eval_partition by assumption.
      apply Partition.partition_Proper; [ assumption .. | ].
      cbv [Z.equiv_modulo].
      pull_Zmod; reflexivity.
    Qed.
  End Freeze.
End Freeze.
Hint Rewrite Freeze.length_freeze : distr_length.

Section freeze_mod_ops.
  Import Positional.
  Import Freeze.
  Local Coercion Z.of_nat : nat >-> Z.
  Local Coercion QArith_base.inject_Z : Z >-> Q.
  (* Design constraints:
     - inputs must be [Z] (b/c reification does not support Q)
     - internal structure must not match on the arguments (b/c reification does not support [positive]) *)
  Context (limbwidth_num limbwidth_den : Z)
          (limbwidth_good : 0 < limbwidth_den <= limbwidth_num)
          (s : Z)
          (c : list (Z*Z))
          (n : nat)
          (bitwidth : Z)
          (m_enc : list Z)
          (m_nz:s - Associational.eval c <> 0) (s_nz:s <> 0)
          (Hn_nz : n <> 0%nat).
  Local Notation bytes_weight := (@weight 8 1).
  Local Notation weight := (@weight limbwidth_num limbwidth_den).
  Let m := (s - Associational.eval c).

  Context (Hs : s = weight n).
  Context (c_small : 0 < Associational.eval c < weight n)
          (m_enc_bounded : List.map (BinInt.Z.land (Z.ones bitwidth)) m_enc = m_enc)
          (m_enc_correct : Positional.eval weight n m_enc = m)
          (Hm_enc_len : length m_enc = n).

  Definition wprops_bytes := (@wprops 8 1 ltac:(clear; lia)).
  Local Notation wprops := (@wprops limbwidth_num limbwidth_den limbwidth_good).

  Local Notation wunique := (@weight_unique limbwidth_num limbwidth_den limbwidth_good).
  Local Notation wunique_bytes := (@weight_unique 8 1 ltac:(clear; lia)).

  Local Hint Immediate (wprops).
  Local Hint Immediate (wprops_bytes).
  Local Hint Immediate (weight_0 wprops).
  Local Hint Immediate (weight_positive wprops).
  Local Hint Immediate (weight_multiples wprops).
  Local Hint Immediate (weight_divides wprops).
  Local Hint Immediate (weight_0 wprops_bytes).
  Local Hint Immediate (weight_positive wprops_bytes).
  Local Hint Immediate (weight_multiples wprops_bytes).
  Local Hint Immediate (weight_divides wprops_bytes).
  Local Hint Immediate (wunique) (wunique_bytes).
  Local Hint Resolve (wunique) (wunique_bytes).

  Definition bytes_n
    := Eval cbv [Qceiling Qdiv inject_Z Qfloor Qmult Qopp Qnum Qden Qinv Pos.mul]
      in Z.to_nat (Qceiling (Z.log2_up (weight n) / 8)).

  Lemma weight_bytes_weight_matches
    : weight n <= bytes_weight bytes_n.
  Proof using limbwidth_good.
    clear -limbwidth_good.
    cbv [weight bytes_n].
    autorewrite with zsimplify_const.
    rewrite Z.log2_up_pow2, !Z2Nat.id, !Z.opp_involutive by (Z.div_mod_to_quot_rem; nia).
    Z.peel_le.
    Z.div_mod_to_quot_rem; nia.
  Qed.

  Definition to_bytes (v : list Z)
    := BaseConversion.convert_bases weight bytes_weight n bytes_n v.

  Definition from_bytes (v : list Z)
    := BaseConversion.convert_bases bytes_weight weight bytes_n n v.

  Definition freeze_to_bytesmod (f : list Z) : list Z
    := to_bytes (freeze weight n (Z.ones bitwidth) m_enc f).

  Definition to_bytesmod (f : list Z) : list Z
    := to_bytes f.

  Definition from_bytesmod (f : list Z) : list Z
    := from_bytes f.

  Lemma bytes_nz : bytes_n <> 0%nat.
  Proof using limbwidth_good Hn_nz.
    clear -limbwidth_good Hn_nz.
    cbv [bytes_n].
    cbv [Qceiling Qdiv inject_Z Qfloor Qmult Qopp Qnum Qden Qinv].
    autorewrite with zsimplify_const.
    change (Z.pos (1*8)) with 8.
    cbv [weight].
    rewrite Z.log2_up_pow2 by (Z.div_mod_to_quot_rem; nia).
    autorewrite with zsimplify_fast.
    rewrite <- Z2Nat.inj_0, Z2Nat.inj_iff by (Z.div_mod_to_quot_rem; nia).
    Z.div_mod_to_quot_rem; nia.
  Qed.

  Lemma bytes_n_big : weight n <= bytes_weight bytes_n.
  Proof using limbwidth_good Hn_nz.
    clear -limbwidth_good Hn_nz.
    cbv [bytes_n bytes_weight].
    Z.peel_le.
    rewrite Z.log2_up_pow2 by (Z.div_mod_to_quot_rem; nia).
    autorewrite with zsimplify_fast.
    rewrite Z2Nat.id by (Z.div_mod_to_quot_rem; nia).
    Z.div_mod_to_quot_rem; nia.
  Qed.

  Lemma eval_to_bytes
    : forall (f : list Z)
        (Hf : length f = n),
      eval bytes_weight bytes_n (to_bytes f) = eval weight n f.
  Proof using limbwidth_good Hn_nz.
    generalize wprops wprops_bytes; clear -Hn_nz limbwidth_good.
    intros.
    cbv [to_bytes].
    rewrite BaseConversion.eval_convert_bases
      by (auto using bytes_nz; distr_length; auto using wprops).
    reflexivity.
  Qed.

  Lemma to_bytes_partitions
    : forall (f : list Z)
             (Hf : length f = n)
             (Hf_small : 0 <= eval weight n f < weight n),
      to_bytes f = partition bytes_weight bytes_n (Positional.eval weight n f).
  Proof using Hn_nz limbwidth_good.
    clear -Hn_nz limbwidth_good.
    intros; cbv [to_bytes].
    pose proof weight_bytes_weight_matches.
    apply BaseConversion.convert_bases_partitions; eauto; lia.
  Qed.

  Lemma eval_to_bytesmod
    : forall (f : list Z)
             (Hf : length f = n)
             (Hf_small : 0 <= eval weight n f < weight n),
      eval bytes_weight bytes_n (to_bytesmod f) = eval weight n f
      /\ to_bytesmod f = partition bytes_weight bytes_n (Positional.eval weight n f).
  Proof using Hn_nz limbwidth_good.
    split; apply eval_to_bytes || apply to_bytes_partitions; assumption.
  Qed.

  Lemma eval_freeze_to_bytesmod_and_partitions
    : forall (f : list Z)
        (Hf : length f = n)
        (Hf_bounded : 0 <= eval weight n f < 2 * m),
      (eval bytes_weight bytes_n (freeze_to_bytesmod f)) = (eval weight n f) mod m
      /\ freeze_to_bytesmod f = partition bytes_weight bytes_n (Positional.eval weight n f mod m).
  Proof using m_enc_correct Hs limbwidth_good Hn_nz c_small Hm_enc_len m_enc_bounded.
    clear -m_enc_correct Hs limbwidth_good Hn_nz c_small Hm_enc_len m_enc_bounded.
    intros; subst m s.
    cbv [freeze_to_bytesmod].
    rewrite eval_to_bytes, to_bytes_partitions;
      erewrite ?eval_freeze by eauto using wprops;
      autorewrite with distr_length; eauto.
    Z.div_mod_to_quot_rem; nia.
  Qed.

  Lemma eval_freeze_to_bytesmod
    : forall (f : list Z)
        (Hf : length f = n)
        (Hf_bounded : 0 <= eval weight n f < 2 * m),
      (eval bytes_weight bytes_n (freeze_to_bytesmod f)) = (eval weight n f) mod m.
  Proof using m_enc_correct Hs limbwidth_good Hn_nz c_small Hm_enc_len m_enc_bounded.
    intros; now apply eval_freeze_to_bytesmod_and_partitions.
  Qed.

  Lemma freeze_to_bytesmod_partitions
    : forall (f : list Z)
        (Hf : length f = n)
        (Hf_bounded : 0 <= eval weight n f < 2 * m),
      freeze_to_bytesmod f = partition bytes_weight bytes_n (Positional.eval weight n f mod m).
  Proof using m_enc_correct Hs limbwidth_good Hn_nz c_small Hm_enc_len m_enc_bounded.
    intros; now apply eval_freeze_to_bytesmod_and_partitions.
  Qed.

  Lemma eval_from_bytes
    : forall (f : list Z)
        (Hf : length f = bytes_n),
      eval weight n (from_bytes f) = eval bytes_weight bytes_n f.
  Proof using limbwidth_good Hn_nz.
    generalize wprops wprops_bytes; clear -Hn_nz limbwidth_good.
    intros.
    cbv [from_bytes].
    rewrite BaseConversion.eval_convert_bases
      by (auto using bytes_nz; distr_length; auto using wprops).
    reflexivity.
  Qed.

  Lemma from_bytes_partitions
    : forall (f : list Z)
             (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n),
      from_bytes f = partition weight n (Positional.eval bytes_weight bytes_n f).
  Proof using limbwidth_good.
    clear -limbwidth_good.
    intros; cbv [from_bytes].
    pose proof weight_bytes_weight_matches.
    apply BaseConversion.convert_bases_partitions; eauto; lia.
  Qed.

  Lemma eval_from_bytesmod
    : forall (f : list Z)
             (Hf : length f = bytes_n),
      eval weight n (from_bytesmod f) = eval bytes_weight bytes_n f.
  Proof using Hn_nz limbwidth_good. apply eval_from_bytes. Qed.

  Lemma from_bytesmod_partitions
    : forall (f : list Z)
             (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n),
      from_bytesmod f = partition weight n (Positional.eval bytes_weight bytes_n f).
  Proof using limbwidth_good. apply from_bytes_partitions. Qed.

  Lemma eval_from_bytesmod_and_partitions
    : forall (f : list Z)
             (Hf : length f = bytes_n)
             (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n),
      eval weight n (from_bytesmod f) = eval bytes_weight bytes_n f
      /\ from_bytesmod f = partition weight n (Positional.eval bytes_weight bytes_n f).
  Proof using limbwidth_good Hn_nz.
    now (split; [ apply eval_from_bytesmod | apply from_bytes_partitions ]).
  Qed.
End freeze_mod_ops.
Hint Rewrite eval_freeze_to_bytesmod eval_to_bytes eval_to_bytesmod eval_from_bytes eval_from_bytesmod : push_eval.