aboutsummaryrefslogtreecommitdiff
path: root/src/Reflection/Z/Bounds/Interpretation.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Reflection/Z/Bounds/Interpretation.v')
-rw-r--r--src/Reflection/Z/Bounds/Interpretation.v188
1 files changed, 77 insertions, 111 deletions
diff --git a/src/Reflection/Z/Bounds/Interpretation.v b/src/Reflection/Z/Bounds/Interpretation.v
index 0a0bb28f0..cad5d87b3 100644
--- a/src/Reflection/Z/Bounds/Interpretation.v
+++ b/src/Reflection/Z/Bounds/Interpretation.v
@@ -2,7 +2,6 @@ Require Import Coq.ZArith.ZArith.
Require Import Crypto.Reflection.Z.Syntax.
Require Import Crypto.Reflection.Syntax.
Require Import Crypto.Reflection.Relations.
-Require Import Crypto.Util.Option.
Require Import Crypto.Util.Notations.
Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.ZRange.
@@ -15,73 +14,70 @@ Local Notation eta4 x := (eta3 (fst x), snd x).
Notation bounds := zrange.
Delimit Scope bounds_scope with bounds.
+Local Open Scope Z_scope.
Module Import Bounds.
- Definition t := option bounds. (* TODO?: Separate out the bounds computation from the overflow computation? e.g., have [safety := in_bounds | overflow] and [t := bounds * safety]? *)
+ Definition t := bounds.
Bind Scope bounds_scope with t.
Local Coercion Z.of_nat : nat >-> Z.
Section with_bitwidth.
Context (bit_width : option Z).
- Definition SmartBuildBounds (l u : Z)
- := if ((0 <=? l) && (match bit_width with Some bit_width => u <? 2^bit_width | None => true end))%Z%bool
- then Some {| lower := l ; upper := u |}
- else None.
- Definition SmartRebuildBounds (b : t) : t
- := match b with
- | Some b => SmartBuildBounds (lower b) (upper b)
- | None => None
+ Definition four_corners (f : Z -> Z -> Z) : t -> t -> t
+ := fun x y
+ => let (lx, ux) := x in
+ let (ly, uy) := y in
+ {| lower := Z.min (f lx ly) (Z.min (f lx uy) (Z.min (f ux ly) (f ux uy)));
+ upper := Z.max (f lx ly) (Z.max (f lx uy) (Z.max (f ux ly) (f ux uy))) |}.
+ Definition truncation_bounds (b : t)
+ := match bit_width with
+ | Some bit_width => if ((0 <=? lower b) && (upper b <? 2^bit_width))%bool
+ then b
+ else {| lower := 0 ; upper := 2^bit_width - 1 |}
+ | None => b
end.
+ Definition BuildTruncated_bounds (l u : Z) : t
+ := truncation_bounds {| lower := l ; upper := u |}.
Definition t_map1 (f : bounds -> bounds) (x : t)
- := match x with
- | Some x
- => match f x with
- | {| lower := l ; upper := u |}
- => SmartBuildBounds l u
- end
- | _ => None
- end%Z.
- Definition t_map2 (f : bounds -> bounds -> bounds) (x y : t)
- := match x, y with
- | Some x, Some y
- => match f x y with
- | {| lower := l ; upper := u |}
- => SmartBuildBounds l u
- end
- | _, _ => None
- end%Z.
+ := truncation_bounds (f x).
+ Definition t_map2 (f : Z -> Z -> Z) : t -> t -> t
+ := fun x y => truncation_bounds (four_corners f x y).
Definition t_map4 (f : bounds -> bounds -> bounds -> bounds -> bounds) (x y z w : t)
- := match x, y, z, w with
- | Some x, Some y, Some z, Some w
- => match f x y z w with
- | {| lower := l ; upper := u |}
- => SmartBuildBounds l u
- end
- | _, _, _, _ => None
- end%Z.
- Definition add' : bounds -> bounds -> bounds
- := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx + ly ; upper := ux + uy |}.
- Definition add : t -> t -> t := t_map2 add'.
- Definition sub' : bounds -> bounds -> bounds
- := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx - uy ; upper := ux - ly |}.
- Definition sub : t -> t -> t := t_map2 sub'.
- Definition mul' : bounds -> bounds -> bounds
- := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx * ly ; upper := ux * uy |}.
- Definition mul : t -> t -> t := t_map2 mul'.
- Definition shl' : bounds -> bounds -> bounds
- := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := Z.shiftl lx ly ; upper := Z.shiftl ux uy |}.
- Definition shl : t -> t -> t := t_map2 shl'.
- Definition shr' : bounds -> bounds -> bounds
- := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := Z.shiftr lx uy ; upper := Z.shiftr ux ly |}.
- Definition shr : t -> t -> t := t_map2 shr'.
- Definition land' : bounds -> bounds -> bounds
- := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := 0 ; upper := Z.min ux uy |}.
- Definition land : t -> t -> t := t_map2 land'.
- Definition lor' : bounds -> bounds -> bounds
- := fun x y => let (lx, ux) := x in let (ly, uy) := y in
- {| lower := Z.max lx ly;
- upper := 2^(Z.max (Z.log2_up (ux+1)) (Z.log2_up (uy+1))) - 1 |}.
- Definition lor : t -> t -> t := t_map2 lor'.
- Definition neg' (int_width : Z) : bounds -> bounds
+ := truncation_bounds (f x y z w).
+ Definition add : t -> t -> t := t_map2 Z.add.
+ Definition sub : t -> t -> t := t_map2 Z.sub.
+ Definition mul : t -> t -> t := t_map2 Z.mul.
+ Definition shl : t -> t -> t := t_map2 Z.shiftl.
+ Definition shr : t -> t -> t := t_map2 Z.shiftr.
+ Definition extreme_lor_land_bounds (x y : t) : t
+ := let (lx, ux) := x in
+ let (ly, uy) := y in
+ let lx := Z.abs lx in
+ let ly := Z.abs ly in
+ let ux := Z.abs ux in
+ let uy := Z.abs uy in
+ let max := Z.max (Z.max lx ly) (Z.max ux uy) in
+ {| lower := -2^(1 + Z.log2_up max) ; upper := 2^(1 + Z.log2_up max) |}.
+ Definition extermization_bounds (f : t -> t -> t) (x y : t) : t
+ := truncation_bounds
+ (let (lx, ux) := x in
+ let (ly, uy) := y in
+ if ((lx <? 0) || (ly <? 0))%Z%bool
+ then extreme_lor_land_bounds x y
+ else f x y).
+ Definition land : t -> t -> t
+ := extermization_bounds
+ (fun x y
+ => let (lx, ux) := x in
+ let (ly, uy) := y in
+ {| lower := Z.min 0 (Z.min lx ly) ; upper := Z.max 0 (Z.min ux uy) |}).
+ Definition lor : t -> t -> t
+ := extermization_bounds
+ (fun x y
+ => let (lx, ux) := x in
+ let (ly, uy) := y in
+ {| lower := Z.max lx ly;
+ upper := 2^(Z.max (Z.log2_up (ux+1)) (Z.log2_up (uy+1))) - 1 |}).
+ Definition neg' (int_width : Z) : t -> t
:= fun v
=> let (lb, ub) := v in
let might_be_one := ((lb <=? 1) && (1 <=? ub))%Z%bool in
@@ -89,19 +85,23 @@ Module Import Bounds.
if must_be_one
then {| lower := Z.ones int_width ; upper := Z.ones int_width |}
else if might_be_one
- then {| lower := 0 ; upper := Z.ones int_width |}
+ then {| lower := Z.min 0 (Z.ones int_width) ; upper := Z.max 0 (Z.ones int_width) |}
else {| lower := 0 ; upper := 0 |}.
Definition neg (int_width : Z) : t -> t
:= fun v
- => if ((0 <=? int_width) (*&& (int_width <=? WordW.bit_width)*))%Z%bool
- then t_map1 (neg' int_width) v
- else None.
- Definition cmovne' (r1 r2 : bounds) : bounds
- := let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}.
- Definition cmovne (x y r1 r2 : t) : t := t_map4 (fun _ _ => cmovne') x y r1 r2.
- Definition cmovle' (r1 r2 : bounds) : bounds
- := let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}.
- Definition cmovle (x y r1 r2 : t) : t := t_map4 (fun _ _ => cmovle') x y r1 r2.
+ => truncation_bounds (neg' int_width v).
+ Definition cmovne' (r1 r2 : t) : t
+ := let (lr1, ur1) := r1 in
+ let (lr2, ur2) := r2 in
+ {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}.
+ Definition cmovne (x y r1 r2 : t) : t
+ := truncation_bounds (cmovne' r1 r2).
+ Definition cmovle' (r1 r2 : t) : t
+ := let (lr1, ur1) := r1 in
+ let (lr2, ur2) := r2 in
+ {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}.
+ Definition cmovle (x y r1 r2 : t) : t
+ := truncation_bounds (cmovle' r1 r2).
End with_bitwidth.
Module Export Notations.
@@ -124,8 +124,8 @@ Module Import Bounds.
Definition interp_op {src dst} (f : op src dst) : interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst
:= match f in op src dst return interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst with
- | OpConst TZ v => fun _ => SmartBuildBounds None v v
- | OpConst (TWord _ as T) v => fun _ => SmartBuildBounds (bit_width_of_base_type T) ((*FixedWordSizes.wordToZ*) v) ((*FixedWordSizes.wordToZ*) v)
+ | OpConst TZ v => fun _ => BuildTruncated_bounds None v v
+ | OpConst (TWord _ as T) v => fun _ => BuildTruncated_bounds (bit_width_of_base_type T) ((*FixedWordSizes.wordToZ*) v) ((*FixedWordSizes.wordToZ*) v)
| Add T => fun xy => add (bit_width_of_base_type T) (fst xy) (snd xy)
| Sub T => fun xy => sub (bit_width_of_base_type T) (fst xy) (snd xy)
| Mul T => fun xy => mul (bit_width_of_base_type T) (fst xy) (snd xy)
@@ -136,63 +136,29 @@ Module Import Bounds.
| Neg T int_width => fun x => neg (bit_width_of_base_type T) int_width x
| Cmovne T => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne (bit_width_of_base_type T) x y z w
| Cmovle T => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle (bit_width_of_base_type T) x y z w
- | Cast _ T => fun x => SmartRebuildBounds (bit_width_of_base_type T) x
+ | Cast _ T => fun x => truncation_bounds (bit_width_of_base_type T) x
end%bounds.
- Definition of_Z (z : Z) : t := Some (ZToZRange z).
+ Definition of_Z (z : Z) : t := ZToZRange z.
Definition of_interp t (z : Syntax.interp_base_type t) : interp_base_type t
- := Some (ZToZRange (match t return Syntax.interp_base_type t -> Z with
- | TZ => fun z => z
- | TWord logsz => fun z => z (*FixedWordSizes.wordToZ*)
- end z)).
+ := ZToZRange (interpToZ z).
- Definition bounds_to_base_type' (b : bounds) : base_type
+ Definition bounds_to_base_type (b : t) : base_type
:= if (0 <=? lower b)%Z
then TWord (Z.to_nat (Z.log2_up (Z.log2_up (1 + upper b))))
else TZ.
- Definition bounds_to_base_type (b : t) : base_type
- := match b with
- | None => TZ
- | Some b' => bounds_to_base_type' b'
- end.
Definition ComputeBounds {t} (e : Expr base_type op t)
(input_bounds : interp_flat_type interp_base_type (domain t))
: interp_flat_type interp_base_type (codomain t)
:= Interp (@interp_op) e input_bounds.
- Definition bound_is_goodb : forall t, interp_base_type t -> bool
- := fun t bs
- => match bs with
- | Some bs
- => (*let l := lower bs in
- let u := upper bs in
- let bit_width := bit_width_of_base_type t in
- ((0 <=? l) && (match bit_width with Some bit_width => Z.log2 u <? bit_width | None => true end))%Z%bool*)
- true
- | None => false
- end.
- Definition bound_is_good : forall t, interp_base_type t -> Prop
- := fun t v => bound_is_goodb t v = true.
- Definition bounds_are_good : forall {t}, interp_flat_type interp_base_type t -> Prop
- := (@interp_flat_type_rel_pointwise1 _ _ bound_is_good).
-
Definition is_tighter_thanb' {T} : interp_base_type T -> interp_base_type T -> bool
- := fun bounds1 bounds2
- => match bounds1, bounds2 with
- | Some bounds1, Some bounds2 => is_tighter_than_bool bounds1 bounds2
- | _, None => true
- | None, Some _ => false
- end.
+ := is_tighter_than_bool.
Definition is_bounded_by' {T} : interp_base_type T -> Syntax.interp_base_type T -> Prop
- := fun bounds val
- => match bounds with
- | Some bounds'
- => is_bounded_by' (bit_width_of_base_type T) bounds' val
- | None => True
- end.
+ := fun bounds val => is_bounded_by' (bit_width_of_base_type T) bounds (interpToZ val).
Definition is_tighter_thanb {T} : interp_flat_type interp_base_type T -> interp_flat_type interp_base_type T -> bool
:= interp_flat_type_relb_pointwise (@is_tighter_thanb').