aboutsummaryrefslogtreecommitdiff
path: root/src/BaseSystem.v
blob: 48b6468cf5c82faada50a4011132b1c9f4fa66f9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
Require Import Coq.Lists.List.
Require Import Coq.ZArith.ZArith Coq.ZArith.Zdiv.
Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.Util.Notations.
Require Export Crypto.Util.FixCoqMistakes.
Import Nat.

Local Open Scope Z.

Class BaseVector (base : list Z):= {
  base_positive : forall b, In b base -> b > 0; (* nonzero would probably work too... *)
  b0_1 : forall x, nth_default x base 0 = 1; (** TODO(jadep,jgross): change to [nth_error base 0 = Some 1], then use [nth_error_value_eq_nth_default] to prove a [forall x, nth_default x base 0 = 1] as a lemma *)
  base_good :
    forall i j, (i+j < length base)%nat ->
    let b := nth_default 0 base in
    let r := (b i * b j) / b (i+j)%nat in
    b i * b j = r * b (i+j)%nat
}.

Section BaseSystem.
  Context (base : list Z).
  (** [BaseSystem] implements an constrained positional number system.
      A wide variety of bases are supported: the base coefficients are not
      required to be powers of 2, and it is NOT necessarily the case that
      $b_{i+j} = b_i b_j$. Implementations of addition and multiplication are
      provided, with focus on near-optimal multiplication performance on
      non-trivial but small operands: maybe 10 32-bit integers or so. This
      module does not handle carries automatically: if no restrictions are put
      on the use of a [BaseSystem], each digit is unbounded. This has nothing
      to do with modular arithmetic either.
  *)
  Definition digits : Type := list Z.

  Definition accumulate p acc := fst p * snd p + acc.
  Definition decode' bs u  := fold_right accumulate 0 (combine u bs).
  Definition decode := decode' base.

  (* i is current index, counts down *)
  Fixpoint encode' z max i : digits :=
    match i with
    | O => nil
    | S i' => let b := nth_default max base in
       encode' z max i' ++ ((z mod (b i)) / (b i')) :: nil
    end.

  (* max must be greater than input; this is used to truncate last digit *)
  Definition encode z max := encode' z max (length base).

  Lemma decode'_truncate : forall bs us, decode' bs us = decode' bs (firstn (length bs) us).
  Proof.
    unfold decode'; intros; f_equal; apply combine_truncate_l.
  Qed.

  Fixpoint add (us vs:digits) : digits :=
    match us,vs with
      | u::us', v::vs' => u+v :: add us' vs'
      | _, nil => us
      | _, _ => vs
    end.
  Infix ".+" := add.

  Hint Extern 1 (@eq Z _ _) => ring.

  Definition mul_each u := map (Z.mul u).
  Fixpoint sub (us vs:digits) : digits :=
    match us,vs with
      | u::us', v::vs' => u-v :: sub us' vs'
      | _, nil => us
      | nil, v::vs' => (0-v)::sub nil vs'
    end.

  Definition crosscoef i j : Z :=
    let b := nth_default 0 base in
    (b(i) * b(j)) / b(i+j)%nat.
  Hint Unfold crosscoef.

  Fixpoint zeros n := match n with O => nil | S n' => 0::zeros n' end.

  (* mul' is multiplication with the SECOND ARGUMENT REVERSED and OUTPUT REVERSED *)
  Fixpoint mul_bi' (i:nat) (vsr:digits) :=
    match vsr with
    | v::vsr' => v * crosscoef i (length vsr') :: mul_bi' i vsr'
    | nil => nil
    end.
  Definition mul_bi (i:nat) (vs:digits) : digits :=
    zeros i ++ rev (mul_bi' i (rev vs)).

  (* mul' is multiplication with the FIRST ARGUMENT REVERSED *)
  Fixpoint mul' (usr vs:digits) : digits :=
    match usr with
      | u::usr' =>
        mul_each u (mul_bi (length usr') vs) .+ mul' usr' vs
      | _ => nil
    end.
  Definition mul us := mul' (rev us).

End BaseSystem.

(* Example : polynomial base system *)
Section PolynomialBaseCoefs.
  Context (b1 : positive) (baseLength : nat) (baseLengthNonzero : ltb 0 baseLength = true).
  (** PolynomialBaseCoefs generates base vectors for [BaseSystem]. *)
  Definition bi i := (Zpos b1)^(Z.of_nat i).
  Definition poly_base := map bi (seq 0 baseLength).

  Lemma poly_b0_1 : forall x, nth_default x poly_base 0 = 1.
    unfold poly_base, bi, nth_default.
    case_eq baseLength; intros. {
      assert ((0 < baseLength)%nat) by
        (rewrite <-ltb_lt; apply baseLengthNonzero).
      subst; omega.
    }
    auto.
  Qed.

  Lemma poly_base_positive : forall b, In b poly_base -> b > 0.
  Proof.
    unfold poly_base.
    intros until 0; intro H.
    rewrite in_map_iff in *.
    destruct H; destruct H.
    subst.
    apply Z.pos_pow_nat_pos.
  Qed.

  Lemma poly_base_defn : forall i, (i < length poly_base)%nat ->
      nth_default 0 poly_base i = bi i.
  Proof.
    unfold poly_base, nth_default; nth_tac.
  Qed.

  Lemma poly_base_succ :
    forall i, ((S i) < length poly_base)%nat ->
    let b := nth_default 0 poly_base in
    let r := (b (S i) / b i) in
    b (S i) = r * b i.
  Proof.
    intros; subst b; subst r.
    repeat rewrite poly_base_defn in * by omega.
    unfold bi.
    replace (Z.pos b1 ^ Z.of_nat (S i))
      with (Z.pos b1 * (Z.pos b1 ^ Z.of_nat i)) by
      (rewrite Nat2Z.inj_succ; rewrite <- Z.pow_succ_r; intuition auto with zarith).
    replace (Z.pos b1 * Z.pos b1 ^ Z.of_nat i / Z.pos b1 ^ Z.of_nat i)
      with (Z.pos b1); auto.
    rewrite Z_div_mult_full; auto.
    apply Z.pow_nonzero; intuition auto with lia.
  Qed.

  Lemma poly_base_good:
    forall i j, (i + j < length poly_base)%nat ->
    let b := nth_default 0 poly_base in
    let r := (b i * b j) / b (i+j)%nat in
    b i * b j = r * b (i+j)%nat.
  Proof.
    unfold poly_base, nth_default; nth_tac.

    clear.
    unfold bi.
    rewrite Nat2Z.inj_add, Zpower_exp by
      (replace 0 with (Z.of_nat 0) by auto; rewrite <- Nat2Z.inj_ge; omega).
    rewrite Z_div_same_full; try ring.
    rewrite <- Z.neq_mul_0.
    split; apply Z.pow_nonzero; try apply Zle_0_nat; try solve [intro H; inversion H].
  Qed.

  Instance PolyBaseVector : BaseVector poly_base := {
    base_positive := poly_base_positive;
    b0_1 := poly_b0_1;
    base_good := poly_base_good
  }.

End PolynomialBaseCoefs.

Import ListNotations.

Section BaseSystemExample.
  Definition baseLength := 32%nat.
  Lemma baseLengthNonzero : ltb 0 baseLength = true.
    compute; reflexivity.
  Qed.
  Definition base2 := poly_base 2 baseLength.

  Example three_times_two : mul base2 [1;1;0] [0;1;0] = [0;1;1;0;0].
  Proof.
    reflexivity.
  Qed.

  (* python -c "e = lambda x: '['+''.join(reversed(bin(x)[2:])).replace('1','1;').replace('0','0;')[:-1]+']'; print(e(19259)); print(e(41781))" *)
  Definition a := [1;1;0;1;1;1;0;0;1;1;0;1;0;0;1].
  Definition b := [1;0;1;0;1;1;0;0;1;1;0;0;0;1;0;1].
  Example da : decode base2 a = 19259.
  Proof.
    reflexivity.
  Qed.
  Example db : decode base2 b = 41781.
  Proof.
    reflexivity.
  Qed.
  Example encoded_ab :
    mul base2 a b =[1;1;1;2;2;4;2;2;4;5;3;3;3;6;4;2;5;3;4;3;2;1;2;2;2;0;1;1;0;1].
  Proof.
    reflexivity.
  Qed.
  Example dab : decode base2 (mul base2 a b) = 804660279.
  Proof.
    reflexivity.
  Qed.
End BaseSystemExample.