aboutsummaryrefslogtreecommitdiff
path: root/src/Util/ZUtil/CPS.v
blob: e2b21933bfe1b58c225dfa412b5c510e36c2ceed (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Util.ZUtil.Definitions.
Require Import Crypto.Util.ZUtil.Tactics.LtbToLt.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Tactics.Head.

Local Open Scope Z_scope.

Module Z.
  Definition eq_dec_cps {T} (x y : Z) (f : {x = y} + {x <> y} -> T) : T
    := f (Z.eq_dec x y).
  Definition eq_dec_cps_correct {T} x y f : @eq_dec_cps T x y f = f (Z.eq_dec x y)
    := eq_refl.
  Hint Rewrite @eq_dec_cps_correct : uncps.

  Definition eqb_cps {T} (x y : Z) (f : bool -> T) : T
    := f (Z.eqb x y).
  Definition eqb_cps_correct {T} x y f : @eqb_cps T x y f = f (Z.eqb x y)
    := eq_refl.
  Hint Rewrite @eqb_cps_correct : uncps.

  Local Ltac prove_cps_correct _ :=
    try match goal with
        | [ |- ?lhs ?f = ?f ?rhs ]
          => let l := head lhs in
             let r := head rhs in
             cbv [l r] in *
        end;
    repeat first [ reflexivity
                 | progress cbv [Decidable.dec Decidable.dec_eq_Z] in *
                 | progress Z.ltb_to_lt
                 | congruence
                 | progress autorewrite with uncps
                 | break_innermost_match_step ].

  Definition get_carry_cps {T} (bitwidth : Z) (v : Z) (f : Z * Z -> T) : T
    := f (Z.get_carry bitwidth v).
  Definition get_carry_cps_correct {T} bitwidth v f
    : @get_carry_cps T bitwidth v f = f (Z.get_carry bitwidth v)
    := eq_refl.
  Hint Rewrite @get_carry_cps_correct : uncps.
  Definition add_with_get_carry_cps {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T) : T
    := f (Z.add_with_get_carry bitwidth c x y).
  Definition add_with_get_carry_cps_correct {T} bitwidth c x y f
    : @add_with_get_carry_cps T bitwidth c x y f = f (Z.add_with_get_carry bitwidth c x y)
    := eq_refl.
  Hint Rewrite @add_with_get_carry_cps_correct : uncps.
  Definition add_get_carry_cps {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : T
    := f (Z.add_get_carry bitwidth x y).
  Definition add_get_carry_cps_correct {T} bitwidth x y f
    : @add_get_carry_cps T bitwidth x y f = f (Z.add_get_carry bitwidth x y)
    := eq_refl.
  Hint Rewrite @add_get_carry_cps_correct : uncps.

  Definition get_borrow_cps {T} (bitwidth : Z) (v : Z) (f : Z * Z -> T)
    := f (Z.get_borrow bitwidth v).
  Definition get_borrow_cps_correct {T} bitwidth v f
    : @get_borrow_cps T bitwidth v f = f (Z.get_borrow bitwidth v)
    := eq_refl.
  Hint Rewrite @get_borrow_cps_correct : uncps.
  Definition sub_with_get_borrow_cps {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T) : T
    := f (Z.sub_with_get_borrow bitwidth c x y).
  Definition sub_with_get_borrow_cps_correct {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T)
    : @sub_with_get_borrow_cps T bitwidth c x y f = f (Z.sub_with_get_borrow bitwidth c x y)
    := eq_refl.
  Hint Rewrite @sub_with_get_borrow_cps_correct : uncps.
  Definition sub_get_borrow_cps {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : T
    := f (Z.sub_get_borrow bitwidth x y).
  Definition sub_get_borrow_cps_correct {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T)
    : @sub_get_borrow_cps T bitwidth x y f = f (Z.sub_get_borrow bitwidth x y)
    := eq_refl.
  Hint Rewrite @sub_get_borrow_cps_correct : uncps.

  (* splits at [bound], not [2^bitwidth]; wrapper to make add_getcarry
  work if input is not known to be a power of 2 *)
  Definition add_get_carry_full_cps {T} (bound : Z) (x y : Z) (f : Z * Z -> T) : T
    := eqb_cps
         (2 ^ (Z.log2 bound)) bound
         (fun eqb
          => if eqb
             then add_get_carry_cps (Z.log2 bound) x y f
             else f ((x + y) mod bound, (x + y) / bound)).
  Lemma add_get_carry_full_cps_correct {T} (bound : Z) (x y : Z) (f : Z * Z -> T)
    : @add_get_carry_full_cps T bound x y f = f (Z.add_get_carry_full bound x y).
  Proof. prove_cps_correct (). Qed.
  Hint Rewrite @add_get_carry_full_cps_correct : uncps.
  Definition add_with_get_carry_full_cps {T} (bound : Z) (c x y : Z) (f : Z * Z -> T) : T
    := eqb_cps
         (2 ^ (Z.log2 bound)) bound
         (fun eqb
          => if eqb
             then add_with_get_carry_cps (Z.log2 bound) c x y f
             else f ((c + x + y) mod bound, (c + x + y) / bound)).
  Lemma add_with_get_carry_full_cps_correct {T} (bound : Z) (c x y : Z) (f : Z * Z -> T)
    : @add_with_get_carry_full_cps T bound c x y f = f (Z.add_with_get_carry_full bound c x y).
  Proof. prove_cps_correct (). Qed.
  Hint Rewrite @add_with_get_carry_full_cps_correct : uncps.
  Definition sub_get_borrow_full_cps {T} (bound : Z) (x y : Z) (f : Z * Z -> T) : T
    := eqb_cps
         (2 ^ (Z.log2 bound)) bound
         (fun eqb
          => if eqb
             then sub_get_borrow_cps (Z.log2 bound) x y f
             else f ((x - y) mod bound, -((x - y) / bound))).
  Lemma sub_get_borrow_full_cps_correct {T} (bound : Z) (x y : Z) (f : Z * Z -> T)
    : @sub_get_borrow_full_cps T bound x y f = f (Z.sub_get_borrow_full bound x y).
  Proof. prove_cps_correct (). Qed.
  Hint Rewrite @sub_get_borrow_full_cps_correct : uncps.
  Definition sub_with_get_borrow_full_cps {T} (bound : Z) (c x y : Z) (f : Z * Z -> T) : T
    := eqb_cps
         (2 ^ (Z.log2 bound)) bound
         (fun eqb
          => if eqb
             then sub_with_get_borrow_cps (Z.log2 bound) c x y f
             else f ((x - y - c) mod bound, -((x - y - c) / bound))).
  Lemma sub_with_get_borrow_full_cps_correct {T} (bound : Z) (c x y : Z) (f : Z * Z -> T)
    : @sub_with_get_borrow_full_cps T bound c x y f = f (Z.sub_with_get_borrow_full bound c x y).
  Proof. prove_cps_correct (). Qed.
  Hint Rewrite @sub_with_get_borrow_full_cps_correct : uncps.

  Definition mul_split_at_bitwidth_cps {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : T
    := dlet xy := x * y in
        f (if Z.geb bitwidth 0
           then Z.land xy (Z.ones bitwidth)
           else xy mod 2^bitwidth,
           if Z.geb bitwidth 0
           then Z.shiftr xy bitwidth
           else xy / 2^bitwidth).
  Definition mul_split_at_bitwidth_cps_correct {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T)
    : @mul_split_at_bitwidth_cps T bitwidth x y f = f (Z.mul_split_at_bitwidth bitwidth x y)
    := eq_refl.
  Hint Rewrite @mul_split_at_bitwidth_cps_correct : uncps.
  Definition mul_split_cps {T} (s x y : Z) (f : Z * Z -> T) : T
    := eqb_cps
         s (2^Z.log2 s)
         (fun b
          => if b
             then mul_split_at_bitwidth_cps (Z.log2 s) x y f
             else f ((x * y) mod s, (x * y) / s)).
  Lemma mul_split_cps_correct {T} (s x y : Z) (f : Z * Z -> T)
    : @mul_split_cps T s x y f = f (Z.mul_split s x y).
  Proof. prove_cps_correct (). Qed.
  Hint Rewrite @mul_split_cps_correct : uncps.

  Definition mul_split_cps' {T} (s x y : Z) (f : Z * Z -> T) : T
    := eqb_cps
         s (2^Z.log2 s)
         (fun b
          => if b
             then f (Z.mul_split_at_bitwidth (Z.log2 s) x y)
             else f ((x * y) mod s, (x * y) / s)).
  Lemma mul_split_cps'_correct {T} (s x y : Z) (f : Z * Z -> T)
    : @mul_split_cps' T s x y f = f (Z.mul_split s x y).
  Proof. prove_cps_correct (). Qed.
  Hint Rewrite @mul_split_cps'_correct : uncps.
End Z.