aboutsummaryrefslogtreecommitdiff
path: root/src/Specific/Framework/ArithmeticSynthesis/Defaults.v
blob: 1d3c3c89c8e6f1565648542779f82828d61d7f77 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
Require Import Coq.ZArith.ZArith Coq.ZArith.BinIntDef.
Require Import Coq.QArith.QArith_base.
Require Import Coq.Lists.List. Import ListNotations.
Require Import Crypto.Arithmetic.CoreUnfolder.
Require Import Crypto.Arithmetic.Core. Import B.
Require Import Crypto.Arithmetic.PrimeFieldTheorems.
Require Crypto.Specific.Framework.CurveParameters.
Require Import Crypto.Specific.Framework.ArithmeticSynthesis.HelperTactics.
Require Import Crypto.Specific.Framework.ArithmeticSynthesis.Base.
Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Tactics.DestructHead.
Require Import Crypto.Util.Tactics.PoseTermWithName.
Require Import Crypto.Util.Tactics.CacheTerm.
Require Crypto.Util.Tuple.

Local Notation tuple := Tuple.tuple.
Local Open Scope list_scope.
Local Open Scope Z_scope.
Local Infix "^" := tuple : type_scope.

Module Export Exports.
  Export Coq.setoid_ring.ZArithRing.
End Exports.

Local Ltac solve_constant_local_sig :=
  idtac;
  lazymatch goal with
  | [ |- { c : Z^?sz | Positional.Fdecode (m:=?M) ?wt c = ?v } ]
    => (exists (Positional.encode (n:=sz) (modulo_cps:=@modulo_cps) (div_cps:=@div_cps) wt (F.to_Z (m:=M) v)));
       lazymatch goal with
       | [ sz_nonzero : sz <> 0%nat, base_pos : (1 <= _)%Q |- _ ]
         => clear -base_pos sz_nonzero
       end
  end;
  abstract (
      setoid_rewrite Positional.Fdecode_Fencode_id;
      [ reflexivity
      | auto using wt_gen0_1, wt_gen_nonzero, wt_gen_divides', div_mod;
        intros; autorewrite with uncps push_id; auto using div_mod.. ]
    ).

