aboutsummaryrefslogtreecommitdiff
path: root/src/Specific/Framework/ArithmeticSynthesis/Base.v
blob: 490ff35c7003c9f92d5381509c09c62679a42ee1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
Require Import Coq.ZArith.ZArith Coq.ZArith.BinIntDef.
Require Import Coq.Lists.List. Import ListNotations.
Require Import Coq.micromega.Lia.
Require Import Coq.QArith.QArith_base.
Require Import Crypto.Arithmetic.Core. Import B.
Require Import Crypto.Specific.Framework.CurveParameters.
Require Import Crypto.Specific.Framework.ArithmeticSynthesis.HelperTactics.
Require Import Crypto.Util.QUtil.
Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo.
Require Crypto.Util.Tuple.
Require Import Crypto.Util.Tactics.CacheTerm.

  (* emacs for adjusting definitions *)
  (* Query replace regexp (default Definition \([a-zA-Z_0-9]+\) : \([A-Za-z0-9_]+\) := P.compute \(.*\)\.\(.*\) -> Ltac pose_\1 \1 :=\4^J  cache_term_with_type_by^J      \2^J      ltac:(let v := P.do_compute \3 in exact v)^J      \1.):  *)
  (* Query replace regexp (default Definition \([a-zA-Z_0-9]+\) : \([A-Za-z0-9_]+\) := P.compute \(.*\)\.\(.*\) -> Ltac pose_\1 \1 :=\4^J  cache_term_with_type_by^J      \2^J      ltac:(let v := P.do_compute \3 in exact v)^J      \1.):  *)
  (* Query replace regexp (default Definition \([a-zA-Z_0-9]+\) : \([A-Za-z0-9_ \.]*\) := P.compute \(.*\)\.\(.*\) -> Ltac pose_\1 \1 :=\4^J  cache_term_with_type_by^J      (\2)^J      ltac:(let v := P.do_compute \3 in exact v)^J      \1.): *)
  (* Query replace regexp (default Definition \([a-zA-Z_0-9]+\) := P.compute \(.*\)\.\(.*\) -> Ltac pose_\1 \1 :=\3^J  let v := P.do_compute \2 in cache_term v \1.):  *)

Ltac pose_r bitwidth r :=
  cache_term_with_type_by
    positive
    ltac:(let v := (eval vm_compute in (Z.to_pos (2^bitwidth))) in exact v)
           r.

Ltac pose_m s c m := (* modulus *)
  let v := (eval vm_compute in (Z.to_pos (s - Associational.eval c))) in
  cache_term v m.

Section wt.
  Import QArith Qround.
  Local Coercion QArith_base.inject_Z : Z >-> Q.
  Local Coercion Z.of_nat : nat >-> Z.
  Local Coercion Z.pos : positive >-> Z.
  Definition wt_gen (base : Q) (i:nat) : Z := 2^Qceiling(base*i).
End wt.

