aboutsummaryrefslogtreecommitdiff
path: root/src/Reflection/Z/Bounds/Interpretation.v
blob: 8e90213b60392fe6cb2a4860766dacc7b7242431 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Reflection.Z.Syntax.
Require Import Crypto.Reflection.Syntax.
Require Import Crypto.Reflection.Relations.
Require Import Crypto.Util.Option.
Require Import Crypto.Util.Notations.
Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.ZRange.
Require Import Crypto.Util.Tactics.DestructHead.
Export Reflection.Syntax.Notations.

Local Notation eta x := (fst x, snd x).
Local Notation eta3 x := (eta (fst x), snd x).
Local Notation eta4 x := (eta3 (fst x), snd x).

Notation bounds := zrange.
Delimit Scope bounds_scope with bounds.

Module Import Bounds.
  Definition t := option bounds. (* TODO?: Separate out the bounds computation from the overflow computation? e.g., have [safety := in_bounds | overflow] and [t := bounds * safety]? *)
  Bind Scope bounds_scope with t.
  Local Coercion Z.of_nat : nat >-> Z.
  Section with_bitwidth.
    Context (bit_width : option Z).
    Definition SmartBuildBounds (l u : Z)
      := if ((0 <=? l) && (match bit_width with Some bit_width => u <? 2^bit_width | None => true end))%Z%bool
         then Some {| lower := l ; upper := u |}
         else None.
    Definition SmartRebuildBounds (b : t) : t
      := match b with
         | Some b => SmartBuildBounds (lower b) (upper b)
         | None => None
         end.
    Definition t_map1 (f : bounds -> bounds) (x : t)
      := match x with
         | Some x
           => match f x with
              | {| lower := l ; upper := u |}
                => SmartBuildBounds l u
              end
         | _ => None
         end%Z.
    Definition t_map2 (f : bounds -> bounds -> bounds) (x y : t)
      := match x, y with
         | Some x, Some y
           => match f x y with
              | {| lower := l ; upper := u |}
                => SmartBuildBounds l u
              end
         | _, _ => None
         end%Z.
    Definition t_map4 (f : bounds -> bounds -> bounds -> bounds -> bounds) (x y z w : t)
      := match x, y, z, w with
         | Some x, Some y, Some z, Some w
           => match f x y z w with
              | {| lower := l ; upper := u |}
                => SmartBuildBounds l u
              end
         | _, _, _, _ => None
         end%Z.
    Definition add' : bounds -> bounds -> bounds
      := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx + ly ; upper := ux + uy |}.
    Definition add : t -> t -> t := t_map2 add'.
    Definition sub' : bounds -> bounds -> bounds
      := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx - uy ; upper := ux - ly |}.
    Definition sub : t -> t -> t := t_map2 sub'.
    Definition mul' : bounds -> bounds -> bounds
      := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx * ly ; upper := ux * uy |}.
    Definition mul : t -> t -> t := t_map2 mul'.
    Definition shl' : bounds -> bounds -> bounds
      := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := Z.shiftl lx ly ; upper := Z.shiftl ux uy |}.
    Definition shl : t -> t -> t := t_map2 shl'.
    Definition shr' : bounds -> bounds -> bounds
      := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := Z.shiftr lx uy ; upper := Z.shiftr ux ly |}.
    Definition shr : t -> t -> t := t_map2 shr'.
    Definition land' : bounds -> bounds -> bounds
      := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := 0 ; upper := Z.min ux uy |}.
    Definition land : t -> t -> t := t_map2 land'.
    Definition lor' : bounds -> bounds -> bounds
      := fun x y => let (lx, ux) := x in let (ly, uy) := y in
                                         {| lower := Z.max lx ly;
                                            upper := 2^(Z.max (Z.log2_up (ux+1)) (Z.log2_up (uy+1))) - 1 |}.
    Definition lor : t -> t -> t := t_map2 lor'.
    Definition neg' (int_width : Z) : bounds -> bounds
      := fun v
         => let (lb, ub) := v in
            let might_be_one := ((lb <=? 1) && (1 <=? ub))%Z%bool in
            let must_be_one := ((lb =? 1) && (ub =? 1))%Z%bool in
            if must_be_one
            then {| lower := Z.ones int_width ; upper := Z.ones int_width |}
            else if might_be_one
                 then {| lower := 0 ; upper := Z.ones int_width |}
                 else {| lower := 0 ; upper := 0 |}.
    Definition neg (int_width : Z) : t -> t
      := fun v
         => if ((0 <=? int_width) (*&& (int_width <=? WordW.bit_width)*))%Z%bool
            then t_map1 (neg' int_width) v
            else None.
    Definition cmovne' (r1 r2 : bounds) : bounds
      := let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}.
    Definition cmovne (x y r1 r2 : t) : t := t_map4 (fun _ _ => cmovne') x y r1 r2.
    Definition cmovle' (r1 r2 : bounds) : bounds
      := let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}.
    Definition cmovle (x y r1 r2 : t) : t := t_map4 (fun _ _ => cmovle') x y r1 r2.
  End with_bitwidth.

  Module Export Notations.
    Export Util.ZRange.Notations.
    Infix "+" := (add _) : bounds_scope.
    Infix "-" := (sub _) : bounds_scope.
    Infix "*" := (mul _) : bounds_scope.
    Infix "<<" := (shl _) : bounds_scope.
    Infix ">>" := (shr _) : bounds_scope.
    Infix "&'" := (land _) : bounds_scope.
  End Notations.

  Definition interp_base_type (ty : base_type) : Set := t.

  Definition bit_width_of_base_type ty : option Z
    := match ty with
       | TZ => None
       | TWord logsz => Some (2^Z.of_nat logsz)%Z
       end.

  Definition interp_op {src dst} (f : op src dst) : interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst
    := match f in op src dst return interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst with
       | OpConst TZ v => fun _ => SmartBuildBounds None v v
       | OpConst (TWord _ as T) v => fun _ => SmartBuildBounds (bit_width_of_base_type T) ((*FixedWordSizes.wordToZ*) v) ((*FixedWordSizes.wordToZ*) v)
       | Add T => fun xy => add (bit_width_of_base_type T) (fst xy) (snd xy)
       | Sub T => fun xy => sub (bit_width_of_base_type T) (fst xy) (snd xy)
       | Mul T => fun xy => mul (bit_width_of_base_type T) (fst xy) (snd xy)
       | Shl T => fun xy => shl (bit_width_of_base_type T) (fst xy) (snd xy)
       | Shr T => fun xy => shr (bit_width_of_base_type T) (fst xy) (snd xy)
       | Land T => fun xy => land (bit_width_of_base_type T) (fst xy) (snd xy)
       | Lor T => fun xy => lor (bit_width_of_base_type T) (fst xy) (snd xy)
       | Neg T int_width => fun x => neg (bit_width_of_base_type T) int_width x
       | Cmovne T => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne (bit_width_of_base_type T) x y z w
       | Cmovle T => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle (bit_width_of_base_type T) x y z w
       | Cast _ T => fun x => SmartRebuildBounds (bit_width_of_base_type T) x
       end%bounds.

  Definition of_Z (z : Z) : t := Some (ZToZRange z).

  Definition of_interp t (z : Syntax.interp_base_type t) : interp_base_type t
    := Some (ZToZRange (match t return Syntax.interp_base_type t -> Z with
                        | TZ => fun z => z
                        | TWord logsz => fun z => z (*FixedWordSizes.wordToZ*)
                        end z)).

  Definition bounds_to_base_type' (b : bounds) : base_type
    := if (0 <=? lower b)%Z
       then TWord (Z.to_nat (Z.log2_up (Z.log2_up (1 + upper b))))
       else TZ.
  Definition bounds_to_base_type (b : t) : base_type
    := match b with
       | None => TZ
       | Some b' => bounds_to_base_type' b'
       end.

  Definition ComputeBounds {t} (e : Expr base_type op t)
             (input_bounds : interp_flat_type interp_base_type (domain t))
    : interp_flat_type interp_base_type (codomain t)
    := Interp (@interp_op) e input_bounds.

  Definition bound_is_goodb : forall t, interp_base_type t -> bool
    := fun t bs
       => match bs with
          | Some bs
            => (*let l := lower bs in
               let u := upper bs in
               let bit_width := bit_width_of_base_type t in
               ((0 <=? l) && (match bit_width with Some bit_width => Z.log2 u <? bit_width | None => true end))%Z%bool*)
            true
          | None => false
          end.
  Definition bound_is_good : forall t, interp_base_type t -> Prop
    := fun t v => bound_is_goodb t v = true.
  Definition bounds_are_good : forall {t}, interp_flat_type interp_base_type t -> Prop
    := (@interp_flat_type_rel_pointwise1 _ _ bound_is_good).

  Definition is_bounded_by' {T} : Syntax.interp_base_type T -> interp_base_type T -> Prop
    := fun val bound
       => match bound with
          | Some bounds'
            => is_bounded_by' (bit_width_of_base_type T) bounds' val
          | None => True
          end.

  Definition is_bounded_by {T} : interp_flat_type Syntax.interp_base_type T -> interp_flat_type interp_base_type T -> Prop
    := interp_flat_type_rel_pointwise (@is_bounded_by').

  Local Arguments interp_base_type !_ / .
  Global Instance dec_eq_interp_flat_type {T} : DecidableRel (@eq (interp_flat_type interp_base_type T)) | 10.
  Proof.
    induction T; destruct_head base_type; simpl; exact _.
  Defined.
End Bounds.