aboutsummaryrefslogtreecommitdiff
path: root/src/Specific/Framework/ArithmeticSynthesis/Ladderstep.v
blob: 2f2ef07a5ed7d330dde69d748f20e7d208cf2034 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
Require Import Coq.ZArith.BinIntDef.
Require Import Crypto.Arithmetic.Core. Import B.
Require Import Crypto.Arithmetic.PrimeFieldTheorems.
Require Import Crypto.Curves.Montgomery.XZ.
Require Import Crypto.Specific.Framework.ArithmeticSynthesis.HelperTactics.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Notations.
Require Import Crypto.Util.Tactics.PoseTermWithName.
Require Import Crypto.Util.Tactics.CacheTerm.
Require Import Crypto.Util.Option.

Local Notation tuple := Tuple.tuple.
Local Open Scope list_scope.
Local Open Scope Z_scope.
Local Infix "^" := tuple : type_scope.

(** TODO(jadep,andreser): Move to NewBaseSystemTest? *)
Definition FMxzladderstep {m} := @M.donnaladderstep (F m) F.add F.sub F.mul.

Section with_notations.
  Context sz (add sub mul : tuple Z sz -> tuple Z sz -> tuple Z sz)
          (square carry : tuple Z sz -> tuple Z sz).
  Local Infix "+" := add.
  Local Notation "a * b" := (carry (mul a b)).
  Local Notation "x ^ 2" := (carry (square x)).
  Local Infix "-" := sub.
  Definition Mxzladderstep a24 x1 Q Q'
    := match Q, Q' with
       | (x, z), (x', z') =>
         dlet origx := x in
         dlet x := x + z in
         dlet z := origx - z in
         dlet origx' := x' in
         dlet x' := x' + z' in
         dlet z' := origx' - z' in
         dlet xx' := x' * z in
         dlet zz' := x * z' in
         dlet origx' := xx' in
         dlet xx' := xx' + zz' in
         dlet zz' := origx' - zz' in
         dlet x3 := xx'^2 in
         dlet zzz' := zz'^2 in
         dlet z3 := zzz' * x1 in
         dlet xx := x^2 in
         dlet zz := z^2 in
         dlet x2 := xx * zz in
         dlet zz := xx - zz in
         dlet zzz := zz * a24 in
         dlet zzz := zzz + xx in
         dlet z2 := zz * zzz in
         ((x2, z2), (x3, z3))%core
       end.
End with_notations.

Ltac pose_a24' a24 a24' :=
  let a24 := (eval vm_compute in (invert_Some a24)) in
  cache_term_with_type_by
    Z
    ltac:(exact a24)
           a24'.

Ltac pose_a24_sig sz m wt a24' a24_sig :=
  cache_term_with_type_by
    { a24t : Z^sz | Positional.Fdecode (m:=m) wt a24t = F.of_Z m a24' }
    solve_constant_sig
    a24_sig.

Ltac pose_Mxzladderstep_sig sz wt m add_sig sub_sig mul_sig square_sig carry_sig Mxzladderstep_sig :=
  cache_term_with_type_by
    { xzladderstep : tuple Z sz -> tuple Z sz -> tuple Z sz * tuple Z sz -> tuple Z sz * tuple Z sz -> tuple Z sz * tuple Z sz * (tuple Z sz * tuple Z sz)
    | forall a24 x1 Q Q', let eval := B.Positional.Fdecode wt in Tuple.map (n:=2) (Tuple.map (n:=2) eval) (xzladderstep a24 x1 Q Q') = FMxzladderstep (m:=m) (eval a24) (eval x1) (Tuple.map (n:=2) eval Q) (Tuple.map (n:=2) eval Q') }
    ltac:((exists (Mxzladderstep sz (proj1_sig add_sig) (proj1_sig sub_sig) (proj1_sig mul_sig) (proj1_sig square_sig) (proj1_sig carry_sig)));
          let a24 := fresh "a24" in
          let x1 := fresh "x1" in
          let Q := fresh "Q" in
          let Q' := fresh "Q'" in
          let eval := fresh "eval" in
          intros a24 x1 Q Q' eval;
          cbv [Mxzladderstep FMxzladderstep M.donnaladderstep];
          destruct Q, Q'; cbv [map map' fst snd Let_In eval];
          repeat match goal with
                 | [ |- context[@proj1_sig ?a ?b ?s] ]
                   => rewrite !(@proj2_sig a b s)
                 end;
          reflexivity)
           Mxzladderstep_sig.