aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--_CoqProject6
-rw-r--r--src/Reflection/Z/BoundsInterpretations.v207
-rw-r--r--src/Reflection/Z/Syntax.v19
3 files changed, 226 insertions, 6 deletions
diff --git a/_CoqProject b/_CoqProject
index f13bf2b2c..9894d6614 100644
--- a/_CoqProject
+++ b/_CoqProject
@@ -6,6 +6,7 @@ src/Algebra.v
src/BaseSystem.v
src/BaseSystemProofs.v
src/EdDSARepChange.v
+src/Karatsuba.v
src/MxDHRepChange.v
src/NewBaseSystem.v
src/Testbit.v
@@ -159,6 +160,7 @@ src/Reflection/Named/NameUtil.v
src/Reflection/Named/RegisterAssign.v
src/Reflection/Named/Syntax.v
src/Reflection/Z/BinaryNotationConstants.v
+src/Reflection/Z/BoundsInterpretations.v
src/Reflection/Z/CNotations.v
src/Reflection/Z/HexNotationConstants.v
src/Reflection/Z/Interpretations128.v
@@ -438,8 +440,8 @@ src/Test/Curve25519SpecTestVectors.v
src/Util/AdditionChainExponentiation.v
src/Util/AutoRewrite.v
src/Util/Bool.v
-src/Util/CaseUtil.v
src/Util/CPSUtil.v
+src/Util/CaseUtil.v
src/Util/Curry.v
src/Util/Decidable.v
src/Util/Equality.v
@@ -456,10 +458,10 @@ src/Util/LetIn.v
src/Util/LetInMonad.v
src/Util/ListUtil.v
src/Util/Logic.v
+src/Util/NUtil.v
src/Util/NatUtil.v
src/Util/Notations.v
src/Util/NumTheoryUtil.v
-src/Util/NUtil.v
src/Util/Option.v
src/Util/PartiallyReifiedProp.v
src/Util/PointedProp.v
diff --git a/src/Reflection/Z/BoundsInterpretations.v b/src/Reflection/Z/BoundsInterpretations.v
new file mode 100644
index 000000000..8da4ef39f
--- /dev/null
+++ b/src/Reflection/Z/BoundsInterpretations.v
@@ -0,0 +1,207 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Crypto.Reflection.Z.Syntax.
+Require Import Crypto.Reflection.Syntax.
+Require Import Crypto.Reflection.Relations.
+Require Import Crypto.Util.Option.
+Require Import Crypto.Util.Notations.
+Export Reflection.Syntax.Notations.
+
+Local Notation eta x := (fst x, snd x).
+Local Notation eta3 x := (eta (fst x), snd x).
+Local Notation eta4 x := (eta3 (fst x), snd x).
+
+Delimit Scope bounds_scope with bounds.
+Record bounds := { lower : Z ; upper : Z }.
+Bind Scope bounds_scope with bounds.
+
+Module Import Bounds.
+ Definition t := option bounds. (* TODO?: Separate out the bounds computation from the overflow computation? e.g., have [safety := in_bounds | overflow] and [t := bounds * safety]? *)
+ Bind Scope bounds_scope with t.
+ Local Coercion Z.of_nat : nat >-> Z.
+ Section with_bitwidth.
+ Context (bit_width : option Z).
+ Definition SmartBuildBounds (l u : Z)
+ := if ((0 <=? l) && (match bit_width with Some bit_width => u <? 2^bit_width | None => true end))%Z%bool
+ then Some {| lower := l ; upper := u |}
+ else None.
+ Definition SmartRebuildBounds (b : t) : t
+ := match b with
+ | Some b => SmartBuildBounds (lower b) (upper b)
+ | None => None
+ end.
+ Definition t_map1 (f : bounds -> bounds) (x : t)
+ := match x with
+ | Some x
+ => match f x with
+ | Build_bounds l u
+ => SmartBuildBounds l u
+ end
+ | _ => None
+ end%Z.
+ Definition t_map2 (f : bounds -> bounds -> bounds) (x y : t)
+ := match x, y with
+ | Some x, Some y
+ => match f x y with
+ | Build_bounds l u
+ => SmartBuildBounds l u
+ end
+ | _, _ => None
+ end%Z.
+ Definition t_map4 (f : bounds -> bounds -> bounds -> bounds -> bounds) (x y z w : t)
+ := match x, y, z, w with
+ | Some x, Some y, Some z, Some w
+ => match f x y z w with
+ | Build_bounds l u
+ => SmartBuildBounds l u
+ end
+ | _, _, _, _ => None
+ end%Z.
+ Definition add' : bounds -> bounds -> bounds
+ := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx + ly ; upper := ux + uy |}.
+ Definition add : t -> t -> t := t_map2 add'.
+ Definition sub' : bounds -> bounds -> bounds
+ := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx - uy ; upper := ux - ly |}.
+ Definition sub : t -> t -> t := t_map2 sub'.
+ Definition mul' : bounds -> bounds -> bounds
+ := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx * ly ; upper := ux * uy |}.
+ Definition mul : t -> t -> t := t_map2 mul'.
+ Definition shl' : bounds -> bounds -> bounds
+ := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := Z.shiftl lx ly ; upper := Z.shiftl ux uy |}.
+ Definition shl : t -> t -> t := t_map2 shl'.
+ Definition shr' : bounds -> bounds -> bounds
+ := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := Z.shiftr lx uy ; upper := Z.shiftr ux ly |}.
+ Definition shr : t -> t -> t := t_map2 shr'.
+ Definition land' : bounds -> bounds -> bounds
+ := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := 0 ; upper := Z.min ux uy |}.
+ Definition land : t -> t -> t := t_map2 land'.
+ Definition lor' : bounds -> bounds -> bounds
+ := fun x y => let (lx, ux) := x in let (ly, uy) := y in
+ {| lower := Z.max lx ly;
+ upper := 2^(Z.max (Z.log2_up (ux+1)) (Z.log2_up (uy+1))) - 1 |}.
+ Definition lor : t -> t -> t := t_map2 lor'.
+ Definition neg' (int_width : Z) : bounds -> bounds
+ := fun v
+ => let (lb, ub) := v in
+ let might_be_one := ((lb <=? 1) && (1 <=? ub))%Z%bool in
+ let must_be_one := ((lb =? 1) && (ub =? 1))%Z%bool in
+ if must_be_one
+ then {| lower := Z.ones int_width ; upper := Z.ones int_width |}
+ else if might_be_one
+ then {| lower := 0 ; upper := Z.ones int_width |}
+ else {| lower := 0 ; upper := 0 |}.
+ Definition neg (int_width : Z) : t -> t
+ := fun v
+ => if ((0 <=? int_width) (*&& (int_width <=? WordW.bit_width)*))%Z%bool
+ then t_map1 (neg' int_width) v
+ else None.
+ Definition cmovne' (r1 r2 : bounds) : bounds
+ := let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}.
+ Definition cmovne (x y r1 r2 : t) : t := t_map4 (fun _ _ => cmovne') x y r1 r2.
+ Definition cmovle' (r1 r2 : bounds) : bounds
+ := let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}.
+ Definition cmovle (x y r1 r2 : t) : t := t_map4 (fun _ _ => cmovle') x y r1 r2.
+ End with_bitwidth.
+
+ Module Export Notations.
+ Delimit Scope bounds_scope with bounds.
+ Notation "b[ l ~> u ]" := {| lower := l ; upper := u |}
+ (format "b[ l ~> u ]") : bounds_scope.
+ Infix "+" := (add _) : bounds_scope.
+ Infix "-" := (sub _) : bounds_scope.
+ Infix "*" := (mul _) : bounds_scope.
+ Infix "<<" := (shl _) : bounds_scope.
+ Infix ">>" := (shr _) : bounds_scope.
+ Infix "&'" := (land _) : bounds_scope.
+ End Notations.
+
+ Definition interp_base_type (ty : base_type) : Set := t.
+
+ Definition bit_width_of_base_type ty : option Z
+ := match ty with
+ | TZ => None
+ end.
+
+ Definition interp_op {src dst} (f : op src dst) : interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst
+ := match f in op src dst return interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst with
+ | OpConst v => fun _ => SmartBuildBounds None v v
+ | Add => fun xy => add (bit_width_of_base_type TZ) (fst xy) (snd xy)
+ | Sub => fun xy => sub (bit_width_of_base_type TZ) (fst xy) (snd xy)
+ | Mul => fun xy => mul (bit_width_of_base_type TZ) (fst xy) (snd xy)
+ | Shl => fun xy => shl (bit_width_of_base_type TZ) (fst xy) (snd xy)
+ | Shr => fun xy => shr (bit_width_of_base_type TZ) (fst xy) (snd xy)
+ | Land => fun xy => land (bit_width_of_base_type TZ) (fst xy) (snd xy)
+ | Lor => fun xy => lor (bit_width_of_base_type TZ) (fst xy) (snd xy)
+ | Neg int_width => fun x => neg (bit_width_of_base_type TZ) int_width x
+ | Cmovne => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne (bit_width_of_base_type TZ) x y z w
+ | Cmovle => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle (bit_width_of_base_type TZ) x y z w
+ end%bounds.
+
+ Ltac inversion_bounds :=
+ let lower := (eval cbv [lower] in (fun x => lower x)) in
+ let upper := (eval cbv [upper] in (fun y => upper y)) in
+ repeat match goal with
+ | [ H : _ = _ :> bounds |- _ ]
+ => pose proof (f_equal lower H); pose proof (f_equal upper H); clear H;
+ cbv beta iota in *
+ | [ H : _ = _ :> t |- _ ]
+ => unfold t in H; inversion_option
+ end.
+
+ Definition ZToBounds (z : Z) : bounds := {| lower := z ; upper := z |}.
+ Definition of_Z (z : Z) : t := Some (ZToBounds z).
+
+ Definition of_interp t (z : Syntax.interp_base_type t) : interp_base_type t
+ := Some (ZToBounds (match t return Syntax.interp_base_type t -> Z with
+ | TZ => fun z => z
+ end z)).
+
+ Definition bounds_to_base_type' (b : bounds) : base_type
+ := TZ.
+ Definition bounds_to_base_type (b : t) : base_type
+ := match b with
+ | None => TZ
+ | Some b' => bounds_to_base_type' b'
+ end.
+
+ (*
+ Definition ComputeBounds {t} (e : Expr base_type op t)
+ (input_bounds : interp_flat_type interp_base_type (domain t))
+ : interp_flat_type interp_base_type (codomain t)
+ := Interp (@interp_op) e input_bounds.
+ *)
+
+ Definition bound_is_goodb : forall t, interp_base_type t -> bool
+ := fun t bs
+ => match bs with
+ | Some bs
+ => (*let l := lower bs in
+ let u := upper bs in
+ let bit_width := bit_width_of_base_type t in
+ ((0 <=? l) && (match bit_width with Some bit_width => Z.log2 u <? bit_width | None => true end))%Z%bool*)
+ true
+ | None => false
+ end.
+ Definition bound_is_good : forall t, interp_base_type t -> Prop
+ := fun t v => bound_is_goodb t v = true.
+ Definition bounds_are_good : forall {t}, interp_flat_type interp_base_type t -> Prop
+ := (@interp_flat_type_rel_pointwise1 _ _ bound_is_good).
+
+ Definition is_bounded_byb {T} : Syntax.interp_base_type T -> interp_base_type T -> bool
+ := fun val bound
+ => match bound with
+ | Some bounds'
+ => ((0 <=? lower bounds') && (lower bounds' <=? interpToZ val) && (interpToZ val <=? upper bounds'))
+ && (match bit_width_of_base_type T with
+ | Some sz => upper bounds' <? 2^sz
+ | None => true
+ end)
+ | None => true
+ end%bool%Z.
+ Definition is_bounded_by' {T} : Syntax.interp_base_type T -> interp_base_type T -> Prop
+ := fun val bound => is_bounded_byb val bound = true.
+
+ Definition is_bounded_by {T} : interp_flat_type Syntax.interp_base_type T -> interp_flat_type interp_base_type T -> Prop
+ := interp_flat_type_rel_pointwise (@is_bounded_by').
+ Definition is_bounded_by_bool {T} : interp_flat_type Syntax.interp_base_type T -> interp_flat_type interp_base_type T -> bool
+ := interp_flat_type_relb_pointwise (@is_bounded_byb).
+End Bounds.
diff --git a/src/Reflection/Z/Syntax.v b/src/Reflection/Z/Syntax.v
index 288876dc9..2060c6852 100644
--- a/src/Reflection/Z/Syntax.v
+++ b/src/Reflection/Z/Syntax.v
@@ -10,6 +10,11 @@ Inductive base_type := TZ.
Local Notation tZ := (Tbase TZ).
+Definition interp_base_type (v : base_type) : Type :=
+ match v with
+ | TZ => Z
+ end.
+
Inductive op : flat_type base_type -> flat_type base_type -> Type :=
| OpConst (z : Z) : op Unit tZ
| Add : op (tZ * tZ) tZ
@@ -23,10 +28,16 @@ Inductive op : flat_type base_type -> flat_type base_type -> Type :=
| Cmovne : op (tZ * tZ * tZ * tZ) tZ
| Cmovle : op (tZ * tZ * tZ * tZ) tZ.
-Definition interp_base_type (v : base_type) : Type :=
- match v with
- | TZ => Z
- end.
+Definition interpToZ {t} : interp_base_type t -> Z
+ := match t with
+ | TZ => fun x => x
+ end.
+Definition ZToInterp {t} : Z -> interp_base_type t
+ := match t return Z -> interp_base_type t with
+ | TZ => fun x => x
+ end.
+Definition cast_const {t1 t2} (v : interp_base_type t1) : interp_base_type t2
+ := ZToInterp (interpToZ v).
Local Notation eta x := (fst x, snd x).
Local Notation eta3 x := (eta (fst x), snd x).