aboutsummaryrefslogtreecommitdiff
path: root/src/Specific/Framework/OutputType.v
blob: 80fced9233dbc604b15526f80a9c5df98f123dfe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
Require Import Coq.ZArith.BinIntDef.
Require Import Crypto.Arithmetic.Core. Import B.
Require Import Crypto.Arithmetic.PrimeFieldTheorems.
Require Import Crypto.Compilers.Syntax.
Require Import Crypto.Compilers.SmartMap.
Require Import Crypto.Compilers.Tuple.
Require Import Crypto.Compilers.ExprInversion.
Require Import Crypto.Compilers.Z.Syntax.Util.
Require Import Crypto.Compilers.Z.Syntax.
Require Import Crypto.Compilers.Tuple.
Require Import Crypto.Specific.Framework.RawCurveParameters.
Require Import Crypto.Specific.Framework.ArithmeticSynthesis.Base.
Require Import Crypto.Util.Notations.
Import CurveParameters.Notations.

Local Coercion Z.to_nat : Z >-> nat.
Local Notation interp_flat_type := (interp_flat_type interp_base_type).

Section gen.
  Context (curve : RawCurveParameters.CurveParameters).

  Section gen_base_type.
    Context (b : base_type).

    Definition m := Z.to_pos (curve.(s) - Associational.eval curve.(c))%Z.
    Definition rT := ((Tbase b)^curve.(sz))%ctype.
    Definition T' := (interp_flat_type rT).
    Definition RT := (Unit -> rT)%ctype.
    Definition wt := (wt_gen curve.(base)).
    Definition encode : F m -> Expr RT
      := fun v var
         => Abs
              (fun _
               => SmartPairf
                    (flat_interp_untuple
                       (T:=Tbase _)
                       (Tuple.map
                          (fun v => Op (OpConst v) TT)
                          (@Positional.Fencode wt curve.(sz) m (@modulo_cps) (@div_cps) v)))).
    Definition decode : T' -> F m
      := fun v => @Positional.Fdecode
                    wt (curve.(sz)) m
                    (Tuple.map interpToZ (flat_interp_tuple (T:=Tbase _) v)).
  End gen_base_type.

  Local Notation TW := (TWord (Z.log2_up curve.(bitwidth))).
  Local Notation RTZ := (RT TZ).
  Local Notation rTZ := (rT TZ).
  Local Notation RTW := (RT TW).
  Local Notation rTW := (rT TW).

  Record SynthesisOutput :=
    {
      zeroZ : Expr RTZ;
      oneZ : Expr RTZ;
      addZ : Expr (rTZ * rTZ -> rTZ); (* does not include carry *)
      subZ : Expr (rTZ * rTZ -> rTZ); (* does not include carry *)
      carry_mulZ : Expr (rTZ * rTZ -> rTZ); (* includes carry *)
      carry_squareZ : Expr (rTZ -> rTZ); (* includes carry *)
      oppZ : Expr (rTZ -> rTZ); (* does not include carry *)
      carryZ : Expr (rTZ -> rTZ);
      carry_addZ : Expr (rTZ * rTZ -> rTZ)
      := (carryZ ∘ addZ)%expr;
      carry_subZ : Expr (rTZ * rTZ -> rTZ)
      := (carryZ ∘ subZ)%expr;
      carry_oppZ : Expr (rTZ -> rTZ)
      := (carryZ ∘ oppZ)%expr;

      zeroW : Expr RTW;
      oneW : Expr RTW;
      carry_addW : Expr (rTW * rTW -> rTW); (* includes carry *)
      carry_subW : Expr (rTW * rTW -> rTW); (* includes carry *)
      carry_mulW : Expr (rTW * rTW -> rTW); (* includes carry *)
      carry_squareW : Expr (rTW -> rTW); (* includes carry *)
      carry_oppW : Expr (rTW -> rTW); (* does not include carry *)

      PZ : T' TZ -> Prop;
      PW : T' TW -> Prop
      := fun v => PZ (tuple_map (A:=Tbase TW) (B:=Tbase TZ) (SmartVarfMap (@interpToZ)) v);

      encodeZ_valid : forall v, _;
      encodeZ_sig := fun v => exist PZ (Interp (encode TZ v) tt) (encodeZ_valid v);
      encodeZ_correct : forall v, decode TZ (Interp (encode TZ v) tt) = v;

      decodeZ_sig := fun v : sig PZ => decode TZ (proj1_sig v);

      zeroZ_correct : zeroZ = encode _ 0%F; (* which equality to use here? *)
      oneZ_correct : oneZ = encode _ 1%F; (* which equality to use here? *)
      zeroZ_sig := encodeZ_sig 0%F;
      oneZ_sig := encodeZ_sig 1%F;

      oppZ_valid : forall x, PZ x -> PZ (Interp carry_oppZ x);
      oppZ_sig := fun x => exist PZ _ (@oppZ_valid (proj1_sig x) (proj2_sig x));

      addZ_valid : forall x y, PZ x -> PZ y -> PZ (Interp carry_addZ (x, y));
      addZ_sig := fun x y => exist PZ _ (@addZ_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));

      subZ_valid : forall x y, PZ x -> PZ y -> PZ (Interp carry_subZ (x, y));
      subZ_sig := fun x y => exist PZ _ (@subZ_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));

      mulZ_valid : forall x y, PZ x -> PZ y -> PZ (Interp carry_mulZ (x, y));
      mulZ_sig := fun x y => exist PZ _ (@mulZ_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));

      squareZ_correct : forall x, PZ x -> Interp carry_squareZ x = Interp carry_mulZ (x, x);


      encodeW_valid : forall v, _;
      encodeW_sig := fun v => exist PW (Interp (encode TW v) tt) (encodeW_valid v);
      encodeW_correct : forall v, decode TW (Interp (encode TW v) tt) = v;

      decodeW_sig := fun v : sig PW => decode TW (proj1_sig v);

      zeroW_correct : zeroW = encode _ 0%F; (* which equality to use here? *)
      oneW_correct : oneW = encode _ 1%F; (* which equality to use here? *)
      zeroW_sig := encodeW_sig 0%F;
      oneW_sig := encodeW_sig 1%F;

      oppW_valid : forall x, PW x -> PW (Interp carry_oppW x);
      oppW_sig := fun x => exist PW _ (@oppW_valid (proj1_sig x) (proj2_sig x));

      addW_valid : forall x y, PW x -> PW y -> PW (Interp carry_addW (x, y));
      addW_sig := fun x y => exist PW _ (@addW_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));

      subW_valid : forall x y, PW x -> PW y -> PW (Interp carry_subW (x, y));
      subW_sig := fun x y => exist PW _ (@subW_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));

      mulW_valid : forall x y, PW x -> PW y -> PW (Interp carry_mulW (x, y));
      mulW_sig := fun x y => exist PW _ (@mulW_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));

      squareW_correct : forall x, PW x -> Interp carry_squareW x = Interp carry_mulW (x, x);

      T_Z := { v : T' TZ | PZ v };
      eqTZ : T_Z -> T_Z -> Prop
      := fun x y => eq (decode _ (proj1_sig x)) (decode _ (proj1_sig y));
      ringZ : @Hierarchy.ring
                T_Z eqTZ zeroZ_sig oneZ_sig oppZ_sig addZ_sig subZ_sig mulZ_sig;
      encodeZ_homomorphism
      : @Ring.is_homomorphism
          (F m) eq 1%F F.add F.mul
          T_Z eqTZ oneZ_sig addZ_sig mulZ_sig
          encodeZ_sig;
      decodeZ_homomorphism
      : @Ring.is_homomorphism
          T_Z eqTZ oneZ_sig addZ_sig mulZ_sig
          (F m) eq 1%F F.add F.mul
          decodeZ_sig;

      T_W := { v : T' TW | PW v };
      eqTW : T_W -> T_W -> Prop
      := fun x y => eq (decode _ (proj1_sig x)) (decode _ (proj1_sig y));
      ringW : @Hierarchy.ring
                T_W eqTW zeroW_sig oneW_sig oppW_sig addW_sig subW_sig mulW_sig;
      encodeW_homomorphism
      : @Ring.is_homomorphism
          (F m) eq 1%F F.add F.mul
          T_W eqTW oneW_sig addW_sig mulW_sig
          encodeW_sig;
      decodeW_homomorphism
      : @Ring.is_homomorphism
          T_W eqTW oneW_sig addW_sig mulW_sig
          (F m) eq 1%F F.add F.mul
          decodeW_sig
    }.
End gen.