aboutsummaryrefslogtreecommitdiff
path: root/src/LegacyArithmetic/Double/Core.v
blob: 53f20801caf6333488c994f989aebc4ce46fe112 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
(*** Implementing Large Bounded Arithmetic via pairs *)
Require Import Coq.ZArith.ZArith.
Require Import Crypto.LegacyArithmetic.Interface.
Require Import Crypto.LegacyArithmetic.InterfaceProofs.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.ListUtil.
Require Import Crypto.Util.Notations.
Require Import Crypto.Util.LetIn.
Import Bug5107WorkAround.

Require Crypto.LegacyArithmetic.BaseSystem.
Require Crypto.LegacyArithmetic.Pow2Base.

Local Open Scope nat_scope.
Local Open Scope Z_scope.
Local Open Scope type_scope.

Local Coercion Z.of_nat : nat >-> Z.
Local Notation eta x := (fst x, snd x).

(** The list is low to high; the tuple is low to high *)
Definition tuple_decoder {n W} {decode : decoder n W} {k : nat} : decoder (k * n) (tuple W k)
  := {| decode w := BaseSystem.decode (Pow2Base.base_from_limb_widths (List.repeat n k))
                                      (List.map decode (List.rev (Tuple.to_list _ w))) |}.
Global Arguments tuple_decoder : simpl never.
Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances.

Section ripple_carry_definitions.
  (** tuple is high to low ([to_list] reverses) *)
  Fixpoint ripple_carry_tuple' {T} (f : T -> T -> bool -> bool * T) k
    : forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k
    := match k return forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k with
       | O => f
       | S k' => fun xss yss carry => dlet xss := xss in
                                      dlet yss := yss in
                                      let (xs, x) := eta xss in
                                      let (ys, y) := eta yss in
                                      dlet addv := (@ripple_carry_tuple' _ f k' xs ys carry) in
                                      let (carry, zs) := eta addv in
                                      dlet fxy := (f x y carry) in
                                      let (carry, z) := eta fxy in
                                      (carry, (zs, z))
       end.

  Definition ripple_carry_tuple {T} (f : T -> T -> bool -> bool * T) k
    : forall (xs ys : tuple T k) (carry : bool), bool * tuple T k
    := match k return forall (xs ys : tuple T k) (carry : bool), bool * tuple T k with
       | O => fun xs ys carry => (carry, tt)
       | S k' => ripple_carry_tuple' f k'
       end.
End ripple_carry_definitions.

Global Instance ripple_carry_adc
       {W} (adc : add_with_carry W) {k}
  : add_with_carry (tuple W k)
  := { adc := ripple_carry_tuple adc k }.

Global Instance ripple_carry_subc
       {W} (subc : sub_with_carry W) {k}
  : sub_with_carry (tuple W k)
  := { subc := ripple_carry_tuple subc k }.