Section gen.
  Context (base : Q)
          (m : positive)
          (sz : nat)
          (coef_div_modulus : nat)
          (base_pos : (1 <= base)%Q).

  Local Notation wt := (wt_gen base).

  Definition sz2' := ((sz * 2) - 1)%nat.

  Definition half_sz' := (sz / 2)%nat.

  Definition m_enc'
    := Positional.encode (modulo_cps:=@modulo_cps) (div_cps:=@div_cps) (n:=sz) wt (Z.pos m).

  Lemma sz2'_nonzero
        (sz_nonzero : sz <> 0%nat)
    : sz2' <> 0%nat.
  Proof using Type. clear -sz_nonzero; cbv [sz2']; omega. Qed.

  Local Ltac Q_cbv :=
    cbv [wt_gen Qround.Qceiling QArith_base.Qmult QArith_base.Qdiv QArith_base.inject_Z QArith_base.Qden QArith_base.Qnum QArith_base.Qopp Qround.Qfloor QArith_base.Qinv QArith_base.Qle QArith_base.Qeq Z.of_nat] in *.
  Lemma wt_gen0_1 : wt 0 = 1.
  Proof using Type.
    Q_cbv; simpl.
    autorewrite with zsimplify_const; reflexivity.
  Qed.

  Lemma wt_gen_nonzero : forall i, wt i <> 0.
  Proof using base_pos.
    eapply pow_ceil_mul_nat_nonzero; [ omega | ].
    destruct base; Q_cbv; lia.
  Qed.

  Lemma wt_gen_nonneg : forall i, 0 <= wt i.
  Proof using Type. apply pow_ceil_mul_nat_nonneg; omega. Qed.

  Lemma wt_gen_pos : forall i, wt i > 0.
  Proof using base_pos.
    intro i; pose proof (wt_gen_nonzero i); pose proof (wt_gen_nonneg i).
    omega.
  Qed.

  Lemma wt_gen_multiples : forall i, wt (S i) mod (wt i) = 0.
  Proof using base_pos.
    apply pow_ceil_mul_nat_multiples; destruct base; Q_cbv; lia.
  Qed.

  Section divides.
    Lemma wt_gen_divides
      : forall i, wt (S i) / wt i > 0.
    Proof using base_pos.
      apply pow_ceil_mul_nat_divide; [ omega | ].
      destruct base; Q_cbv; lia.
    Qed.

    Lemma wt_gen_divides'
      : forall i, wt (S i) / wt i <> 0.
    Proof using base_pos.
      symmetry; apply Z.lt_neq, Z.gt_lt_iff, wt_gen_divides; assumption.
    Qed.

    Lemma wt_gen_div_bound
      : forall i, wt (S i) / wt i <= wt 1.
    Proof using base_pos.
      intro; etransitivity.
      eapply pow_ceil_mul_nat_divide_upperbound; [ omega | ].
      all:destruct base; Q_cbv; autorewrite with zsimplify_const;
        rewrite ?Pos.mul_1_l, ?Pos.mul_1_r; try assumption; omega.
    Qed.
    Lemma wt_gen_divides_chain
          carry_chain
      : forall i (H:In i carry_chain), wt (S i) / wt i <> 0.
    Proof using base_pos. intros i ?; apply wt_gen_divides'; assumption. Qed.

    Lemma wt_gen_divides_chains
          carry_chains
      : List.fold_right
          and
          True
          (List.map
             (fun carry_chain
              => forall i (H:In i carry_chain), wt (S i) / wt i <> 0)
             carry_chains).
    Proof using base_pos.
      induction carry_chains as [|carry_chain carry_chains IHcarry_chains];
        constructor; eauto using wt_gen_divides_chain.
    Qed.
  End divides.

  Definition coef'
    := (fix addm (acc: Z^sz) (ctr : nat) : Z^sz :=
          match ctr with
          | O => acc
          | S n => addm (Positional.add_cps wt acc m_enc' id) n
          end) (Positional.zeros sz) coef_div_modulus.

  Lemma coef_mod'
    : mod_eq m (Positional.eval (n:=sz) wt coef') 0.
  Proof using base_pos.
    cbv [coef' m_enc'].
    remember (Positional.zeros sz) as v eqn:Hv.
    assert (Hv' : mod_eq m (Positional.eval wt v) 0)
      by (subst v; autorewrite with push_basesystem_eval; reflexivity);
      clear Hv.
    revert dependent v.
    induction coef_div_modulus as [|n IHn]; clear coef_div_modulus;
      intros; [ assumption | ].
    rewrite IHn; [ reflexivity | ].
    pose proof wt_gen0_1.
    pose proof wt_gen_nonzero.
    pose proof div_mod.
    pose proof wt_gen_divides'.
    destruct (Nat.eq_dec sz 0).
    { subst; reflexivity. }
    { repeat autounfold; autorewrite with uncps push_id push_basesystem_eval.
      rewrite Positional.eval_encode by (intros; autorewrite with uncps; unfold id; auto).
      cbv [mod_eq] in *.
      push_Zmod; rewrite Hv'; pull_Zmod.
      reflexivity. }
  Qed.
End gen.

Ltac pose_wt base wt :=
  let v := (eval cbv [wt_gen] in (wt_gen base)) in
  cache_term v wt.

Ltac pose_sz2 sz sz2 :=
  let v := (eval vm_compute in (sz2' sz)) in
  cache_term v sz2.

Ltac pose_half_sz sz half_sz :=
  let v := (eval compute in (half_sz' sz)) in
  cache_term v half_sz.

Ltac pose_s_nonzero s s_nonzero :=
  cache_proof_with_type_by
    (s <> 0)
    ltac:(vm_decide_no_check)
           s_nonzero.

Ltac pose_sz_le_log2_m sz m sz_le_log2_m :=
  cache_proof_with_type_by
    (Z.of_nat sz <= Z.log2_up (Z.pos m))
    ltac:(vm_decide_no_check)
           sz_le_log2_m.

Ltac pose_base_pos base base_pos :=
  cache_proof_with_type_by
    ((1 <= base)%Q)
    ltac:(vm_decide_no_check)
           base_pos.

Ltac pose_m_correct m s c m_correct :=
  cache_proof_with_type_by
    (Z.pos m = s - Associational.eval c)
    ltac:(vm_decide_no_check)
           m_correct.

Ltac pose_m_enc base m sz m_enc :=
  let v := (eval vm_compute in (m_enc' base m sz)) in
  let v := (eval compute in v) in (* compute away the type arguments *)
  cache_term v m_enc.

Ltac pose_coef base m sz coef_div_modulus coef := (* subtraction coefficient *)
  let v := (eval vm_compute in (coef' base m sz coef_div_modulus)) in
  cache_term v coef.

Ltac pose_coef_mod wt coef base m sz coef_div_modulus base_pos coef_mod :=
  cache_proof_with_type_by
    (mod_eq m (Positional.eval (n:=sz) wt coef) 0)
    ltac:(vm_cast_no_check (coef_mod' base m sz coef_div_modulus base_pos))
           coef_mod.
Ltac pose_sz_nonzero sz sz_nonzero :=
  cache_proof_with_type_by
    (sz <> 0%nat)
    ltac:(vm_decide_no_check)
           sz_nonzero.

Ltac pose_wt_nonzero wt wt_nonzero :=
  cache_proof_with_type_by
    (forall i, wt i <> 0)
    ltac:(apply wt_gen_nonzero; vm_decide_no_check)
           wt_nonzero.
Ltac pose_wt_nonneg wt wt_nonneg :=
  cache_proof_with_type_by
    (forall i, 0 <= wt i)
    ltac:(apply wt_gen_nonneg; vm_decide_no_check)
           wt_nonneg.
Ltac pose_wt_divides wt wt_divides :=
  cache_proof_with_type_by
    (forall i, wt (S i) / wt i > 0)
    ltac:(apply wt_gen_divides; vm_decide_no_check)
           wt_divides.
Ltac pose_wt_divides' wt wt_divides wt_divides' :=
  cache_proof_with_type_by
    (forall i, wt (S i) / wt i <> 0)
    ltac:(apply wt_gen_divides'; vm_decide_no_check)
           wt_divides'.

Ltac pose_wt_divides_chains wt carry_chains wt_divides_chains :=
  let T := (eval cbv [carry_chains List.fold_right List.map] in
               (List.fold_right
                  and
                  True
                  (List.map
                     (fun carry_chain
                      => forall i (H:In i carry_chain), wt (S i) / wt i <> 0)
                     carry_chains))) in
  cache_proof_with_type_by
    T
    ltac:(refine (@wt_gen_divides_chains _ _ carry_chains); vm_decide_no_check)
           wt_divides_chains.

Ltac pose_wt_pos wt wt_pos :=
  cache_proof_with_type_by
    (forall i, wt i > 0)
    ltac:(apply wt_gen_pos; vm_decide_no_check)
           wt_pos.

Ltac pose_wt_multiples wt wt_multiples :=
  cache_proof_with_type_by
    (forall i, wt (S i) mod (wt i) = 0)
    ltac:(apply wt_gen_multiples; vm_decide_no_check)
           wt_multiples.

Ltac pose_c_small c wt sz c_small :=
  cache_proof_with_type_by
    (0 < Associational.eval c < wt sz)
    ltac:(vm_decide_no_check)
           c_small.

Ltac pose_base_le_bitwidth base bitwidth base_le_bitwidth := (* this is purely a sanity check *)
  cache_proof_with_type_by
    (base <= inject_Z bitwidth)%Q
    ltac:(cbv -[Z.le]; vm_decide_no_check)
           base_le_bitwidth.


Ltac pose_m_enc_bounded sz bitwidth m_enc m_enc_bounded :=
  cache_proof_with_type_by
    (Tuple.map (n:=sz) (BinInt.Z.land (Z.ones bitwidth)) m_enc = m_enc)
    ltac:(vm_compute; reflexivity)
           m_enc_bounded.