aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Karatsuba.v
blob: f17623da70023f25dc97517cbc2a6c359b33e6da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Algebra.Nsatz.
Require Import Crypto.Util.ZUtil Crypto.Util.LetIn Crypto.Util.CPSUtil Crypto.Util.Tactics.
Require Import Crypto.Arithmetic.Core. Import B. Import Positional.
Require Import Crypto.Util.Tuple.
Local Open Scope Z_scope.

Section Karatsuba.
Context (weight : nat -> Z)
        (weight_0 : weight 0%nat = 1%Z)
        (weight_nonzero : forall i, weight i <> 0).
  (* [tuple Z n] is the "half-length" type,
     [tuple Z n2] is the "full-length" type *)
  Context {n n2 : nat} (n_nonzero : n <> 0%nat) (n2_nonzero : n2 <> 0%nat).
  Let T := tuple Z n.
  Let T2 := tuple Z n2.

  (* 
     If x = x0 + sx1 and y = y0 + sy1, then xy = s^2 * z2 + s * z1 + s * z0, 
     with:
     
     z2 = x1y1
     z0 = x0y0
     z1 = (x1+x0)(y1+y0) - (z2 + z0)

     Computing z1 one operation at a time:
     sum_z = z0 + z2
     sum_x = x1 + x0
     sum_y = y1 + y0
     mul_sumxy = sum_x * sum_y
     z1 = mul_sumxy - sum_z
  *)
  Definition karatsuba_mul_cps s (x y : T2) {R} (f:T2->R) :=
    split_cps (n:=n2) (m1:=n) (m2:=n) weight s x
      (fun x0_x1 => split_cps weight s y
      (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1)
      (fun z0 => mul_cps weight(snd x0_x1) (snd y0_y1)
      (fun z2 => add_cps weight z0 z2
      (fun sum_z => add_cps weight (fst x0_x1) (snd x0_x1)
      (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1)
      (fun sum_y => mul_cps weight sum_x sum_y
      (fun mul_sumxy => unbalanced_sub_cps weight mul_sumxy sum_z
      (fun z1 => scmul_cps weight s z1
      (fun sz1 => scmul_cps weight (s^2) z2
      (fun s2z2 => add_cps weight s2z2 sz1
      (fun add_s2z2_sz1 => add_cps weight add_s2z2_sz1 z0 f)))))))))))).

  Definition karatsuba_mul s x y := @karatsuba_mul_cps s x y _ id.
  Lemma karatsuba_mul_id s x y R f :
    @karatsuba_mul_cps s x y R f = f (karatsuba_mul s x y).
  Proof.
    cbv [karatsuba_mul karatsuba_mul_cps].
    repeat autounfold.
    autorewrite with cancel_pair push_id uncps.
    reflexivity.
  Qed.

  Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) :
    eval weight (karatsuba_mul s x y) = eval weight x * eval weight y.
  Proof.
    cbv [karatsuba_mul karatsuba_mul_cps]; repeat autounfold.
    autorewrite with cancel_pair push_id uncps push_basesystem_eval.
    repeat match goal with
           | _ => rewrite <-eval_to_associational
           | |- context [(to_associational ?w ?x)] =>
             rewrite <-(Associational.eval_split
                          s (to_associational w x)) by assumption
           | _ => rewrite <-Associational.eval_split by assumption
           | _ => setoid_rewrite Associational.eval_nil
           end.
    ring_simplify.
    nsatz.
  Qed.

  (*
    If:
        s^2 mod p = (s + 1) mod p
        x = x0 + sx1
        y = y0 + sy1
    Then, with z0 and z2 as before and z1 = ((a + b) * (c + d)) - z0,
        xy mod p = (z0 + z2 + sz1) mod p
    
    Computing xy one operation at a time:
    sum_z = z0 + z2
    sum_x = x0 + x1
    sum_y = y0 + y1
    mul_sumxy = sum_x * sum_y
    z1 = mul_sumxy - z0
    sz1 = s * z1
    xy = sum_z - sz1
   
  *)
  Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T2->R) :=
    split_cps (m1:=n) (m2:=n) weight s xs
      (fun x0_x1 => split_cps weight s ys
      (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1)
      (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1)
      (fun z2 => add_cps weight z0 z2
      (fun sum_z => add_cps weight (fst x0_x1) (snd x0_x1)
      (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1)
      (fun sum_y => mul_cps weight sum_x sum_y
      (fun mul_sumxy => unbalanced_sub_cps weight mul_sumxy z0
      (fun z1 => scmul_cps weight s z1
      (fun sz1 => add_cps weight sum_z sz1 f)))))))))).

  Definition goldilocks_mul s xs ys := @goldilocks_mul_cps s xs ys _ id.
  Lemma goldilocks_mul_id s xs ys {R} f :
    @goldilocks_mul_cps s xs ys R f = f (goldilocks_mul s xs ys).
  Proof.
    cbv [goldilocks_mul goldilocks_mul_cps].
    repeat autounfold.
    autorewrite with cancel_pair push_id uncps.
    reflexivity.
  Qed.
    
  Local Existing Instances Z.equiv_modulo_Reflexive
        RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric
        Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper
        Z.modulo_equiv_modulo_Proper.

  Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys :
    (eval weight (goldilocks_mul s xs ys)) mod p = (eval weight xs * eval weight ys) mod p.
  Proof.
    cbv [goldilocks_mul_cps goldilocks_mul]; Zmod_to_equiv_modulo.
    repeat autounfold; autorewrite with push_id cancel_pair uncps push_basesystem_eval.
    repeat match goal with
           | _ => rewrite <-eval_to_associational
           | |- context [(to_associational ?w ?x)] =>
             rewrite <-(Associational.eval_split
                          s (to_associational w x)) by assumption
           | _ => rewrite <-Associational.eval_split by assumption
           | _ => setoid_rewrite Associational.eval_nil
           end.

    ring_simplify.
    setoid_rewrite s2_modp.
    apply f_equal2; nsatz.
  Qed.
End Karatsuba.