(** constructions on [tuple W 2] *)
Section tuple2.
  Section select_conditional.
    Context {W}
            {selc : select_conditional W}.

    Definition select_conditional_double (b : bool) (x : tuple W 2) (y : tuple W 2) : tuple W 2
      := dlet x := x in
         dlet y := y in
         let (x1, x2) := eta x in
         let (y1, y2) := eta y in
         (selc b x1 y1, selc b x2 y2).

    Global Instance selc_double : select_conditional (tuple W 2)
      := { selc := select_conditional_double }.
  End select_conditional.

  Section load_immediate.
    Context (n : Z) {W}
            {ldi : load_immediate W}.

    Definition load_immediate_double (r : Z) : tuple W 2
      := (ldi (r mod 2^n), ldi (r / 2^n)).

    (** Require a [decoder] instance to aid typeclass search in
        resolving [n] *)
    Global Instance ldi_double {decode : decoder n W} : load_immediate (tuple W 2)
      := { ldi := load_immediate_double }.
  End load_immediate.

  Section bitwise_or.
    Context {W}
            {or : bitwise_or W}.

    Definition bitwise_or_double (x : tuple W 2) (y : tuple W 2) : tuple W 2
      := dlet x := x in
         dlet y := y in
         let (x1, x2) := eta x in
         let (y1, y2) := eta y in
         (or x1 y1, or x2 y2).

    Global Instance or_double : bitwise_or (tuple W 2)
      := { or := bitwise_or_double }.
  End bitwise_or.

  Section bitwise_and.
    Context {W}
            {and : bitwise_and W}.

    Definition bitwise_and_double (x : tuple W 2) (y : tuple W 2) : tuple W 2
      := dlet x := x in
         dlet y := y in
         let (x1, x2) := eta x in
         let (y1, y2) := eta y in
         (and x1 y1, and x2 y2).

    Global Instance and_double : bitwise_and (tuple W 2)
      := { and := bitwise_and_double }.
  End bitwise_and.

  Section spread_left.
    Context (n : Z) {W}
            {ldi : load_immediate W}
            {shl : shift_left_immediate W}
            {shr : shift_right_immediate W}.

    Definition spread_left_from_shift (r : W) (count : Z) : tuple W 2
      := dlet r := r in
         (shl r count, if count =? 0 then ldi 0 else shr r (n - count)).

    (** Require a [decoder] instance to aid typeclass search in
        resolving [n] *)
    Global Instance sprl_from_shift {decode : decoder n W} : spread_left_immediate W
      := { sprl := spread_left_from_shift }.
  End spread_left.

  Section shl_shr.
    Context (n : Z) {W}
            {ldi : load_immediate W}
            {shl : shift_left_immediate W}
            {shr : shift_right_immediate W}
            {or : bitwise_or W}.

    Definition shift_left_immediate_double (r : tuple W 2) (count : Z) : tuple W 2
      := dlet r := r in
         let (r1, r2) := eta r in
         (if count =? 0
          then r1
          else if count <? n
               then shl r1 count
               else ldi 0,
          if count =? 0
          then r2
          else if count <? n
               then or (shr r1 (n - count)) (shl r2 count)
               else shl r1 (count - n)).

    Definition shift_right_immediate_double (r : tuple W 2) (count : Z) : tuple W 2
      := dlet r := r in
         let (r1, r2) := eta r in
         (if count =? 0
          then r1
          else if count <? n
               then or (shr r1 count) (shl r2 (n - count))
               else shr r2 (count - n),
          if count =? 0
          then r2
          else if count <? n
               then shr r2 count
               else ldi 0).

    (** Require a [decoder] instance to aid typeclass search in
        resolving [n] *)
    Global Instance shl_double {decode : decoder n W} : shift_left_immediate (tuple W 2)
      := { shl := shift_left_immediate_double }.
    Global Instance shr_double {decode : decoder n W} : shift_right_immediate (tuple W 2)
      := { shr := shift_right_immediate_double }.
  End shl_shr.

  Section shrd.
    Context (n : Z) {W}
            {ldi : load_immediate W}
            {shrd : shift_right_doubleword_immediate W}.

    Definition shift_right_doubleword_immediate_double (high low : tuple W 2) (count : Z) : tuple W 2
      := dlet high := high in
         dlet low := low in
         let (high1, high2) := eta high in
         let (low1, low2) := eta low in
         (if count =? 0
          then low1
          else if count <? n
               then shrd low2 low1 count
               else if count <? 2 * n
                    then shrd high1 low2 (count - n)
                    else shrd high2 high1 (count - 2 * n),
          if count =? 0
          then low2
          else if count <? n
               then shrd high1 low2 count
               else if count <? 2 * n
                    then shrd high2 high1 (count - n)
                    else shrd (ldi 0) high2 (count - 2 * n)).

    (** Require a [decoder] instance to aid typeclass search in
        resolving [n] *)
    Global Instance shrd_double {decode : decoder n W} : shift_right_doubleword_immediate (tuple W 2)
      := { shrd := shift_right_doubleword_immediate_double }.
  End shrd.

  Section double_from_half.
    Context {half_n : Z} {W}
            {mulhwll : multiply_low_low W}
            {mulhwhl : multiply_high_low W}
            {mulhwhh : multiply_high_high W}
            {adc : add_with_carry W}
            {shl : shift_left_immediate W}
            {shr : shift_right_immediate W}
            {ldi : load_immediate W}.

    Definition mul_double (a b : W) : tuple W 2
      := dlet a              := a in
         dlet b              := b in
         let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in
         dlet out            := out in
         dlet tmp            := mulhwhl a b in
         dlet addv           := (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in
         let (_, out)        := eta addv in
         dlet tmp            := mulhwhl b a in
         dlet addv           := (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in
         let (_, out)        := eta addv in
         out.

    (** Require a dummy [decoder] for these instances to allow
            typeclass inference of the [half_n] argument *)
    Global Instance mul_double_multiply {decode : decoder (2 * half_n) W} : multiply_double W
      := { muldw a b := mul_double a b }.
  End double_from_half.

  Global Instance mul_double_multiply_low_low {W} {muldw : multiply_double W}
    : multiply_low_low (tuple W 2)
    := { mulhwll a b := muldw (fst a) (fst b) }.
  Global Instance mul_double_multiply_high_low {W} {muldw : multiply_double W}
    : multiply_high_low (tuple W 2)
    := { mulhwhl a b := muldw (snd a) (fst b) }.
  Global Instance mul_double_multiply_high_high {W} {muldw : multiply_double W}
    : multiply_high_high (tuple W 2)
    := { mulhwhh a b := muldw (snd a) (snd b) }.
End tuple2.

Global Arguments mul_double half_n {_ _ _ _ _ _ _} _ _.