aboutsummaryrefslogtreecommitdiff
path: root/src/Util/ZUtil/Definitions.v
blob: 8fe5772f55b294067b7beec1a19cd85aefc8c920 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.ZUtil.Notations.
Require Import Crypto.Util.LetIn.
Local Open Scope Z_scope.

Module Z.
  Definition pow2_mod n i := (n &' (Z.ones i)).

  Definition zselect (cond zero_case nonzero_case : Z) :=
    if cond =? 0 then zero_case else nonzero_case.

  Definition add_modulo x y modulus :=
    if (modulus <=? x + y) then (x + y) - modulus else (x + y).

  (** Logical negation, modulo a number *)
  Definition lnot_modulo (v : Z) (modulus : Z) : Z
    := Z.lnot v mod modulus.

  (** Boolean negation *)
  Definition bneg (v : Z) : Z
    := if dec (v = 0) then 1 else 0.

  (* most significant bit *)
  Definition cc_m s x := if dec (2 ^ (Z.log2 s) = s) then x >> (Z.log2 s - 1) else x / (s / 2).

  (* least significant bit *)
  Definition cc_l x := x mod 2.

  (* two-register right shift *)
  Definition rshi s hi lo n :=
       let k := Z.log2 s in
       if dec (2 ^ k = s)
       then ((lo + (hi << k)) >> n) &' (Z.ones k)
       else ((lo + hi * s) >> n) mod s.

  Definition get_carry (bitwidth : Z) (v : Z) : Z * Z
    := (v mod 2^bitwidth, v / 2^bitwidth).
  Definition add_with_carry (c : Z) (x y : Z) : Z
    := c + x + y.
  Definition add_with_get_carry (bitwidth : Z) (c : Z) (x y : Z) : Z * Z
    := dlet v := add_with_carry c x y in get_carry bitwidth v.
  Definition add_get_carry (bitwidth : Z) (x y : Z) : Z * Z
    := add_with_get_carry bitwidth 0 x y.

  Definition get_borrow (bitwidth : Z) (v : Z) : Z * Z
    := let '(v, c) := get_carry bitwidth v in
       (v, -c).
  Definition sub_with_borrow (c : Z) (x y : Z) : Z
    := add_with_carry (-c) x (-y).
  Definition sub_with_get_borrow (bitwidth : Z) (c : Z) (x y : Z) : Z * Z
    := dlet v := sub_with_borrow c x y in get_borrow bitwidth v.
  Definition sub_get_borrow (bitwidth : Z) (x y : Z) : Z * Z
    := sub_with_get_borrow bitwidth 0 x y.

  (* splits at [bound], not [2^bitwidth]; wrapper to make add_getcarry
  work if input is not known to be a power of 2 *)
  Definition add_get_carry_full (bound : Z) (x y : Z) : Z * Z
    := if 2 ^ (Z.log2 bound) =? bound
       then add_get_carry (Z.log2 bound) x y
       else ((x + y) mod bound, (x + y) / bound).
  Definition add_with_get_carry_full (bound : Z) (c x y : Z) : Z * Z
    := if 2 ^ (Z.log2 bound) =? bound
       then add_with_get_carry (Z.log2 bound) c x y
       else ((c + x + y) mod bound, (c + x + y) / bound).
  Definition sub_get_borrow_full (bound : Z) (x y : Z) : Z * Z
    := if 2 ^ (Z.log2 bound) =? bound
       then sub_get_borrow (Z.log2 bound) x y
       else ((x - y) mod bound, -((x - y) / bound)).
  Definition sub_with_get_borrow_full (bound : Z) (c x y : Z) : Z * Z
    := if 2 ^ (Z.log2 bound) =? bound
       then sub_with_get_borrow (Z.log2 bound) c x y
       else ((x - y - c) mod bound, -((x - y - c) / bound)).

  Definition mul_split_at_bitwidth (bitwidth : Z) (x y : Z) : Z * Z
    := dlet xy := x * y in
       (if Z.geb bitwidth 0
        then xy &' Z.ones bitwidth
        else xy mod 2^bitwidth,
        if Z.geb bitwidth 0
        then xy >> bitwidth
        else xy / 2^bitwidth).
  Definition mul_split (s x y : Z) : Z * Z
    := if s =? 2^Z.log2 s
       then mul_split_at_bitwidth (Z.log2 s) x y
       else ((x * y) mod s, (x * y) / s).

  (** if positive, round up to 2^k-1 (0b11111....); if negative, round down to -2^k (0b...111000000...) *)
  Definition round_lor_land_bound (x : Z) : Z
    := if (0 <=? x)%Z
       then 2^(Z.log2_up (x+1))-1
       else -2^(Z.log2_up (-x)).
End Z.