aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ExtendedBaseVector.v
blob: ef8c9716a22c41d6ef1e7a12190160312cb397b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith.
Require Import Coq.Lists.List.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.Tactics.VerdiTactics.
Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
Require Import Crypto.BaseSystemProofs.
Require Crypto.BaseSystem.
Local Open Scope Z_scope.

Section ExtendedBaseVector.
  Context (limb_widths : list Z)
          (limb_widths_nonnegative : forall x, In x limb_widths -> 0 <= x).
  Local Notation k := (sum_firstn limb_widths (length limb_widths)).
  Local Notation base := (base_from_limb_widths limb_widths).

  (* This section defines a new BaseVector that has double the length of the BaseVector
  * used to construct [params]. The coefficients of the new vector are as follows:
  *
  * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i]
  *
  * The purpose of this construction is that it allows us to multiply numbers expressed
  * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as
  * vectors of digits; the value of a digit vector is obtained by doing a dot product with
  * the base vector.) So if x, y are digit vectors:
  *
  * (x \dot base) * (y \dot base) = (z \dot ext_base)
  *
  * Then we can separate z into its first and second halves:
  *
  * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base)
  *
  * Now, if we want to reduce the product modulo 2 ^ k - c:
  *
  * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c)
  * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c)
  *
  * This sum may be short enough to express using base; if not, we can reduce again.
  *)
  Definition ext_limb_widths := limb_widths ++ limb_widths.
  Definition ext_base := base_from_limb_widths ext_limb_widths.
  Lemma ext_base_alt : ext_base = base ++ (map (Z.mul (2^k)) base).
  Proof.
    unfold ext_base, ext_limb_widths.
    rewrite base_from_limb_widths_app by auto.
    rewrite two_p_equiv.
    reflexivity.
  Qed.

  Lemma ext_base_positive : forall b, In b ext_base -> b > 0.
  Proof.
    apply base_positive; unfold ext_limb_widths.
    intros ? H. apply in_app_or in H; destruct H; auto.
  Qed.

  Lemma b0_1 : forall x, nth_default x base 0 = 1 -> nth_default x ext_base 0 = 1.
  Proof.
    intros. rewrite ext_base_alt, nth_default_app.
    destruct base; assumption.
  Qed.

  Lemma map_nth_default_base_high : forall n, (n < (length base))%nat ->
    nth_default 0 (map (Z.mul (2 ^ k)) base) n =
    (2 ^ k) * (nth_default 0 base n).
  Proof.
    intros.
    erewrite map_nth_default; auto.
  Qed.

  Section base_good.
    Context (two_k_nonzero : 2^k <> 0)
            (base_good : forall i j, (i+j < length base)%nat ->
                                     let b := nth_default 0 base in
                                     let r := (b i * b j) / b (i+j)%nat in
                                     b i * b j = r * b (i+j)%nat)
            (limb_widths_match_modulus : forall i j,
                (i < length limb_widths)%nat ->
                (j < length limb_widths)%nat ->
                (i + j >= length limb_widths)%nat ->
                let w_sum := sum_firstn limb_widths in
                k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j).

    Lemma base_good_over_boundary
      : forall (i : nat)
               (l : (i < length base)%nat)
               (j' : nat)
               (Hj': (i + j' < length base)%nat),
        2 ^ k * (nth_default 0 base i * nth_default 0 base j') =
        (2 ^ k * (nth_default 0 base i * nth_default 0 base j'))
          / (2 ^ k * nth_default 0 base (i + j')) *
        (2 ^ k * nth_default 0 base (i + j')).
    Proof.
      clear limb_widths_match_modulus.
      intros.
      remember (nth_default 0 base) as b.
      rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero).
      replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat))
      with  ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring.
      rewrite Z.mul_cancel_l by (exact two_k_nonzero).
      replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat))
      with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring.
      subst b.
      apply (base_good i j'); omega.
    Qed.

    Lemma ext_base_good :
      forall i j, (i+j < length ext_base)%nat ->
                  let b := nth_default 0 ext_base in
                  let r := (b i * b j) / b (i+j)%nat in
                  b i * b j = r * b (i+j)%nat.
    Proof.
      intros.
      subst b. subst r.
      rewrite ext_base_alt in *.
      rewrite app_length in H; rewrite map_length in H.
      repeat rewrite nth_default_app.
      repeat break_if; try omega.
      { (* i < length base, j < length base, i + j < length base *)
        auto using BaseSystem.base_good.
      } { (* i < length base, j < length base, i + j >= length base *)
        rewrite (map_nth_default _ _ _ _ 0) by omega.
        apply base_matches_modulus; auto using limb_widths_nonnegative, limb_widths_match_modulus;
        distr_length.
        assumption.
      } { (* i < length base, j >= length base, i + j >= length base *)
        do 2 rewrite map_nth_default_base_high by omega.
        remember (j - length base)%nat as j'.
        replace (i + j - length base)%nat with (i + j')%nat by omega.
        replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j'))
        with (2 ^ k * (nth_default 0 base i * nth_default 0 base j'))
          by ring.
        eapply base_good_over_boundary; eauto; omega.
      } { (* i >= length base, j < length base, i + j >= length base *)
        do 2 rewrite map_nth_default_base_high by omega.
        remember (i - length base)%nat as i'.
        replace (i + j - length base)%nat with (j + i')%nat by omega.
        replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j)
        with (2 ^ k * (nth_default 0 base j * nth_default 0 base i'))
          by ring.
        eapply base_good_over_boundary; eauto; omega.
      }
    Qed.
  End base_good.

  Lemma extended_base_length:
      length ext_base = (length base + length base)%nat.
  Proof.
    clear limb_widths_nonnegative.
    unfold ext_base, ext_limb_widths; autorewrite with distr_length; reflexivity.
  Qed.

  Lemma firstn_us_base_ext_base : forall (us : BaseSystem.digits),
      (length us <= length base)%nat
      -> firstn (length us) base = firstn (length us) ext_base.
  Proof.
    rewrite ext_base_alt; intros.
    rewrite firstn_app_inleft; auto; omega.
  Qed.

  Lemma decode_short : forall (us : BaseSystem.digits),
    (length us <= length base)%nat ->
    BaseSystem.decode base us = BaseSystem.decode ext_base us.
  Proof. auto using decode_short_initial, firstn_us_base_ext_base. Qed.

  Section BaseVector.
    Context {bv : BaseSystem.BaseVector base}
            (limb_widths_match_modulus : forall i j,
                (i < length limb_widths)%nat ->
                (j < length limb_widths)%nat ->
                (i + j >= length limb_widths)%nat ->
                let w_sum := sum_firstn limb_widths in
                k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j).

    Instance ExtBaseVector : BaseSystem.BaseVector ext_base :=
      { base_positive := ext_base_positive;
        b0_1 x := b0_1 x (BaseSystem.b0_1 _);
        base_good := ext_base_good (two_sum_firstn_limb_widths_nonzero limb_widths_nonnegative _) BaseSystem.base_good limb_widths_match_modulus }.
  End BaseVector.
End ExtendedBaseVector.

Hint Rewrite @extended_base_length : distr_length.