aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Karatsuba.v
blob: 1873e5ef19eaa996d5540268eff9bed6c1623bc0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
Require Import Coq.ZArith.ZArith.
Require Import Coq.micromega.Lia.
Require Import Crypto.Algebra.Nsatz.
Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil.
Require Import Crypto.Arithmetic.Core. Import B. Import Positional.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.IdfunWithAlt.
Require Import Crypto.Util.ZUtil.EquivModulo.
Local Open Scope Z_scope.

Section Karatsuba.
Context (weight : nat -> Z)
        (weight_0 : weight 0%nat = 1%Z)
        (weight_nonzero : forall i, weight i <> 0).
  (* [tuple Z n] is the "half-length" type,
     [tuple Z n2] is the "full-length" type *)
  Context {n n2 : nat} (n_nonzero : n <> 0%nat) (n2_nonzero : n2 <> 0%nat).
  Let T := tuple Z n.
  Let T2 := tuple Z n2.

  (*
     If x = x0 + sx1 and y = y0 + sy1, then xy = s^2 * z2 + s * z1 + s * z0,
     with:

     z2 = x1y1
     z0 = x0y0
     z1 = (x1+x0)(y1+y0) - (z2 + z0)

     Computing z1 one operation at a time:
     sum_z = z0 + z2
     sum_x = x1 + x0
     sum_y = y1 + y0
     mul_sumxy = sum_x * sum_y
     z1 = mul_sumxy - sum_z
  *)
  Definition karatsuba_mul_cps s (x y : T2) {R} (f:T2->R) :=
    split_cps (n:=n2) (m1:=n) (m2:=n) weight s x
      (fun x0_x1 => split_cps weight s y
      (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1)
      (fun z0 => mul_cps weight(snd x0_x1) (snd y0_y1)
      (fun z2 => add_cps weight z0 z2
      (fun sum_z => add_cps weight (fst x0_x1) (snd x0_x1)
      (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1)
      (fun sum_y => mul_cps weight sum_x sum_y
      (fun mul_sumxy => unbalanced_sub_cps weight mul_sumxy sum_z
      (fun z1 => scmul_cps weight s z1
      (fun sz1 => scmul_cps weight (s^2) z2
      (fun s2z2 => add_cps weight s2z2 sz1
      (fun add_s2z2_sz1 => add_cps weight add_s2z2_sz1 z0 f)))))))))))).

  Definition karatsuba_mul s x y := @karatsuba_mul_cps s x y _ id.
  Lemma karatsuba_mul_id s x y R f :
    @karatsuba_mul_cps s x y R f = f (karatsuba_mul s x y).
  Proof.
    cbv [karatsuba_mul karatsuba_mul_cps].
    repeat autounfold.
    autorewrite with cancel_pair push_id uncps.
    reflexivity.
  Qed.
  Hint Opaque karatsuba_mul : uncps.
  Hint Rewrite karatsuba_mul_id : uncps.

  Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) :
    eval weight (karatsuba_mul s x y) = eval weight x * eval weight y.
  Proof.
    cbv [karatsuba_mul karatsuba_mul_cps]; repeat autounfold.
    autorewrite with cancel_pair push_id uncps push_basesystem_eval.
    repeat match goal with
           | _ => rewrite <-eval_to_associational
           | |- context [(to_associational ?w ?x)] =>
             rewrite <-(Associational.eval_split
                          s (to_associational w x)) by assumption
           | _ => rewrite <-Associational.eval_split by assumption
           | _ => setoid_rewrite Associational.eval_nil
           end.
    ring_simplify.
    nsatz.
  Qed.

  (* These definitions are intended to make bounds analysis go through
    for karatsuba. Essentially, we provide a version of the code to
    actually run and a version to bounds-check, along with a proof
    that they are exactly equal. This works around cases where the
    bounds proof requires high-level reasoning. *)
  Local Notation id_with_alt_bounds_cps := id_tuple_with_alt_cps'.

  (*
    If:
        s^2 mod p = (s + 1) mod p
        x = x0 + sx1
        y = y0 + sy1
    Then, with z0 and z2 as before (x0y0 and x1y1 respectively), let z1 = ((x0 + x1) * (y0 + y1)) - z0.

    Computing xy one operation at a time:
    sum_z = z0 + z2
    sum_x = x0 + x1
    sum_y = y0 + y1
    mul_sumxy = sum_x * sum_y
    z1 = mul_sumxy - z0
    sz1 = s * z1
    xy = sum_z - sz1

    The subtraction in the computation of z1 presents issues for
    bounds analysis. In particular, just analyzing the upper and lower
    bounds of the values would indicate that it could underflow--we
    know it won't because

    mul_sumxy -z0 = ((x0+x1) * (y0+y1)) - x0y0
                  = (x0y0 + x1y0 + x0y1 + x1y1) - x0y0
                  = x1y0 + x0y1 + x1y1

    Therefore, we use id_with_alt_bounds to indicate that the
    bounds-checker should check the non-subtracting form.

   *)

  (*
  Definition goldilocks_mul_cps_for_bounds_checker
             s (xs ys : T2) {R} (f:T2->R) :=
    split_cps (m1:=n) (m2:=n) weight s xs
      (fun x0_x1 => split_cps weight s ys

      (fun z1 => Positional.to_associational_cps weight z1
      (fun z1 => Associational.mul_cps (pair s 1::nil) z1
      (fun sz1 => Positional.from_associational_cps weight n2 sz1
      (fun sz1 => add_cps weight sum_z sz1 f)))))))))))).
   *)

  Let T3 := tuple Z (n2+n).
  Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T3->R) :=
    split_cps (m1:=n) (m2:=n) weight s xs
      (fun x0_x1 => split_cps weight s ys
      (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1)
      (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1)
      (fun z2 => add_cps weight z0 z2
      (fun sum_z : tuple _ n2 => add_cps weight (fst x0_x1) (snd x0_x1)
      (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1)
      (fun sum_y => mul_cps weight sum_x sum_y
      (fun mul_sumxy =>

      id_with_alt_bounds_cps (fun f =>
      (unbalanced_sub_cps weight mul_sumxy z0 f)) (fun f =>

      (mul_cps weight (fst x0_x1) (snd y0_y1)
      (fun x0_y1 => mul_cps weight (snd x0_x1) (fst y0_y1)
      (fun x1_y0 => mul_cps weight (fst x0_x1) (fst y0_y1)
      (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1)
      (fun z2 => add_cps weight z0 z2
      (fun sum_z => add_cps weight x0_y1 x1_y0
      (fun z1' => add_cps weight z1' z2 f)))))))) (fun z1 =>

                 Positional.to_associational_cps weight z1
      (fun z1 => Associational.mul_cps (pair s 1::nil) z1
      (fun sz1 => Positional.to_associational_cps weight sum_z
      (fun sum_z => Positional.from_associational_cps weight _ (sum_z++sz1) f
      )))))))))))).

  Definition goldilocks_mul s xs ys := goldilocks_mul_cps s xs ys id.
  Lemma goldilocks_mul_id s xs ys R f :
    @goldilocks_mul_cps s xs ys R f = f (goldilocks_mul s xs ys).
  Proof.
    cbv [goldilocks_mul goldilocks_mul_cps Let_In].
    repeat autounfold. autorewrite with uncps push_id.
    reflexivity.
  Qed.
  Hint Opaque goldilocks_mul : uncps.
  Hint Rewrite goldilocks_mul_id : uncps.

  Local Existing Instances Z.equiv_modulo_Reflexive
        RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric
        Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper
        Z.modulo_equiv_modulo_Proper.

  Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys :
    (eval weight (goldilocks_mul s xs ys)) mod p = (eval weight xs * eval weight ys) mod p.
  Proof.
    cbv [goldilocks_mul_cps goldilocks_mul Let_In].
    Zmod_to_equiv_modulo.
    progress autounfold.
    progress autorewrite with push_id cancel_pair uncps push_basesystem_eval.
    rewrite !unfold_id_tuple_with_alt.
    repeat match goal with
    | _ => rewrite <-eval_to_associational
    | |- context [(to_associational ?w ?x)] =>
      rewrite <-(Associational.eval_split
                   s (to_associational w x)) by assumption
    | _ => rewrite <-Associational.eval_split by assumption
    | _ => setoid_rewrite Associational.eval_nil
    end.
    progress autorewrite with push_id cancel_pair uncps push_basesystem_eval.
    repeat (rewrite ?eval_from_associational, ?eval_to_associational).
    progress autorewrite with push_id cancel_pair uncps push_basesystem_eval.
    repeat match goal with
    | _ => rewrite <-eval_to_associational
    | |- context [(to_associational ?w ?x)] =>
      rewrite <-(Associational.eval_split
                   s (to_associational w x)) by assumption
    | _ => rewrite <-Associational.eval_split by assumption
    | _ => setoid_rewrite Associational.eval_nil
    end.
    ring_simplify.
    setoid_rewrite s2_modp.
    apply f_equal2; nsatz.
    assumption. assumption. omega.
  Qed.

  Lemma eval_goldilocks_mul (p : positive) s (s_nonzero : s <> 0) (s2_modp : mod_eq p (s^2) (s+1)) xs ys :
    mod_eq p (eval weight (goldilocks_mul s xs ys)) (eval weight xs * eval weight ys).
  Proof.
    apply goldilocks_mul_correct; auto; lia.
  Qed.
End Karatsuba.
Hint Opaque karatsuba_mul goldilocks_mul : uncps.
Hint Rewrite karatsuba_mul_id goldilocks_mul_id : uncps.

Hint Rewrite
     @eval_karatsuba_mul
     @eval_goldilocks_mul
     @goldilocks_mul_correct
     using (assumption || (div_mod_cps_t; auto)) : push_basesystem_eval.

Ltac basesystem_partial_evaluation_unfolder t :=
  let t := (eval cbv delta [goldilocks_mul karatsuba_mul goldilocks_mul_cps karatsuba_mul_cps] in t) in
  let t := Arithmetic.Core.basesystem_partial_evaluation_unfolder t in
  t.

Ltac Arithmetic.Core.basesystem_partial_evaluation_default_unfolder t ::=
  basesystem_partial_evaluation_unfolder t.