aboutsummaryrefslogtreecommitdiff
path: root/src/Encoding/PointEncodingPre.v
blob: 73ced869b48fe561f99179442cae379ae194a28b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
Require Import Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
Require Import Coq.Numbers.Natural.Peano.NPeano.
Require Import Coq.Program.Equality.
Require Import Crypto.Encoding.EncodingTheorems.
Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Bedrock.Word.
Require Import Crypto.Encoding.ModularWordEncodingTheorems.
Require Import Crypto.Tactics.VerdiTactics.
Require Import Crypto.Util.ZUtil.

Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding Crypto.Spec.ModularArithmetic.

Local Open Scope F_scope.

Section PointEncoding.
  Context {prm: TwistedEdwardsParams} {sz : nat} {sz_nonzero : (0 < sz)%nat}
   {bound_check : (Z.to_nat q < 2 ^ sz)%nat} {q_5mod8 : (q mod 8 = 5)%Z}
   {sqrt_minus1_valid : (@ZToField q 2 ^ Z.to_N (q / 4)) ^ 2 = opp 1}
   {FqEncoding : canonical encoding of (F q) as (word sz)}
   {sign_bit : F q -> bool} {sign_bit_zero : sign_bit 0 = false}
   {sign_bit_opp : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x)}.
  Existing Instance prime_q.

  Add Field Ffield : (@Ffield_theory q _)
    (morphism (@Fring_morph q),
     preprocess [Fpreprocess],
     postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption],
     constants [Fconstant],
     div (@Fmorph_div_theory q),
     power_tac (@Fpower_theory q) [Fexp_tac]).

  Definition sqrt_valid (a : F q) := ((sqrt_mod_q a) ^ 2 = a)%F.

  Lemma solve_sqrt_valid : forall p, E.onCurve p ->
    sqrt_valid (E.solve_for_x2 (snd p)).
  Proof.
    intros ? onCurve_xy.
    destruct p as [x y]; simpl.
    rewrite (E.solve_correct x y) in onCurve_xy.
    rewrite <- onCurve_xy.
    unfold sqrt_valid.
    eapply sqrt_mod_q_valid; eauto.
    unfold isSquare; eauto.
    Grab Existential Variables. eauto.
  Qed.

  Lemma solve_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) ->
    E.onCurve (sqrt_mod_q (E.solve_for_x2 y), y).
  Proof.
    intros.
    unfold sqrt_valid in *.
    apply E.solve_correct; auto.
  Qed.

  Lemma solve_opp_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) ->
    E.onCurve (opp (sqrt_mod_q (E.solve_for_x2 y)), y).
  Proof.
    intros y sqrt_valid_x2.
    unfold sqrt_valid in *.
    apply E.solve_correct.
    rewrite <- sqrt_valid_x2 at 2.
    ring.
  Qed.

  Definition point_enc_coordinates (p : (F q * F q)) : Word.word (S sz) := let '(x,y) := p in
    Word.WS (sign_bit x) (enc y).

  Let point_enc (p : E.point) : Word.word (S sz) := let '(x,y) := proj1_sig p in
    Word.WS (sign_bit x) (enc y).

  Definition point_dec_coordinates (sign_bit : F q -> bool) (w : Word.word (S sz)) : option (F q * F q) :=
    match dec (Word.wtl w) with
    | None => None
    | Some y => let x2 := E.solve_for_x2 y in
        let x := sqrt_mod_q x2 in
        if F_eq_dec (x ^ 2) x2
        then
          let p := (if Bool.eqb (whd w) (sign_bit x) then x else opp x, y) in
          if (andb (F_eqb x 0) (whd w))
          then None (* special case for 0, since its opposite has the same sign; if the sign bit of 0 is 1, produce None.*)
          else Some p 
        else None
    end.

  Ltac inversion_Some_eq := match goal with [H: Some ?x = Some ?y |- _] => inversion H; subst end.

  Lemma point_dec_coordinates_onCurve : forall w p, point_dec_coordinates sign_bit w = Some p -> E.onCurve p.
  Proof.
    unfold point_dec_coordinates; intros.
    edestruct dec; [ | congruence].
    break_if; [ | congruence].
    break_if; [ congruence | ]. 
    break_if; inversion_Some_eq; auto using solve_onCurve, solve_opp_onCurve.
  Qed.

  Lemma prod_eq_dec : forall {A} (A_eq_dec : forall a a' : A, {a = a'} + {a <> a'})
   (x y : (A * A)), {x = y} + {x <> y}.
  Proof.
    decide equality.
  Qed.

  Lemma option_eq_dec : forall {A} (A_eq_dec : forall a a' : A, {a = a'} + {a <> a'})
   (x y : option A), {x = y} + {x <> y}.
  Proof.
    decide equality.
  Qed.

  Definition point_dec' w p : option E.point :=
    match (option_eq_dec (prod_eq_dec F_eq_dec) (point_dec_coordinates sign_bit w) (Some p)) with
      | left EQ => Some (exist _ p (point_dec_coordinates_onCurve w p EQ))
      | right _ => None (* this case is never reached *)
    end.

  Definition point_dec (w : word (S sz)) : option E.point :=
    match (point_dec_coordinates sign_bit w) with
    | Some p => point_dec' w p
    | None => None
    end.

  Lemma point_coordinates_encoding_canonical : forall w p,
    point_dec_coordinates sign_bit w = Some p -> point_enc_coordinates p = w.
  Proof.
    unfold point_dec_coordinates, point_enc_coordinates; intros ? ? coord_dec_Some.
    case_eq (dec (wtl w)); [ intros ? dec_Some | intros dec_None; rewrite dec_None in *; congruence ].
    destruct p.
    rewrite (shatter_word w).
    f_equal; rewrite dec_Some in *;
      do 2 (break_if; try congruence); inversion coord_dec_Some; subst.
    + destruct (F_eq_dec (sqrt_mod_q (E.solve_for_x2 f1)) 0%F) as [sqrt_0 | ?].
      - rewrite sqrt_0 in *.
        apply sqrt_mod_q_root_0 in sqrt_0; try assumption.
        rewrite sqrt_0 in *.
        break_if; [symmetry; auto using Bool.eqb_prop | ].
        rewrite sign_bit_zero in *.
        simpl in Heqb; rewrite Heqb in *.
        discriminate.
      - break_if.
        symmetry; auto using Bool.eqb_prop.
        rewrite <- sign_bit_opp by assumption.
        destruct (whd w); inversion Heqb0; break_if; auto.
    + inversion coord_dec_Some; subst.
      auto using encoding_canonical.
Qed.

  Lemma point_encoding_canonical : forall w x, point_dec w = Some x -> point_enc x = w.
  Proof.
  (*
    unfold point_enc; intros.
    unfold point_dec in *.
    assert (point_dec_coordinates w = Some (proj1_sig x)). {
      set (y := point_dec_coordinates w) in *.
      revert H.
      dependent destruction y. intros.
      rewrite H0 in H.
  *)
  Admitted.

Lemma point_dec_coordinates_correct w
  : option_map (@proj1_sig _ _) (point_dec w) = point_dec_coordinates sign_bit w.
Proof.
  unfold point_dec, option_map.
  do 2 break_match; try congruence; unfold point_dec' in *;
    break_match; try congruence.
  inversion_Some_eq. 
  reflexivity.
Qed.

Lemma y_decode : forall p, dec (wtl (point_enc_coordinates p)) = Some (snd p).
Proof.
  intros.
  destruct p as [x y]; simpl.
  exact (encoding_valid y).
Qed.

Lemma sign_bit_opp_eq_iff : forall x y, y <> 0 ->
  (sign_bit x <> sign_bit y <-> sign_bit x = sign_bit (opp y)).
Proof.
  split; intro sign_mismatch; case_eq (sign_bit x); case_eq (sign_bit y);
    try congruence; intros y_sign x_sign; rewrite <- sign_bit_opp in * by auto;
    rewrite y_sign, x_sign in *; reflexivity || discriminate.
Qed.

Lemma sign_bit_squares : forall x y, y <> 0 -> x ^ 2 = y ^ 2 ->
  sign_bit x = sign_bit y -> x = y.
Proof.
  intros ? ? y_nonzero squares_eq sign_match.
  destruct (sqrt_solutions _ _ squares_eq) as [? | eq_opp]; auto.
  assert (sign_bit x = sign_bit (opp y)) as sign_mismatch by (f_equal; auto).
  apply sign_bit_opp_eq_iff in sign_mismatch; auto.
  congruence.
Qed.

Lemma sign_bit_match : forall x x' y : F q, E.onCurve (x, y) -> E.onCurve (x', y) ->
  sign_bit x = sign_bit x' -> x = x'.
Proof.
  intros ? ? ? onCurve_x onCurve_x' sign_match.
  apply E.solve_correct in onCurve_x.
  apply E.solve_correct in onCurve_x'.
  destruct (F_eq_dec x' 0).
  + subst.
    rewrite Fq_pow_zero in onCurve_x' by congruence.
    rewrite <- onCurve_x' in *.
    eapply Fq_root_zero; eauto.
  + apply sign_bit_squares; auto.
    rewrite onCurve_x, onCurve_x'.
    reflexivity.
Qed.

Lemma point_encoding_coordinates_valid : forall p, E.onCurve p ->
   point_dec_coordinates sign_bit (point_enc_coordinates p) = Some p.
Proof.
  intros p onCurve_p.
  unfold point_dec_coordinates.
  rewrite y_decode.
  pose proof (solve_sqrt_valid p onCurve_p) as solve_sqrt_valid_p.
  destruct p as [x y].
  unfold sqrt_valid in *.
  simpl.
  replace (E.solve_for_x2 y) with (x ^ 2 : F q) in * by (apply E.solve_correct; assumption).
  case_eq (F_eqb x 0); intro eqb_x_0.
  + apply F_eqb_eq in eqb_x_0; rewrite eqb_x_0 in *.
    rewrite !Fq_pow_zero, sqrt_mod_q_of_0, Fq_pow_zero by congruence.
    rewrite if_F_eq_dec_if_F_eqb, sign_bit_zero. 
    reflexivity.
  + assert (sqrt_mod_q (x ^ 2) <> 0) by (intro false_eq; apply sqrt_mod_q_root_0 in false_eq; try assumption;
      apply Fq_root_zero in false_eq; rewrite false_eq, F_eqb_refl in eqb_x_0; congruence).
    replace (F_eqb (sqrt_mod_q (x ^ 2)) 0) with false by (symmetry;
        apply F_eqb_neq_complete; assumption).
    break_if.
    - simpl.
      f_equal.
      break_if.
      * rewrite Bool.eqb_true_iff in Heqb.
        pose proof (solve_onCurve y solve_sqrt_valid_p).
        f_equal.
        apply (sign_bit_match _ _ y); auto.
        apply E.solve_correct in onCurve_p; rewrite onCurve_p in *.
        assumption.
      * rewrite Bool.eqb_false_iff in Heqb.
        pose proof (solve_opp_onCurve y solve_sqrt_valid_p).
        f_equal.
        apply sign_bit_opp_eq_iff in Heqb; try assumption.
        apply (sign_bit_match _ _ y); auto.
        apply E.solve_correct in onCurve_p.
        rewrite onCurve_p; auto.
   - simpl in solve_sqrt_valid_p.
     replace (E.solve_for_x2 y) with (x ^ 2 : F q) in * by (apply E.solve_correct; assumption).
     congruence.
Qed.

Lemma point_dec'_valid : forall p,
  point_dec' (point_enc_coordinates (proj1_sig p)) (proj1_sig p) = Some p.
Proof.
  unfold point_dec'; intros.
  break_match.
  + f_equal.
    destruct p.
    apply E.point_eq.
    reflexivity.
  + rewrite point_encoding_coordinates_valid in n by apply (proj2_sig p).
    congruence.
Qed.

Lemma point_encoding_valid : forall p, point_dec (point_enc p) = Some p.
Proof.
  intros.
  unfold point_dec.
  replace (point_enc p) with (point_enc_coordinates (proj1_sig p)) by reflexivity.
  break_match; rewrite point_encoding_coordinates_valid in * by apply (proj2_sig p); try congruence.
  inversion_Some_eq.
  eapply point_dec'_valid.
Qed.

End PointEncoding.