diff options
Diffstat (limited to 'src/Reflection')
-rw-r--r-- | src/Reflection/Z/Bounds/Interpretation.v | 188 | ||||
-rw-r--r-- | src/Reflection/Z/Bounds/Pipeline/Glue.v | 149 | ||||
-rw-r--r-- | src/Reflection/Z/Bounds/Relax.v | 33 | ||||
-rw-r--r-- | src/Reflection/Z/Syntax.v | 1 |
4 files changed, 243 insertions, 128 deletions
diff --git a/src/Reflection/Z/Bounds/Interpretation.v b/src/Reflection/Z/Bounds/Interpretation.v index 0a0bb28f0..cad5d87b3 100644 --- a/src/Reflection/Z/Bounds/Interpretation.v +++ b/src/Reflection/Z/Bounds/Interpretation.v @@ -2,7 +2,6 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.Reflection.Z.Syntax. Require Import Crypto.Reflection.Syntax. Require Import Crypto.Reflection.Relations. -Require Import Crypto.Util.Option. Require Import Crypto.Util.Notations. Require Import Crypto.Util.Decidable. Require Import Crypto.Util.ZRange. @@ -15,73 +14,70 @@ Local Notation eta4 x := (eta3 (fst x), snd x). Notation bounds := zrange. Delimit Scope bounds_scope with bounds. +Local Open Scope Z_scope. Module Import Bounds. - Definition t := option bounds. (* TODO?: Separate out the bounds computation from the overflow computation? e.g., have [safety := in_bounds | overflow] and [t := bounds * safety]? *) + Definition t := bounds. Bind Scope bounds_scope with t. Local Coercion Z.of_nat : nat >-> Z. Section with_bitwidth. Context (bit_width : option Z). - Definition SmartBuildBounds (l u : Z) - := if ((0 <=? l) && (match bit_width with Some bit_width => u <? 2^bit_width | None => true end))%Z%bool - then Some {| lower := l ; upper := u |} - else None. - Definition SmartRebuildBounds (b : t) : t - := match b with - | Some b => SmartBuildBounds (lower b) (upper b) - | None => None + Definition four_corners (f : Z -> Z -> Z) : t -> t -> t + := fun x y + => let (lx, ux) := x in + let (ly, uy) := y in + {| lower := Z.min (f lx ly) (Z.min (f lx uy) (Z.min (f ux ly) (f ux uy))); + upper := Z.max (f lx ly) (Z.max (f lx uy) (Z.max (f ux ly) (f ux uy))) |}. + Definition truncation_bounds (b : t) + := match bit_width with + | Some bit_width => if ((0 <=? lower b) && (upper b <? 2^bit_width))%bool + then b + else {| lower := 0 ; upper := 2^bit_width - 1 |} + | None => b end. + Definition BuildTruncated_bounds (l u : Z) : t + := truncation_bounds {| lower := l ; upper := u |}. Definition t_map1 (f : bounds -> bounds) (x : t) - := match x with - | Some x - => match f x with - | {| lower := l ; upper := u |} - => SmartBuildBounds l u - end - | _ => None - end%Z. - Definition t_map2 (f : bounds -> bounds -> bounds) (x y : t) - := match x, y with - | Some x, Some y - => match f x y with - | {| lower := l ; upper := u |} - => SmartBuildBounds l u - end - | _, _ => None - end%Z. + := truncation_bounds (f x). + Definition t_map2 (f : Z -> Z -> Z) : t -> t -> t + := fun x y => truncation_bounds (four_corners f x y). Definition t_map4 (f : bounds -> bounds -> bounds -> bounds -> bounds) (x y z w : t) - := match x, y, z, w with - | Some x, Some y, Some z, Some w - => match f x y z w with - | {| lower := l ; upper := u |} - => SmartBuildBounds l u - end - | _, _, _, _ => None - end%Z. - Definition add' : bounds -> bounds -> bounds - := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx + ly ; upper := ux + uy |}. - Definition add : t -> t -> t := t_map2 add'. - Definition sub' : bounds -> bounds -> bounds - := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx - uy ; upper := ux - ly |}. - Definition sub : t -> t -> t := t_map2 sub'. - Definition mul' : bounds -> bounds -> bounds - := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := lx * ly ; upper := ux * uy |}. - Definition mul : t -> t -> t := t_map2 mul'. - Definition shl' : bounds -> bounds -> bounds - := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := Z.shiftl lx ly ; upper := Z.shiftl ux uy |}. - Definition shl : t -> t -> t := t_map2 shl'. - Definition shr' : bounds -> bounds -> bounds - := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := Z.shiftr lx uy ; upper := Z.shiftr ux ly |}. - Definition shr : t -> t -> t := t_map2 shr'. - Definition land' : bounds -> bounds -> bounds - := fun x y => let (lx, ux) := x in let (ly, uy) := y in {| lower := 0 ; upper := Z.min ux uy |}. - Definition land : t -> t -> t := t_map2 land'. - Definition lor' : bounds -> bounds -> bounds - := fun x y => let (lx, ux) := x in let (ly, uy) := y in - {| lower := Z.max lx ly; - upper := 2^(Z.max (Z.log2_up (ux+1)) (Z.log2_up (uy+1))) - 1 |}. - Definition lor : t -> t -> t := t_map2 lor'. - Definition neg' (int_width : Z) : bounds -> bounds + := truncation_bounds (f x y z w). + Definition add : t -> t -> t := t_map2 Z.add. + Definition sub : t -> t -> t := t_map2 Z.sub. + Definition mul : t -> t -> t := t_map2 Z.mul. + Definition shl : t -> t -> t := t_map2 Z.shiftl. + Definition shr : t -> t -> t := t_map2 Z.shiftr. + Definition extreme_lor_land_bounds (x y : t) : t + := let (lx, ux) := x in + let (ly, uy) := y in + let lx := Z.abs lx in + let ly := Z.abs ly in + let ux := Z.abs ux in + let uy := Z.abs uy in + let max := Z.max (Z.max lx ly) (Z.max ux uy) in + {| lower := -2^(1 + Z.log2_up max) ; upper := 2^(1 + Z.log2_up max) |}. + Definition extermization_bounds (f : t -> t -> t) (x y : t) : t + := truncation_bounds + (let (lx, ux) := x in + let (ly, uy) := y in + if ((lx <? 0) || (ly <? 0))%Z%bool + then extreme_lor_land_bounds x y + else f x y). + Definition land : t -> t -> t + := extermization_bounds + (fun x y + => let (lx, ux) := x in + let (ly, uy) := y in + {| lower := Z.min 0 (Z.min lx ly) ; upper := Z.max 0 (Z.min ux uy) |}). + Definition lor : t -> t -> t + := extermization_bounds + (fun x y + => let (lx, ux) := x in + let (ly, uy) := y in + {| lower := Z.max lx ly; + upper := 2^(Z.max (Z.log2_up (ux+1)) (Z.log2_up (uy+1))) - 1 |}). + Definition neg' (int_width : Z) : t -> t := fun v => let (lb, ub) := v in let might_be_one := ((lb <=? 1) && (1 <=? ub))%Z%bool in @@ -89,19 +85,23 @@ Module Import Bounds. if must_be_one then {| lower := Z.ones int_width ; upper := Z.ones int_width |} else if might_be_one - then {| lower := 0 ; upper := Z.ones int_width |} + then {| lower := Z.min 0 (Z.ones int_width) ; upper := Z.max 0 (Z.ones int_width) |} else {| lower := 0 ; upper := 0 |}. Definition neg (int_width : Z) : t -> t := fun v - => if ((0 <=? int_width) (*&& (int_width <=? WordW.bit_width)*))%Z%bool - then t_map1 (neg' int_width) v - else None. - Definition cmovne' (r1 r2 : bounds) : bounds - := let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}. - Definition cmovne (x y r1 r2 : t) : t := t_map4 (fun _ _ => cmovne') x y r1 r2. - Definition cmovle' (r1 r2 : bounds) : bounds - := let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}. - Definition cmovle (x y r1 r2 : t) : t := t_map4 (fun _ _ => cmovle') x y r1 r2. + => truncation_bounds (neg' int_width v). + Definition cmovne' (r1 r2 : t) : t + := let (lr1, ur1) := r1 in + let (lr2, ur2) := r2 in + {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}. + Definition cmovne (x y r1 r2 : t) : t + := truncation_bounds (cmovne' r1 r2). + Definition cmovle' (r1 r2 : t) : t + := let (lr1, ur1) := r1 in + let (lr2, ur2) := r2 in + {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}. + Definition cmovle (x y r1 r2 : t) : t + := truncation_bounds (cmovle' r1 r2). End with_bitwidth. Module Export Notations. @@ -124,8 +124,8 @@ Module Import Bounds. Definition interp_op {src dst} (f : op src dst) : interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst := match f in op src dst return interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst with - | OpConst TZ v => fun _ => SmartBuildBounds None v v - | OpConst (TWord _ as T) v => fun _ => SmartBuildBounds (bit_width_of_base_type T) ((*FixedWordSizes.wordToZ*) v) ((*FixedWordSizes.wordToZ*) v) + | OpConst TZ v => fun _ => BuildTruncated_bounds None v v + | OpConst (TWord _ as T) v => fun _ => BuildTruncated_bounds (bit_width_of_base_type T) ((*FixedWordSizes.wordToZ*) v) ((*FixedWordSizes.wordToZ*) v) | Add T => fun xy => add (bit_width_of_base_type T) (fst xy) (snd xy) | Sub T => fun xy => sub (bit_width_of_base_type T) (fst xy) (snd xy) | Mul T => fun xy => mul (bit_width_of_base_type T) (fst xy) (snd xy) @@ -136,63 +136,29 @@ Module Import Bounds. | Neg T int_width => fun x => neg (bit_width_of_base_type T) int_width x | Cmovne T => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne (bit_width_of_base_type T) x y z w | Cmovle T => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle (bit_width_of_base_type T) x y z w - | Cast _ T => fun x => SmartRebuildBounds (bit_width_of_base_type T) x + | Cast _ T => fun x => truncation_bounds (bit_width_of_base_type T) x end%bounds. - Definition of_Z (z : Z) : t := Some (ZToZRange z). + Definition of_Z (z : Z) : t := ZToZRange z. Definition of_interp t (z : Syntax.interp_base_type t) : interp_base_type t - := Some (ZToZRange (match t return Syntax.interp_base_type t -> Z with - | TZ => fun z => z - | TWord logsz => fun z => z (*FixedWordSizes.wordToZ*) - end z)). + := ZToZRange (interpToZ z). - Definition bounds_to_base_type' (b : bounds) : base_type + Definition bounds_to_base_type (b : t) : base_type := if (0 <=? lower b)%Z then TWord (Z.to_nat (Z.log2_up (Z.log2_up (1 + upper b)))) else TZ. - Definition bounds_to_base_type (b : t) : base_type - := match b with - | None => TZ - | Some b' => bounds_to_base_type' b' - end. Definition ComputeBounds {t} (e : Expr base_type op t) (input_bounds : interp_flat_type interp_base_type (domain t)) : interp_flat_type interp_base_type (codomain t) := Interp (@interp_op) e input_bounds. - Definition bound_is_goodb : forall t, interp_base_type t -> bool - := fun t bs - => match bs with - | Some bs - => (*let l := lower bs in - let u := upper bs in - let bit_width := bit_width_of_base_type t in - ((0 <=? l) && (match bit_width with Some bit_width => Z.log2 u <? bit_width | None => true end))%Z%bool*) - true - | None => false - end. - Definition bound_is_good : forall t, interp_base_type t -> Prop - := fun t v => bound_is_goodb t v = true. - Definition bounds_are_good : forall {t}, interp_flat_type interp_base_type t -> Prop - := (@interp_flat_type_rel_pointwise1 _ _ bound_is_good). - Definition is_tighter_thanb' {T} : interp_base_type T -> interp_base_type T -> bool - := fun bounds1 bounds2 - => match bounds1, bounds2 with - | Some bounds1, Some bounds2 => is_tighter_than_bool bounds1 bounds2 - | _, None => true - | None, Some _ => false - end. + := is_tighter_than_bool. Definition is_bounded_by' {T} : interp_base_type T -> Syntax.interp_base_type T -> Prop - := fun bounds val - => match bounds with - | Some bounds' - => is_bounded_by' (bit_width_of_base_type T) bounds' val - | None => True - end. + := fun bounds val => is_bounded_by' (bit_width_of_base_type T) bounds (interpToZ val). Definition is_tighter_thanb {T} : interp_flat_type interp_base_type T -> interp_flat_type interp_base_type T -> bool := interp_flat_type_relb_pointwise (@is_tighter_thanb'). diff --git a/src/Reflection/Z/Bounds/Pipeline/Glue.v b/src/Reflection/Z/Bounds/Pipeline/Glue.v new file mode 100644 index 000000000..93986a42f --- /dev/null +++ b/src/Reflection/Z/Bounds/Pipeline/Glue.v @@ -0,0 +1,149 @@ +(** * Reflective Pipeline: Glue Code *) +(** This file defines the tactics that transform a non-reflective goal + into a goal the that the reflective machinery can handle. *) +Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Reflection.Reify. +Require Import Crypto.Reflection.Z.Syntax. +Require Import Crypto.Reflection.Z.Syntax.Util. +Require Import Crypto.Reflection.Z.Reify. +Require Import Crypto.Reflection.Z.Bounds.Interpretation. +Require Import Crypto.Util.Tactics.Head. +Require Import Crypto.Util.Curry. +Require Import Crypto.Util.FixedWordSizes. +Require Import Crypto.Util.BoundedWord. +Require Import Crypto.Util.Tuple. + +(** The [do_curry] tactic takes a goal of the form +<< +BoundedWordToZ (?f a b ... z) = F A B ... Z +>> + and turns it into a goal of the form +<< +BoundedWordToZ (f' (a, b, ..., z)) = F' (A, B, ..., Z) +>> + *) +Ltac do_curry := + lazymatch goal with + | [ |- ?BWtoZ ?f_bw = ?f_Z ] + => let f_bw := head f_bw in + let f_Z := head f_Z in + change_with_curried f_Z; + let f_bw_name := fresh f_bw in + set (f_bw_name := f_bw); + change_with_curried f_bw_name + end. +(** The [split_BoundedWordToZ] tactic takes a goal of the form +<< +BoundedWordToZ (f args) = F ARGS +>> + and splits it into a conjunction, one part about the computational + behavior, and another part about the boundedness. *) +Ltac count_tuple_length T := + lazymatch T with + | (?A * ?B)%type => let a := count_tuple_length A in + let b := count_tuple_length B in + (eval compute in (a + b)%nat) + | _ => constr:(1%nat) + end. +Ltac make_evar_for_first_projection := + lazymatch goal with + | [ |- @map ?N1 ?A ?B wordToZ (@proj1_sig _ ?P ?f) = ?fZ ?argsZ ] + => let T := type of argsZ in + let N := count_tuple_length T in + let map' := (eval compute in (@map N)) in + let proj1_sig' := (eval compute in @proj1_sig) in + let f1 := fresh f in + let f2 := fresh f in + let pf := fresh in + revert f; refine (_ : let f := exist P _ _ in _); + intro f; + pose (proj1_sig f) as f1; + pose (proj2_sig f : P f1) as f2; + change f with (exist _ f1 f2); + subst f; cbn [proj1_sig proj2_sig] in f1, f2 |- *; revert f2; + lazymatch goal with + | [ |- let f' := _ in @?P f' ] + => refine (let pf := _ in (proj2 pf : let f' := proj1 pf in P f')) + end + end. +Ltac split_BoundedWordToZ := + match goal with + | [ |- BoundedWordToZ _ _ _ ?x = _ ] + => revert x + end; + repeat match goal with + | [ |- context[BoundedWordToZ _ _ _ ?x] ] + => is_var x; + first [ clearbody x; fail 1 + | instantiate (1:=ltac:(destruct x)); destruct x ] + end; + cbv beta iota; intro; + unfold BoundedWordToZ; cbn [proj1_sig]; + make_evar_for_first_projection. +(** The [zrange_to_reflective] tactic takes a goal of the form +<< +is_bounded_by _ bounds (map wordToZ (?fW args)) /\ map wordToZ (?fW args) = fZ argsZ +>> + and uses [cut] and a small lemma to turn it into a goal that the + reflective machinery can handle. The goal left by this tactic + should be fully solvable by the reflective pipeline. *) + +Ltac const_tuple T val := + lazymatch T with + | (?A * ?B)%type => let a := const_tuple A val in + let b := const_tuple B val in + constr:((a, b)%core) + | _ => val + end. +Lemma adjust_goal_for_reflective {T P} (LHS RHS : T) + : P RHS /\ LHS = RHS -> P LHS /\ LHS = RHS. +Proof. intros [? ?]; subst; tauto. Qed. +Ltac adjust_goal_for_reflective := apply adjust_goal_for_reflective. +Ltac unmap_wordToZ_tuple term := + lazymatch term with + | (?x, ?y) => let x' := unmap_wordToZ_tuple x in + let y' := unmap_wordToZ_tuple y in + constr:((x', y')) + | map wordToZ ?x => x + end. +Ltac zrange_to_reflective_hyps_step := + match goal with + | [ H : @ZRange.is_bounded_by ?option_bit_width ?count ?bounds (Tuple.map wordToZ ?arg) |- _ ] + => let rT := constr:(Syntax.tuple (Tbase TZ) count) in + let is_bounded_by' := constr:(@Bounds.is_bounded_by rT) in + let map' := constr:(@cast_back_flat_const (@Bounds.interp_base_type) rT (fun _ => Bounds.bounds_to_base_type) bounds) in + (* we use [cut] and [abstract] rather than [change] to catch inefficiencies in conversion early, rather than allowing [Defined] to take forever *) + let H' := fresh H in + rename H into H'; + assert (H : is_bounded_by' bounds (map' arg)) by (clear -H'; abstract exact H'); + clear H'; move H at top + end. +Ltac zrange_to_reflective_hyps := repeat zrange_to_reflective_hyps_step. +Ltac zrange_to_reflective_goal := + lazymatch goal with + | [ |- @ZRange.is_bounded_by ?option_bit_width ?count ?bounds (Tuple.map wordToZ ?reified_f_evar) + /\ Tuple.map wordToZ ?reified_f_evar = ?f ?Zargs ] + => let T := type of f in + let f_domain := lazymatch T with ?A -> ?B => A end in + let T := (eval compute in T) in + let rT := reify_type T in + let is_bounded_by' := constr:(@Bounds.is_bounded_by (codomain rT)) in + let input_bounds := const_tuple f_domain bounds in + let map_t := constr:(fun t bs => @cast_back_flat_const (@Bounds.interp_base_type) t (fun _ => Bounds.bounds_to_base_type) bs) in + let map_output := constr:(map_t (codomain rT) bounds) in + let map_input := constr:(map_t (domain rT) input_bounds) in + let args := unmap_wordToZ_tuple Zargs in + (* we use [cut] and [abstract] rather than [change] to catch inefficiencies in conversion early, rather than allowing [Defined] to take forever *) + cut (is_bounded_by' bounds (map_output reified_f_evar) /\ map_output reified_f_evar = f (map_input args)); + [ generalize reified_f_evar; clear; clearbody f; let x := fresh in intros ? x; abstract exact x + | ]; + cbv beta + end; + adjust_goal_for_reflective. +Ltac zrange_to_reflective := zrange_to_reflective_hyps; zrange_to_reflective_goal. + +(** The tactic [refine_to_reflective_glue] is the public-facing one. *) +Ltac refine_to_reflective_glue := + do_curry; + split_BoundedWordToZ; + zrange_to_reflective. diff --git a/src/Reflection/Z/Bounds/Relax.v b/src/Reflection/Z/Bounds/Relax.v index 5269639e1..e77ef423a 100644 --- a/src/Reflection/Z/Bounds/Relax.v +++ b/src/Reflection/Z/Bounds/Relax.v @@ -58,23 +58,22 @@ Proof. cbv [Bounds.is_bounded_by cast_back_flat_const Bounds.is_tighter_thanb] in *. apply interp_flat_type_rel_pointwise_iff_relb in Hrelax. induction t; unfold SmartFlatTypeMap in *; simpl @smart_interp_flat_map in *; inversion_flat_type. - { cbv [Bounds.is_tighter_thanb' ZRange.is_tighter_than_bool is_true SmartFlatTypeMap Bounds.bounds_to_base_type' Bounds.bounds_to_base_type ZRange.is_bounded_by' ZRange.is_bounded_by Bounds.is_bounded_by' Bounds.bit_width_of_base_type] in *; simpl in *. - progress destruct_head_hnf option; - repeat first [ progress inversion_flat_type - | progress inversion_base_type - | progress destruct_head bounds - | progress split_andb - | progress Z.ltb_to_lt - | progress break_match_hyps - | progress destruct_head'_and - | progress simpl in * - | rewrite helper in * - | omega - | tauto - | congruence - | progress destruct_head @eq; (reflexivity || omega) - | progress break_innermost_match_step - | apply conj ]. } + { cbv [Bounds.is_tighter_thanb' ZRange.is_tighter_than_bool is_true SmartFlatTypeMap Bounds.bounds_to_base_type ZRange.is_bounded_by' ZRange.is_bounded_by Bounds.is_bounded_by' Bounds.bit_width_of_base_type] in *; simpl in *. + repeat first [ progress inversion_flat_type + | progress inversion_base_type + | progress destruct_head bounds + | progress split_andb + | progress Z.ltb_to_lt + | progress break_match_hyps + | progress destruct_head'_and + | progress simpl in * + | rewrite helper in * + | omega + | tauto + | congruence + | progress destruct_head @eq; (reflexivity || omega) + | progress break_innermost_match_step + | apply conj ]. } { compute in *; tauto. } { simpl in *. specialize (fun Hv => IHt1 (fst tight_output_bounds) (fst relaxed_output_bounds) Hv (fst v)). diff --git a/src/Reflection/Z/Syntax.v b/src/Reflection/Z/Syntax.v index 3a005e59e..ad8215fe5 100644 --- a/src/Reflection/Z/Syntax.v +++ b/src/Reflection/Z/Syntax.v @@ -2,6 +2,7 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.Reflection.Syntax. Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. +Require Import Crypto.Util.FixedWordSizes. Export Syntax.Notations. Local Set Boolean Equality Schemes. |