diff options
Diffstat (limited to 'src/Compilers/Z/Bounds/Interpretation.v')
-rw-r--r-- | src/Compilers/Z/Bounds/Interpretation.v | 177 |
1 files changed, 177 insertions, 0 deletions
diff --git a/src/Compilers/Z/Bounds/Interpretation.v b/src/Compilers/Z/Bounds/Interpretation.v new file mode 100644 index 000000000..b9880f097 --- /dev/null +++ b/src/Compilers/Z/Bounds/Interpretation.v @@ -0,0 +1,177 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Compilers.Z.Syntax. +Require Import Crypto.Compilers.Syntax. +Require Import Crypto.Compilers.Relations. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.Tactics.DestructHead. +Export Compilers.Syntax.Notations. + +Local Notation eta x := (fst x, snd x). +Local Notation eta3 x := (eta (fst x), snd x). +Local Notation eta4 x := (eta3 (fst x), snd x). + +Notation bounds := zrange. +Delimit Scope bounds_scope with bounds. +Local Open Scope Z_scope. + +Module Import Bounds. + Definition t := bounds. + Bind Scope bounds_scope with t. + Local Coercion Z.of_nat : nat >-> Z. + Section with_bitwidth. + Context (bit_width : option Z). + Definition four_corners (f : Z -> Z -> Z) : t -> t -> t + := fun x y + => let (lx, ux) := x in + let (ly, uy) := y in + {| lower := Z.min (f lx ly) (Z.min (f lx uy) (Z.min (f ux ly) (f ux uy))); + upper := Z.max (f lx ly) (Z.max (f lx uy) (Z.max (f ux ly) (f ux uy))) |}. + Definition two_corners (f : Z -> Z) : t -> t + := fun x + => let (lx, ux) := x in + {| lower := Z.min (f lx) (f ux); + upper := Z.max (f lx) (f ux) |}. + Definition truncation_bounds (b : t) + := match bit_width with + | Some bit_width => if ((0 <=? lower b) && (upper b <? 2^bit_width))%bool + then b + else {| lower := 0 ; upper := 2^bit_width - 1 |} + | None => b + end. + Definition BuildTruncated_bounds (l u : Z) : t + := truncation_bounds {| lower := l ; upper := u |}. + Definition t_map1 (f : Z -> Z) (x : t) + := truncation_bounds (two_corners f x). + Definition t_map2 (f : Z -> Z -> Z) : t -> t -> t + := fun x y => truncation_bounds (four_corners f x y). + Definition t_map4 (f : bounds -> bounds -> bounds -> bounds -> bounds) (x y z w : t) + := truncation_bounds (f x y z w). + Definition add : t -> t -> t := t_map2 Z.add. + Definition sub : t -> t -> t := t_map2 Z.sub. + Definition mul : t -> t -> t := t_map2 Z.mul. + Definition shl : t -> t -> t := t_map2 Z.shiftl. + Definition shr : t -> t -> t := t_map2 Z.shiftr. + Definition extreme_lor_land_bounds (x y : t) : t + := let (lx, ux) := x in + let (ly, uy) := y in + let lx := Z.abs lx in + let ly := Z.abs ly in + let ux := Z.abs ux in + let uy := Z.abs uy in + let max := Z.max (Z.max lx ly) (Z.max ux uy) in + {| lower := -2^(1 + Z.log2_up max) ; upper := 2^(1 + Z.log2_up max) |}. + Definition extermization_bounds (f : t -> t -> t) (x y : t) : t + := truncation_bounds + (let (lx, ux) := x in + let (ly, uy) := y in + if ((lx <? 0) || (ly <? 0))%Z%bool + then extreme_lor_land_bounds x y + else f x y). + Definition land : t -> t -> t + := extermization_bounds + (fun x y + => let (lx, ux) := x in + let (ly, uy) := y in + {| lower := Z.min 0 (Z.min lx ly) ; upper := Z.max 0 (Z.min ux uy) |}). + Definition lor : t -> t -> t + := extermization_bounds + (fun x y + => let (lx, ux) := x in + let (ly, uy) := y in + {| lower := Z.max lx ly; + upper := 2^(Z.max (Z.log2_up (ux+1)) (Z.log2_up (uy+1))) - 1 |}). + Definition opp : t -> t := t_map1 Z.opp. + Definition neg' (int_width : Z) : t -> t + := fun v + => let (lb, ub) := v in + let might_be_one := ((lb <=? 1) && (1 <=? ub))%Z%bool in + let must_be_one := ((lb =? 1) && (ub =? 1))%Z%bool in + if must_be_one + then {| lower := Z.ones int_width ; upper := Z.ones int_width |} + else if might_be_one + then {| lower := Z.min 0 (Z.ones int_width) ; upper := Z.max 0 (Z.ones int_width) |} + else {| lower := 0 ; upper := 0 |}. + Definition neg (int_width : Z) : t -> t + := fun v + => truncation_bounds (neg' int_width v). + Definition cmovne' (r1 r2 : t) : t + := let (lr1, ur1) := r1 in + let (lr2, ur2) := r2 in + {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}. + Definition cmovne (x y r1 r2 : t) : t + := truncation_bounds (cmovne' r1 r2). + Definition cmovle' (r1 r2 : t) : t + := let (lr1, ur1) := r1 in + let (lr2, ur2) := r2 in + {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}. + Definition cmovle (x y r1 r2 : t) : t + := truncation_bounds (cmovle' r1 r2). + End with_bitwidth. + + Module Export Notations. + Export Util.ZRange.Notations. + Infix "+" := (add _) : bounds_scope. + Infix "-" := (sub _) : bounds_scope. + Infix "*" := (mul _) : bounds_scope. + Infix "<<" := (shl _) : bounds_scope. + Infix ">>" := (shr _) : bounds_scope. + Infix "&'" := (land _) : bounds_scope. + Notation "- x" := (opp _ x) : bounds_scope. + End Notations. + + Definition interp_base_type (ty : base_type) : Set := t. + + Definition bit_width_of_base_type ty : option Z + := match ty with + | TZ => None + | TWord logsz => Some (2^Z.of_nat logsz)%Z + end. + + Definition interp_op {src dst} (f : op src dst) : interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst + := match f in op src dst return interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst with + | OpConst T v => fun _ => BuildTruncated_bounds (bit_width_of_base_type T) v v + | Add _ _ T => fun xy => add (bit_width_of_base_type T) (fst xy) (snd xy) + | Sub _ _ T => fun xy => sub (bit_width_of_base_type T) (fst xy) (snd xy) + | Mul _ _ T => fun xy => mul (bit_width_of_base_type T) (fst xy) (snd xy) + | Shl _ _ T => fun xy => shl (bit_width_of_base_type T) (fst xy) (snd xy) + | Shr _ _ T => fun xy => shr (bit_width_of_base_type T) (fst xy) (snd xy) + | Land _ _ T => fun xy => land (bit_width_of_base_type T) (fst xy) (snd xy) + | Lor _ _ T => fun xy => lor (bit_width_of_base_type T) (fst xy) (snd xy) + | Opp _ T => fun x => opp (bit_width_of_base_type T) x + end%bounds. + + Definition of_Z (z : Z) : t := ZToZRange z. + + Definition of_interp t (z : Syntax.interp_base_type t) : interp_base_type t + := ZToZRange (interpToZ z). + + Definition bounds_to_base_type (b : t) : base_type + := if (0 <=? lower b)%Z + then TWord (Z.to_nat (Z.log2_up (Z.log2_up (1 + upper b)))) + else TZ. + + Definition ComputeBounds {t} (e : Expr base_type op t) + (input_bounds : interp_flat_type interp_base_type (domain t)) + : interp_flat_type interp_base_type (codomain t) + := Interp (@interp_op) e input_bounds. + + Definition is_tighter_thanb' {T} : interp_base_type T -> interp_base_type T -> bool + := is_tighter_than_bool. + + Definition is_bounded_by' {T} : interp_base_type T -> Syntax.interp_base_type T -> Prop + := fun bounds val => is_bounded_by' (bit_width_of_base_type T) bounds (interpToZ val). + + Definition is_tighter_thanb {T} : interp_flat_type interp_base_type T -> interp_flat_type interp_base_type T -> bool + := interp_flat_type_relb_pointwise (@is_tighter_thanb'). + + Definition is_bounded_by {T} : interp_flat_type interp_base_type T -> interp_flat_type Syntax.interp_base_type T -> Prop + := interp_flat_type_rel_pointwise (@is_bounded_by'). + + Local Arguments interp_base_type !_ / . + Global Instance dec_eq_interp_flat_type {T} : DecidableRel (@eq (interp_flat_type interp_base_type T)) | 10. + Proof. + induction T; destruct_head base_type; simpl; exact _. + Defined. +End Bounds. |