Section gen.
  Context (m : positive)
          (base : Q)
          (sz : nat)
          (s : Z)
          (c : list limb)
          (carry_chains : list (list nat))
          (coef : Z^sz)
          (mul_code : option (Z^sz -> Z^sz -> Z^sz))
          (square_code : option (Z^sz -> Z^sz))
          (sz_nonzero : sz <> 0%nat)
          (s_nonzero : s <> 0)
          (base_pos : (1 <= base)%Q)
          (sz_le_log2_m : Z.of_nat sz <= Z.log2_up (Z.pos m)).

  Local Notation wt := (wt_gen base).
  Local Notation sz2 := (sz2' sz).
  Local Notation wt_divides' := (wt_gen_divides' base base_pos).
  Local Notation wt_nonzero := (wt_gen_nonzero base base_pos).

  (* side condition needs cbv [Positional.mul_cps Positional.reduce_cps]. *)
  Context (mul_code_correct
           : match mul_code with
             | None => True
             | Some v
               => forall a b,
                 v a b
                 = Positional.mul_cps (n:=sz) (m:=sz2) wt a b
                                      (fun ab => Positional.reduce_cps (n:=sz) (m:=sz2) wt s c ab id)
             end)
          (square_code_correct
           : match square_code with
             | None => True
             | Some v
               => forall a,
                 v a
                 = Positional.mul_cps (n:=sz) (m:=sz2) wt a a
                                      (fun ab => Positional.reduce_cps (n:=sz) (m:=sz2) wt s c ab id)
             end).

  Context (coef_mod : mod_eq m (Positional.eval wt coef) 0)
          (m_correct : Z.pos m = s - Associational.eval c).


  (* Performs a full carry loop (as specified by carry_chain) *)
  Definition carry_sig'
    : { carry : (Z^sz -> Z^sz)%type
      | forall a : Z^sz,
          let eval := Positional.Fdecode (m := m) wt in
          eval (carry a) = eval a }.
  Proof.
    let a := fresh "a" in
    eexists; cbv beta zeta; intros a.
    pose proof (wt_gen0_1 base).
    pose proof wt_nonzero; pose proof div_mod.
    pose proof (wt_gen_divides_chains base base_pos carry_chains).
    pose proof wt_divides'.
    let x := constr:(Positional.chained_carries_reduce (n:=sz) (modulo_cps:=@modulo_cps) (div_cps:=@div_cps) wt s c a carry_chains) in
    presolve_op_F constr:(wt) x;
      [ autorewrite with pattern_runtime; reflexivity | ].
    reflexivity.
  Defined.

  Definition constant_sig' v
    : { c : Z^sz | Positional.Fdecode (m:=m) wt c = v}.
  Proof. solve_constant_local_sig. Defined.

  Definition zero_sig'
    : { zero : Z^sz | Positional.Fdecode (m:=m) wt zero = 0%F}
    := Eval hnf in constant_sig' _.

  Definition one_sig'
    : { one : Z^sz | Positional.Fdecode (m:=m) wt one = 1%F}
    := Eval hnf in constant_sig' _.

  Definition add_sig'
    : { add : (Z^sz -> Z^sz -> Z^sz)%type
      | forall a b : Z^sz,
          let eval := Positional.Fdecode (m:=m) wt in
          eval (add a b) = (eval a + eval b)%F }.
  Proof.
    eexists; cbv beta zeta; intros a b.
    pose proof wt_nonzero.
    pose proof (wt_gen0_1 base).
    let x := constr:(
               Positional.add_cps (n := sz) wt a b id) in
    presolve_op_F constr:(wt) x;
      [ autorewrite with pattern_runtime; reflexivity | ].
    reflexivity.
  Defined.

  Definition sub_sig'
    : { sub : (Z^sz -> Z^sz -> Z^sz)%type
      | forall a b : Z^sz,
          let eval := Positional.Fdecode (m:=m) wt in
          eval (sub a b) = (eval a - eval b)%F }.
  Proof.
    let a := fresh "a" in
    let b := fresh "b" in
    eexists; cbv beta zeta; intros a b.
    pose proof wt_nonzero.
    pose proof (wt_gen0_1 base).
    let x := constr:(
               Positional.sub_cps (n:=sz) (coef := coef) wt a b id) in
    presolve_op_F constr:(wt) x;
      [ autorewrite with pattern_runtime; reflexivity | ].
    reflexivity.
  Defined.

  Definition opp_sig'
    : { opp : (Z^sz -> Z^sz)%type
      | forall a : Z^sz,
          let eval := Positional.Fdecode (m := m) wt in
          eval (opp a) = F.opp (eval a) }.
  Proof.
    eexists; cbv beta zeta; intros a.
    pose proof wt_nonzero.
    pose proof (wt_gen0_1 base).
    let x := constr:(
               Positional.opp_cps (n:=sz) (coef := coef) wt a id) in
    presolve_op_F constr:(wt) x;
      [ autorewrite with pattern_runtime; reflexivity | ].
    reflexivity.
  Defined.

  Definition mul_sig'
    : { mul : (Z^sz -> Z^sz -> Z^sz)%type
      | forall a b : Z^sz,
          let eval := Positional.Fdecode (m := m) wt in
          eval (mul a b) = (eval a * eval b)%F }.
  Proof.
    eexists; cbv beta zeta; intros a b.
    pose proof wt_nonzero.
    pose proof (wt_gen0_1 base).
    pose proof (sz2'_nonzero sz sz_nonzero).
    let x := constr:(
               Positional.mul_cps (n:=sz) (m:=sz2) wt a b
                                  (fun ab => Positional.reduce_cps (n:=sz) (m:=sz2) wt s c ab id)) in
    presolve_op_F constr:(wt) x; [ | reflexivity ].
    let rhs := match goal with |- _ = ?rhs => rhs end in
    transitivity (match mul_code with
                  | None => rhs
                  | Some v => v a b
                  end);
      [ reflexivity | ].
    destruct mul_code; try reflexivity.
    transitivity (Positional.mul_cps (n:=sz) (m:=sz2) wt a b
                                     (fun ab => Positional.reduce_cps (n:=sz) (m:=sz2) wt s c ab id)); [ | reflexivity ].
    auto.
  Defined.

  Definition square_sig'
    : { square : (Z^sz -> Z^sz)%type
      | forall a : Z^sz,
          let eval := Positional.Fdecode (m := m) wt in
          eval (square a) = (eval a * eval a)%F }.
  Proof.
    eexists; cbv beta zeta; intros a.
    pose proof wt_nonzero.
    pose proof (wt_gen0_1 base).
    pose proof (sz2'_nonzero sz sz_nonzero).
    let x := constr:(
               Positional.mul_cps (n:=sz) (m:=sz2) wt a a
                                  (fun ab => Positional.reduce_cps (n:=sz) (m:=sz2) wt s c ab id)) in
    presolve_op_F constr:(wt) x; [ | reflexivity ].
    let rhs := match goal with |- _ = ?rhs => rhs end in
    transitivity (match square_code with
                  | None => rhs
                  | Some v => v a
                  end);
      [ reflexivity | ].
    destruct square_code; try reflexivity.
    transitivity (Positional.mul_cps (n:=sz) (m:=sz2) wt a a
                                     (fun ab => Positional.reduce_cps (n:=sz) (m:=sz2) wt s c ab id)); [ | reflexivity ].
    auto.
  Defined.

  Let ring_pkg : { T : _ & T }.
  Proof.
    eexists.
    refine (fun zero_sig one_sig add_sig sub_sig mul_sig opp_sig
            => Ring.ring_by_isomorphism
                 (F := F m)
                 (H := Z^sz)
                 (phi := Positional.Fencode wt)
                 (phi' := Positional.Fdecode wt)
                 (zero := proj1_sig zero_sig)
                 (one := proj1_sig one_sig)
                 (opp := proj1_sig opp_sig)
                 (add := proj1_sig add_sig)
                 (sub := proj1_sig sub_sig)
                 (mul := proj1_sig mul_sig)
                 (phi'_zero := _)
                 (phi'_one := _)
                 (phi'_opp := _)
                 (Positional.Fdecode_Fencode_id
                    (sz_nonzero := sz_nonzero)
                    (div_mod := div_mod)
                    wt (wt_gen0_1 base) wt_nonzero wt_divides')
                 (Positional.eq_Feq_iff wt)
                 _ _ _);
      lazymatch goal with
      | [ |- context[@proj1_sig ?A ?P ?x] ]
        => pattern (@proj1_sig A P x);
             exact (@proj2_sig A P x)
      | _ => eauto using @Core.modulo_id, @Core.div_id with nocore
      end.
  Defined.

  Definition ring' zero_sig one_sig add_sig sub_sig mul_sig opp_sig
    := Eval cbv [ring_pkg projT2] in
        projT2 ring_pkg zero_sig one_sig add_sig sub_sig mul_sig opp_sig.
End gen.

Ltac internal_solve_code_correct P_tac :=
  hnf;
  lazymatch goal with
  | [ |- True ] => constructor
  | _
    => cbv [Positional.mul_cps Positional.reduce_cps];
       intros;
       autorewrite with pattern_runtime;
       repeat autounfold;
       autorewrite with pattern_runtime;
       basesystem_partial_evaluation_RHS;
       P_tac ();
       break_match; cbv [Let_In runtime_mul runtime_add]; repeat apply (f_equal2 pair); rewrite ?Z.shiftl_mul_pow2 by omega; ring
  end.

Ltac pose_mul_code_correct P_extra_prove_mul_eq sz sz2 wt s c mul_code mul_code_correct :=
  cache_proof_with_type_by
    (match mul_code with
     | None => True
     | Some v
       => forall a b,
         v a b
         = Positional.mul_cps (n:=sz) (m:=sz2) wt a b
                              (fun ab => Positional.reduce_cps (n:=sz) (m:=sz2) wt s c ab id)
     end)
    ltac:(internal_solve_code_correct P_extra_prove_mul_eq)
           mul_code_correct.

Ltac pose_square_code_correct P_extra_prove_square_eq sz sz2 wt s c square_code square_code_correct :=
  cache_proof_with_type_by
    (match square_code with
     | None => True
     | Some v
       => forall a,
         v a
         = Positional.mul_cps (n:=sz) (m:=sz2) wt a a
                              (fun ab => Positional.reduce_cps (n:=sz) (m:=sz2) wt s c ab id)
     end)
    ltac:(internal_solve_code_correct P_extra_prove_square_eq)
           square_code_correct.

Ltac cache_sig_with_type_by_existing_sig ty existing_sig id :=
  cache_sig_with_type_by_existing_sig_helper
    ltac:(fun _ => cbv [carry_sig' constant_sig' zero_sig' one_sig' add_sig' sub_sig' mul_sig' square_sig' opp_sig'])
           ty existing_sig id.

Ltac pose_carry_sig wt m base sz s c carry_chains carry_sig :=
  cache_sig_with_type_by_existing_sig
    {carry : (Z^sz -> Z^sz)%type |
     forall a : Z^sz,
       let eval := Positional.Fdecode (m := m) wt in
       eval (carry a) = eval a}
    (carry_sig' m base sz s c carry_chains)
    carry_sig.

Ltac pose_zero_sig wt m base sz sz_nonzero base_pos zero_sig :=
  cache_vm_sig_with_type
    { zero : Z^sz | Positional.Fdecode (m:=m) wt zero = 0%F}
    (zero_sig' m base sz sz_nonzero base_pos)
    zero_sig.

Ltac pose_one_sig wt m base sz sz_nonzero base_pos one_sig :=
  cache_vm_sig_with_type
    { one : Z^sz | Positional.Fdecode (m:=m) wt one = 1%F}
    (one_sig' m base sz sz_nonzero base_pos)
    one_sig.

Ltac pose_add_sig wt m base sz add_sig :=
  cache_sig_with_type_by_existing_sig
    { add : (Z^sz -> Z^sz -> Z^sz)%type |
      forall a b : Z^sz,
        let eval := Positional.Fdecode (m:=m) wt in
        eval (add a b) = (eval a + eval b)%F }
    (add_sig' m base sz)
    add_sig.

Ltac pose_sub_sig wt m base sz coef sub_sig :=
  cache_sig_with_type_by_existing_sig
    {sub : (Z^sz -> Z^sz -> Z^sz)%type |
     forall a b : Z^sz,
       let eval := Positional.Fdecode (m:=m) wt in
       eval (sub a b) = (eval a - eval b)%F}
    (sub_sig' m base sz coef)
    sub_sig.

Ltac pose_opp_sig wt m base sz coef opp_sig :=
  cache_sig_with_type_by_existing_sig
    {opp : (Z^sz -> Z^sz)%type |
     forall a : Z^sz,
       let eval := Positional.Fdecode (m := m) wt in
       eval (opp a) = F.opp (eval a)}
    (opp_sig' m base sz coef)
    opp_sig.

Ltac pose_mul_sig wt m base sz s c mul_code sz_nonzero s_nonzero base_pos mul_code_correct mul_sig :=
  cache_sig_with_type_by_existing_sig
    {mul : (Z^sz -> Z^sz -> Z^sz)%type |
     forall a b : Z^sz,
       let eval := Positional.Fdecode (m := m) wt in
       eval (mul a b) = (eval a * eval b)%F}
    (mul_sig' m base sz s c mul_code sz_nonzero s_nonzero base_pos mul_code_correct)
    mul_sig.

Ltac pose_square_sig wt m base sz s c square_code sz_nonzero s_nonzero base_pos square_code_correct square_sig :=
  cache_sig_with_type_by_existing_sig
    {square : (Z^sz -> Z^sz)%type |
     forall a : Z^sz,
       let eval := Positional.Fdecode (m := m) wt in
       eval (square a) = (eval a * eval a)%F}
    (square_sig' m base sz s c square_code sz_nonzero s_nonzero base_pos square_code_correct)
    square_sig.

Ltac pose_ring sz m wt wt_divides' sz_nonzero wt_nonzero zero_sig one_sig opp_sig add_sig sub_sig mul_sig ring :=
  cache_term
    (Ring.ring_by_isomorphism
       (F := F m)
       (H := Z^sz)
       (phi := Positional.Fencode wt)
       (phi' := Positional.Fdecode wt)
       (zero := proj1_sig zero_sig)
       (one := proj1_sig one_sig)
       (opp := proj1_sig opp_sig)
       (add := proj1_sig add_sig)
       (sub := proj1_sig sub_sig)
       (mul := proj1_sig mul_sig)
       (phi'_zero := proj2_sig zero_sig)
       (phi'_one := proj2_sig one_sig)
       (phi'_opp := proj2_sig opp_sig)
       (Positional.Fdecode_Fencode_id
          (sz_nonzero := sz_nonzero)
          (div_mod := div_mod)
          (modulo_cps_id:=@Core.modulo_id)
          (div_cps_id:=@Core.div_id)
          wt eq_refl wt_nonzero wt_divides')
       (Positional.eq_Feq_iff wt)
       (proj2_sig add_sig)
       (proj2_sig sub_sig)
       (proj2_sig mul_sig)
    )
    ring.

(*
Eval cbv [proj1_sig add_sig] in (proj1_sig add_sig).
Eval cbv [proj1_sig sub_sig] in (proj1_sig sub_sig).
Eval cbv [proj1_sig opp_sig] in (proj1_sig opp_sig).
Eval cbv [proj1_sig mul_sig] in (proj1_sig mul_sig).
Eval cbv [proj1_sig carry_sig] in (proj1_sig carry_sig).
 *)