From be5f86b5483d2e00ec9002b8db00a1ff8ecb9cfe Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 12 May 2017 18:07:47 -0400 Subject: Add reflective machinery for adc, zselect --- src/Compilers/Z/ArithmeticSimplifier.v | 14 +- src/Compilers/Z/Bounds/Interpretation.v | 59 ++++++- .../Z/Bounds/InterpretationLemmas/IsBoundedBy.v | 179 ++++++++++++++++++++- .../Z/Bounds/InterpretationLemmas/PullCast.v | 2 +- src/Compilers/Z/CommonSubexpressionElimination.v | 28 +++- src/Compilers/Z/Reify.v | 15 +- src/Compilers/Z/Syntax.v | 7 + src/Compilers/Z/Syntax/Equality.v | 9 ++ src/Compilers/Z/Syntax/Util.v | 3 + src/Specific/IntegrationTestFreeze.v | 119 ++++++++++++++ src/Specific/IntegrationTestFreezeDisplay.v | 4 + 11 files changed, 420 insertions(+), 19 deletions(-) create mode 100644 src/Specific/IntegrationTestFreeze.v create mode 100644 src/Specific/IntegrationTestFreezeDisplay.v (limited to 'src') diff --git a/src/Compilers/Z/ArithmeticSimplifier.v b/src/Compilers/Z/ArithmeticSimplifier.v index b2621c625..f0d3e19ab 100644 --- a/src/Compilers/Z/ArithmeticSimplifier.v +++ b/src/Compilers/Z/ArithmeticSimplifier.v @@ -20,9 +20,18 @@ Section language. : option (interp_flat_type inverted_expr t) := match x in Syntax.exprf _ _ t return option (interp_flat_type _ t) with | Op t1 (Tbase _) opc args - => Some (match opc in op src dst return exprf dst -> exprf src -> inverted_expr match dst with Tbase t => t | _ => TZ end with + => Some (match opc in op src dst + return exprf dst + -> exprf src + -> match dst with + | Tbase t => inverted_expr t + | Prod _ _ => True + | _ => inverted_expr TZ + end + with | OpConst _ z => fun _ _ => const_of _ z | Opp TZ TZ => fun _ args => neg_expr _ args + | AddWithGetCarry _ _ _ _ _ _ => fun _ _ => I | _ => fun e _ => gen_expr _ e end (Op opc args) args) | TT => Some tt @@ -175,6 +184,9 @@ Section language. | Lor _ _ _ as opc | OpConst _ _ as opc | Opp _ _ as opc + | Zselect _ _ _ _ as opc + | AddWithCarry _ _ _ _ as opc + | AddWithGetCarry _ _ _ _ _ _ as opc => Op opc end. End with_var. diff --git a/src/Compilers/Z/Bounds/Interpretation.v b/src/Compilers/Z/Bounds/Interpretation.v index ab2a8ef43..f4cbb3bbd 100644 --- a/src/Compilers/Z/Bounds/Interpretation.v +++ b/src/Compilers/Z/Bounds/Interpretation.v @@ -5,6 +5,7 @@ Require Import Crypto.Compilers.Relations. Require Import Crypto.Util.Notations. Require Import Crypto.Util.Decidable. Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ZUtil.Definitions. Require Import Crypto.Util.Tactics.DestructHead. Export Compilers.Syntax.Notations. @@ -22,17 +23,25 @@ Module Import Bounds. Local Coercion Z.of_nat : nat >-> Z. Section with_bitwidth. Context (bit_width : option Z). - Definition four_corners (f : Z -> Z -> Z) : t -> t -> t - := fun x y - => let (lx, ux) := x in - let (ly, uy) := y in - {| lower := Z.min (f lx ly) (Z.min (f lx uy) (Z.min (f ux ly) (f ux uy))); - upper := Z.max (f lx ly) (Z.max (f lx uy) (Z.max (f ux ly) (f ux uy))) |}. Definition two_corners (f : Z -> Z) : t -> t := fun x => let (lx, ux) := x in {| lower := Z.min (f lx) (f ux); upper := Z.max (f lx) (f ux) |}. + Definition four_corners (f : Z -> Z -> Z) : t -> t -> t + := fun x y + => let (lx, ux) := x in + let (lfl, ufl) := two_corners (f lx) y in + let (lfu, ufu) := two_corners (f ux) y in + {| lower := Z.min lfl lfu; + upper := Z.max ufl ufu |}. + Definition eight_corners (f : Z -> Z -> Z -> Z) : t -> t -> t -> t + := fun x y z + => let (lx, ux) := x in + let (lfl, ufl) := four_corners (f lx) y z in + let (lfu, ufu) := four_corners (f ux) y z in + {| lower := Z.min lfl lfu; + upper := Z.max ufl ufu |}. Definition truncation_bounds (b : t) := match bit_width with | Some bit_width => if ((0 <=? lower b) && (upper b Z -> Z) : t -> t -> t := fun x y => truncation_bounds (four_corners f x y). + Definition t_map3' (f : Z -> Z -> Z -> Z) : t -> t -> t -> t + := fun x y z => eight_corners f x y z. + Definition t_map3 (f : Z -> Z -> Z -> Z) : t -> t -> t -> t + := fun x y z => truncation_bounds (eight_corners f x y z). Definition t_map4 (f : bounds -> bounds -> bounds -> bounds -> bounds) (x y z w : t) := truncation_bounds (f x y z w). Definition add : t -> t -> t := t_map2 Z.add. @@ -82,6 +95,27 @@ Module Import Bounds. {| lower := Z.max lx ly; upper := 2^(Z.max (Z.log2_up (ux+1)) (Z.log2_up (uy+1))) - 1 |}). Definition opp : t -> t := t_map1 Z.opp. + Definition zselect' (r1 r2 : t) : t + := let (lr1, ur1) := r1 in + let (lr2, ur2) := r2 in + {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}. + Definition zselect (c r1 r2 : t) : t + := truncation_bounds (zselect' r1 r2). + Definition add_with_carry' : t -> t -> t -> t + := t_map3' Z.add_with_carry. + Definition add_with_carry : t -> t -> t -> t + := t_map3 Z.add_with_carry. + Definition modulo_pow2_constant : Z -> t -> t + := fun e x + => let d := 2^e in + let (l, u) := (lower x, upper x) in + truncation_bounds {| lower := if l / d =? u / d then Z.min (l mod d) (u mod d) else Z.min 0 (d + 1); + upper := if l / d =? u / d then Z.max (l mod d) (u mod d) else Z.max 0 (d - 1) |}. + Definition div_pow2_constant : Z -> t -> t + := fun e x + => let d := 2^e in + let (l, u) := (lower x, upper x) in + truncation_bounds {| lower := l / d ; upper := u / d |}. Definition neg' (int_width : Z) : t -> t := fun v => let (lb, ub) := v in @@ -108,6 +142,14 @@ Module Import Bounds. Definition cmovle (x y r1 r2 : t) : t := truncation_bounds (cmovle' r1 r2). End with_bitwidth. + Section with_bitwidth2. + Context (bit_width1 bit_width2 : option Z). + Definition add_with_get_carry (carry_boundary_bit_width : Z) : t -> t -> t -> t * t + := fun c x y + => let xpy := add_with_carry' c x y in + (modulo_pow2_constant bit_width1 carry_boundary_bit_width xpy, + div_pow2_constant bit_width2 carry_boundary_bit_width xpy). + End with_bitwidth2. Module Export Notations. Export Util.ZRange.Notations. @@ -142,6 +184,11 @@ Module Import Bounds. | Land _ _ T => fun xy => land (bit_width_of_base_type T) (fst xy) (snd xy) | Lor _ _ T => fun xy => lor (bit_width_of_base_type T) (fst xy) (snd xy) | Opp _ T => fun x => opp (bit_width_of_base_type T) x + | Zselect _ _ _ T => fun cxy => let '(c, x, y) := eta3 cxy in zselect (bit_width_of_base_type T) c x y + | AddWithCarry _ _ _ T => fun cxy => let '(c, x, y) := eta3 cxy in add_with_carry (bit_width_of_base_type T) c x y + | AddWithGetCarry carry_boundary_bit_width _ _ _ T1 T2 + => fun cxy => let '(c, x, y) := eta3 cxy in + add_with_get_carry (bit_width_of_base_type T1) (bit_width_of_base_type T2) carry_boundary_bit_width c x y end%bounds. Definition of_Z (z : Z) : t := ZToZRange z. diff --git a/src/Compilers/Z/Bounds/InterpretationLemmas/IsBoundedBy.v b/src/Compilers/Z/Bounds/InterpretationLemmas/IsBoundedBy.v index 415c65406..c74c319d5 100644 --- a/src/Compilers/Z/Bounds/InterpretationLemmas/IsBoundedBy.v +++ b/src/Compilers/Z/Bounds/InterpretationLemmas/IsBoundedBy.v @@ -35,6 +35,65 @@ Proof. | word_arith_t ]. Qed. +Lemma is_bounded_by_compose T1 T2 f_v bs v f_bs fv + (H : Bounds.is_bounded_by (T:=Tbase T1) bs v) + (Hf : forall bs v, Bounds.is_bounded_by (T:=Tbase T1) bs v -> Bounds.is_bounded_by (T:=Tbase T2) (f_bs bs) (f_v v)) + (Hfv : f_v v = fv) + : Bounds.is_bounded_by (T:=Tbase T2) (f_bs bs) fv. +Proof. + subst; eauto. +Qed. + +Lemma monotone_two_corners_genb + (f : Z -> Z) + (R := fun b : bool => if b then Z.le else Basics.flip Z.le) + (Hmonotone : exists b, Proper (R b ==> Z.le) f) + x_bs x + (Hboundedx : ZRange.is_bounded_by' None x_bs x) + : ZRange.is_bounded_by' None (Bounds.two_corners f x_bs) (f x). +Proof. + unfold ZRange.is_bounded_by' in *; split; trivial. + destruct x_bs as [lx ux]; simpl in *. + destruct Hboundedx as [Hboundedx _]. + destruct_head'_ex. + repeat match goal with + | [ H : Proper (R ?b ==> Z.le) f |- _ ] + => unique assert (R b (if b then lx else x) (if b then x else lx) + /\ R b (if b then x else ux) (if b then ux else x)) + by (unfold R, Basics.flip; destruct b; omega) + end. + destruct_head' and. + repeat match goal with + | [ H : Proper (R ?b ==> Z.le) _, H' : R ?b _ _ |- _ ] + => unique pose proof (H _ _ H') + end. + destruct_head bool; split_min_max; omega. +Qed. + +Lemma monotone_two_corners_gen + (f : Z -> Z) + (Hmonotone : Proper (Z.le ==> Z.le) f \/ Proper (Basics.flip Z.le ==> Z.le) f) + x_bs x + (Hboundedx : ZRange.is_bounded_by' None x_bs x) + : ZRange.is_bounded_by' None (Bounds.two_corners f x_bs) (f x). +Proof. + eapply monotone_two_corners_genb; auto. + destruct Hmonotone; [ exists true | exists false ]; assumption. +Qed. +Lemma monotone_two_corners + (b : bool) + (f : Z -> Z) + (R := if b then Z.le else Basics.flip Z.le) + (Hmonotone : Proper (R ==> Z.le) f) + x_bs x + (Hboundedx : ZRange.is_bounded_by' None x_bs x) + : ZRange.is_bounded_by' None (Bounds.two_corners f x_bs) (f x). +Proof. + apply monotone_two_corners_genb; auto; subst R; + exists b. + intros ???; apply Hmonotone; auto. +Qed. + Lemma monotone_four_corners_genb (f : Z -> Z -> Z) (R := fun b : bool => if b then Z.le else Basics.flip Z.le) @@ -45,11 +104,19 @@ Lemma monotone_four_corners_genb (Hboundedy : ZRange.is_bounded_by' None y_bs y) : ZRange.is_bounded_by' None (Bounds.four_corners f x_bs y_bs) (f x y). Proof. - unfold ZRange.is_bounded_by' in *; split; trivial. - destruct x_bs as [lx ux], y_bs as [ly uy]; simpl in *. - destruct Hboundedx as [Hboundedx _], Hboundedy as [Hboundedy _]. - pose proof (Hmonotone1 lx); pose proof (Hmonotone1 x); pose proof (Hmonotone1 ux). - pose proof (Hmonotone2 ly); pose proof (Hmonotone2 y); pose proof (Hmonotone2 uy). + destruct x_bs as [lx ux], y_bs as [ly uy]. + unfold Bounds.four_corners. + pose proof (monotone_two_corners_genb (f lx) (Hmonotone1 _) _ _ Hboundedy) as Hmono_fl. + pose proof (monotone_two_corners_genb (f ux) (Hmonotone1 _) _ _ Hboundedy) as Hmono_fu. + repeat match goal with + | [ |- context[Bounds.two_corners ?x ?y] ] + => let l := fresh "lf" in + let u := fresh "uf" in + generalize dependent (Bounds.two_corners x y); intros [l u]; intros + end. + unfold ZRange.is_bounded_by' in *; simpl in *; split; trivial. + destruct_head'_and; destruct_head' True. + pose proof (Hmonotone2 y). destruct_head'_ex. repeat match goal with | [ H : Proper (R ?b ==> Z.le) (f _) |- _ ] @@ -65,7 +132,7 @@ Proof. repeat match goal with | [ H : Proper (R ?b ==> Z.le) _, H' : R ?b _ _ |- _ ] => unique pose proof (H _ _ H') - end. + end; cbv beta in *. destruct_head bool; split_min_max; omega. Qed. @@ -98,6 +165,88 @@ Proof. | intros ???; apply Hmonotone; auto; destruct b2; reflexivity ]. Qed. +Lemma monotone_eight_corners_genb + (f : Z -> Z -> Z -> Z) + (R := fun b : bool => if b then Z.le else Basics.flip Z.le) + (Hmonotone1 : forall x y, exists b, Proper (R b ==> Z.le) (f x y)) + (Hmonotone2 : forall x z, exists b, Proper (R b ==> Z.le) (fun y => f x y z)) + (Hmonotone3 : forall y z, exists b, Proper (R b ==> Z.le) (fun x => f x y z)) + x_bs y_bs z_bs x y z + (Hboundedx : ZRange.is_bounded_by' None x_bs x) + (Hboundedy : ZRange.is_bounded_by' None y_bs y) + (Hboundedz : ZRange.is_bounded_by' None z_bs z) + : ZRange.is_bounded_by' None (Bounds.eight_corners f x_bs y_bs z_bs) (f x y z). +Proof. + destruct x_bs as [lx ux], y_bs as [ly uy], z_bs as [lz uz]. + unfold Bounds.eight_corners. + pose proof (monotone_four_corners_genb (f lx) (Hmonotone1 _) (Hmonotone2 _) _ _ _ _ Hboundedy Hboundedz) as Hmono_fl. + pose proof (monotone_four_corners_genb (f ux) (Hmonotone1 _) (Hmonotone2 _) _ _ _ _ Hboundedy Hboundedz) as Hmono_fu. + repeat match goal with + | [ |- context[Bounds.four_corners ?x ?y ?z] ] + => let l := fresh "lf" in + let u := fresh "uf" in + generalize dependent (Bounds.four_corners x y z); intros [l u]; intros + end. + unfold ZRange.is_bounded_by' in *; simpl in *; split; trivial. + destruct_head'_and; destruct_head' True. + pose proof (Hmonotone3 y z). + destruct_head'_ex. + repeat match goal with + | [ H : Proper (R ?b ==> Z.le) (f _ _) |- _ ] + => unique assert (R b (if b then lz else z) (if b then z else lz) + /\ R b (if b then z else uz) (if b then uz else z)) + by (unfold R, Basics.flip; destruct b; omega) + | [ H : Proper (R ?b ==> Z.le) (fun y' => f _ y' _) |- _ ] + => unique assert (R b (if b then ly else y) (if b then y else ly) + /\ R b (if b then y else uy) (if b then uy else y)) + by (unfold R, Basics.flip; destruct b; omega) + | [ H : Proper (R ?b ==> Z.le) (fun x' => f x' _ _) |- _ ] + => unique assert (R b (if b then lx else x) (if b then x else lx) + /\ R b (if b then x else ux) (if b then ux else x)) + by (unfold R, Basics.flip; destruct b; omega) + end. + destruct_head' and. + repeat match goal with + | [ H : Proper (R ?b ==> Z.le) _, H' : R ?b _ _ |- _ ] + => unique pose proof (H _ _ H') + end. + destruct_head bool; split_min_max; omega. +Qed. + +Lemma monotone_eight_corners_gen + (f : Z -> Z -> Z -> Z) + (Hmonotone1 : forall x y, Proper (Z.le ==> Z.le) (f x y) \/ Proper (Basics.flip Z.le ==> Z.le) (f x y)) + (Hmonotone2 : forall x z, Proper (Z.le ==> Z.le) (fun y => f x y z) \/ Proper (Basics.flip Z.le ==> Z.le) (fun y => f x y z)) + (Hmonotone3 : forall y z, Proper (Z.le ==> Z.le) (fun x => f x y z) \/ Proper (Basics.flip Z.le ==> Z.le) (fun x => f x y z)) + x_bs y_bs z_bs x y z + (Hboundedx : ZRange.is_bounded_by' None x_bs x) + (Hboundedy : ZRange.is_bounded_by' None y_bs y) + (Hboundedz : ZRange.is_bounded_by' None z_bs z) + : ZRange.is_bounded_by' None (Bounds.eight_corners f x_bs y_bs z_bs) (f x y z). +Proof. + eapply monotone_eight_corners_genb; auto. + { intros x' y'; destruct (Hmonotone1 x' y'); [ exists true | exists false ]; assumption. } + { intros x' y'; destruct (Hmonotone2 x' y'); [ exists true | exists false ]; assumption. } + { intros x' y'; destruct (Hmonotone3 x' y'); [ exists true | exists false ]; assumption. } +Qed. +Lemma monotone_eight_corners + (b1 b2 b3 : bool) + (f : Z -> Z -> Z -> Z) + (R1 := if b1 then Z.le else Basics.flip Z.le) + (R2 := if b2 then Z.le else Basics.flip Z.le) + (R3 := if b3 then Z.le else Basics.flip Z.le) + (Hmonotone : Proper (R1 ==> R2 ==> R3 ==> Z.le) f) + x_bs y_bs z_bs x y z + (Hboundedx : ZRange.is_bounded_by' None x_bs x) + (Hboundedy : ZRange.is_bounded_by' None y_bs y) + (Hboundedz : ZRange.is_bounded_by' None z_bs z) + : ZRange.is_bounded_by' None (Bounds.eight_corners f x_bs y_bs z_bs) (f x y z). +Proof. + apply monotone_eight_corners_genb; auto; intro x'; subst R1 R2 R3; + [ exists b3 | exists b2 | exists b1 ]; + intros ???; apply Hmonotone; break_innermost_match; try reflexivity; trivial. +Qed. + Lemma monotonify2 (f : Z -> Z -> Z) (upper : Z -> Z -> Z) (Hbounded : forall a b, Z.abs (f a b) <= upper (Z.abs a) (Z.abs b)) (Hupper_monotone : Proper (Z.le ==> Z.le ==> Z.le) upper) @@ -173,7 +322,7 @@ Qed. Local Arguments N.ldiff : simpl never. Local Arguments Z.pow : simpl never. Local Arguments Z.add !_ !_. -Local Existing Instances Z.add_le_Proper Z.sub_le_flip_le_Proper Z.log2_up_le_Proper Z.pow_Zpos_le_Proper Z.sub_le_eq_Proper. +Local Existing Instances Z.add_le_Proper Z.sub_le_flip_le_Proper Z.log2_up_le_Proper Z.pow_Zpos_le_Proper Z.sub_le_eq_Proper Z.add_with_carry_le_Proper. Local Hint Extern 1 => progress cbv beta iota : typeclass_instances. Lemma is_bounded_by_interp_op t tR (opc : op t tR) (bs : interp_flat_type Bounds.interp_base_type _) @@ -181,7 +330,14 @@ Lemma is_bounded_by_interp_op t tR (opc : op t tR) (H : Bounds.is_bounded_by bs v) : Bounds.is_bounded_by (Bounds.interp_op opc bs) (Syntax.interp_op _ _ opc v). Proof. - destruct opc; apply is_bounded_by_truncation_bounds; + destruct opc; + [ apply is_bounded_by_truncation_bounds.. + | split; + cbv [Bounds.interp_op Zinterp_op Z.add_with_get_carry SmartFlatTypeMapUnInterp Bounds.add_with_get_carry Z.get_carry cast_const]; cbn [fst snd]; + [ eapply is_bounded_by_compose with (T1:=TZ) (f_v := fun v => ZToInterp (v mod _)) (v:=ZToInterp _); + [ | intros; apply is_bounded_by_truncation_bounds | simpl; reflexivity ] + | eapply is_bounded_by_compose with (T1:=TZ) (f_v := fun v => ZToInterp (v / _)) (v:=ZToInterp _); + [ | intros; apply is_bounded_by_truncation_bounds | simpl; reflexivity ] ] ]; repeat first [ progress simpl in * | progress cbv [interp_op lift_op cast_const Bounds.interp_base_type Bounds.is_bounded_by' ZRange.is_bounded_by'] in * | progress destruct_head'_prod @@ -225,4 +381,11 @@ Proof. | progress simpl in * | progress split_min_max | omega ]. } + { destruct_head Bounds.t; cbv [Bounds.zselect' Z.zselect]. + break_innermost_match; split_min_max; omega. } + { apply (@monotone_eight_corners true true true _ _ _); split; auto. } + { apply (@monotone_eight_corners true true true _ _ _); split; auto. } + { apply Z.mod_bound_min_max; auto. } + { apply (@monotone_eight_corners true true true _ _ _); split; auto. } + { auto with zarith. } Qed. diff --git a/src/Compilers/Z/Bounds/InterpretationLemmas/PullCast.v b/src/Compilers/Z/Bounds/InterpretationLemmas/PullCast.v index 3e38eabdf..38ed60038 100644 --- a/src/Compilers/Z/Bounds/InterpretationLemmas/PullCast.v +++ b/src/Compilers/Z/Bounds/InterpretationLemmas/PullCast.v @@ -160,7 +160,7 @@ Section with_round_up. | progress destruct_head'_and ]; [ | rewrite cast_const_idempotent_small by t_small; reflexivity - | ]; + | .. ]; repeat match goal with | _ => progress change (@cast_const TZ) with @ZToInterp in * | [ |- context[@cast_const ?x TZ] ] diff --git a/src/Compilers/Z/CommonSubexpressionElimination.v b/src/Compilers/Z/CommonSubexpressionElimination.v index 6695d137e..81f6553ba 100644 --- a/src/Compilers/Z/CommonSubexpressionElimination.v +++ b/src/Compilers/Z/CommonSubexpressionElimination.v @@ -20,11 +20,15 @@ Inductive symbolic_op := | SLand | SLor | SOpp +| SZselect +| SAddWithCarry +| SAddWithGetCarry (bitwidth : Z) . Definition symbolic_op_leb (x y : symbolic_op) : bool := match x, y with | SOpConst z1, SOpConst z2 => Z.leb z1 z2 + | SAddWithGetCarry bw1, SAddWithGetCarry bw2 => Z.leb bw1 bw2 | SOpConst _, _ => true | _, SOpConst _ => false | SAdd, _ => true @@ -42,6 +46,13 @@ Definition symbolic_op_leb (x y : symbolic_op) : bool | SLor, _ => true | _, SLor => false | SOpp, _ => true + | _, SOpp => false + | SZselect, _ => true + | _, SZselect => false + | SAddWithCarry, _ => true + | _, SAddWithCarry => false + (*| SAddWithGetCarry _, _ => true + | _, SAddWithGetCarry _ => false*) end. Local Notation symbolic_expr := (@symbolic_expr base_type symbolic_op). @@ -59,6 +70,9 @@ Definition symbolize_op s d (opc : op s d) : symbolic_op | Land T1 T2 Tout => SLand | Lor T1 T2 Tout => SLor | Opp T Tout => SOpp + | Zselect T1 T2 T3 Tout => SZselect + | AddWithCarry T1 T2 T3 Tout => SAddWithCarry + | AddWithGetCarry bitwidth T1 T2 T3 Tout1 Tout2 => SAddWithGetCarry bitwidth end. Definition denote_symbolic_op s d (opc : symbolic_op) : option (op s d) @@ -72,6 +86,10 @@ Definition denote_symbolic_op s d (opc : symbolic_op) : option (op s d) | SLand, Prod (Tbase _) (Tbase _), Tbase _ => Some (Land _ _ _) | SLor, Prod (Tbase _) (Tbase _), Tbase _ => Some (Lor _ _ _) | SOpp, Tbase _, Tbase _ => Some (Opp _ _) + | SZselect, Prod (Prod (Tbase _) (Tbase _)) (Tbase _), Tbase _ => Some (Zselect _ _ _ _) + | SAddWithCarry, Prod (Prod (Tbase _) (Tbase _)) (Tbase _), Tbase _ => Some (AddWithCarry _ _ _ _) + | SAddWithGetCarry bitwidth, Prod (Prod (Tbase _) (Tbase _)) (Tbase _), Prod (Tbase _) (Tbase _) + => Some (AddWithGetCarry bitwidth _ _ _ _ _) | SAdd, _, _ | SSub, _, _ | SMul, _, _ @@ -81,14 +99,17 @@ Definition denote_symbolic_op s d (opc : symbolic_op) : option (op s d) | SLor, _, _ | SOpp, _, _ | SOpConst _, _, _ + | SZselect, _, _ + | SAddWithCarry, _, _ + | SAddWithGetCarry _, _, _ => None end. Lemma symbolic_op_leb_total : forall a1 a2, symbolic_op_leb a1 a2 = true \/ symbolic_op_leb a2 a1 = true. Proof. - induction a1, a2; simpl; auto. - rewrite !Z.leb_le; omega. + induction a1, a2; simpl; auto; + rewrite !Z.leb_le; omega. Qed. Module SymbolicExprOrder <: TotalLeBool. @@ -148,6 +169,9 @@ Definition normalize_symbolic_expr_mod_c (opc : symbolic_op) (args : symbolic_ex | SShl | SShr | SOpp + | SZselect + | SAddWithCarry + | SAddWithGetCarry _ => args end. diff --git a/src/Compilers/Z/Reify.v b/src/Compilers/Z/Reify.v index c5f31c935..6d41df19e 100644 --- a/src/Compilers/Z/Reify.v +++ b/src/Compilers/Z/Reify.v @@ -7,6 +7,7 @@ Require Import Crypto.Compilers.WfReflective. Require Import Crypto.Compilers.Reify. Require Import Crypto.Compilers.Eta. Require Import Crypto.Compilers.EtaInterp. +Require Import Crypto.Util.ZUtil.Definitions. Ltac base_reify_op op op_head extra ::= lazymatch op_head with @@ -18,13 +19,25 @@ Ltac base_reify_op op op_head extra ::= | @Z.land => constr:(reify_op op op_head 2 (Land TZ TZ TZ)) | @Z.lor => constr:(reify_op op op_head 2 (Lor TZ TZ TZ)) | @Z.opp => constr:(reify_op op op_head 1 (Opp TZ TZ)) + | @Z.opp => constr:(reify_op op op_head 1 (Opp TZ TZ)) + | @Z.zselect => constr:(reify_op op op_head 3 (Zselect TZ TZ TZ TZ)) + | @Z.add_with_carry => constr:(reify_op op op_head 3 (AddWithCarry TZ TZ TZ TZ)) + | @Z.add_with_get_carry + => lazymatch extra with + | @Z.add_with_get_carry ?bit_width _ _ _ + => constr:(reify_op op op_head 3 (AddWithGetCarry bit_width TZ TZ TZ TZ TZ)) + | _ => fail 100 "Anomaly: In Reflection.Z.base_reify_op: head is Z.add_with_get_carry but body is wrong:" extra + end end. Ltac base_reify_type T ::= lazymatch T with | Z => TZ end. -Ltac Reify' e := Compilers.Reify.Reify' base_type interp_base_type op e. +Ltac Reify' e := + let e := (eval cbv beta delta [Z.add_get_carry] in e) in + Compilers.Reify.Reify' base_type interp_base_type op e. Ltac Reify e := + let e := (eval cbv beta delta [Z.add_get_carry] in e) in let v := Compilers.Reify.Reify base_type interp_base_type op make_const e in constr:(ExprEta v). Ltac prove_ExprEta_Compile_correct := diff --git a/src/Compilers/Z/Syntax.v b/src/Compilers/Z/Syntax.v index 754b3fb5a..7d0c84421 100644 --- a/src/Compilers/Z/Syntax.v +++ b/src/Compilers/Z/Syntax.v @@ -6,6 +6,7 @@ Require Import Crypto.Compilers.Syntax. Require Import Crypto.Compilers.TypeUtil. Require Import Crypto.Util.FixedWordSizes. Require Import Crypto.Util.Option. +Require Import Crypto.Util.ZUtil.Definitions. Require Import Crypto.Util.NatUtil. (* for nat_beq for equality schemes *) Export Syntax.Notations. @@ -26,6 +27,9 @@ Inductive op : flat_type base_type -> flat_type base_type -> Type := | Land T1 T2 Tout : op (Tbase T1 * Tbase T2) (Tbase Tout) | Lor T1 T2 Tout : op (Tbase T1 * Tbase T2) (Tbase Tout) | Opp T Tout : op (Tbase T) (Tbase Tout) +| Zselect T1 T2 T3 Tout : op (Tbase T1 * Tbase T2 * Tbase T3) (Tbase Tout) +| AddWithCarry T1 T2 T3 Tout : op (Tbase T1 * Tbase T2 * Tbase T3) (Tbase Tout) +| AddWithGetCarry (bitwidth : Z) T1 T2 T3 Tout1 Tout2 : op (Tbase T1 * Tbase T2 * Tbase T3) (Tbase Tout1 * Tbase Tout2) . Definition interp_base_type (v : base_type) : Type := @@ -78,6 +82,9 @@ Definition Zinterp_op src dst (f : op src dst) | Land _ _ _ => fun xy => Z.land (fst xy) (snd xy) | Lor _ _ _ => fun xy => Z.lor (fst xy) (snd xy) | Opp _ _ => fun x => Z.opp x + | Zselect _ _ _ _ => fun ctf => let '(c, t, f) := eta3 ctf in Z.zselect c t f + | AddWithCarry _ _ _ _ => fun cxy => let '(c, x, y) := eta3 cxy in Z.add_with_carry c x y + | AddWithGetCarry bitwidth _ _ _ _ _ => fun cxy => let '(c, x, y) := eta3 cxy in Z.add_with_get_carry bitwidth c x y end%Z. Definition interp_op src dst (f : op src dst) : interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst diff --git a/src/Compilers/Z/Syntax/Equality.v b/src/Compilers/Z/Syntax/Equality.v index 745b01503..b2075c49c 100644 --- a/src/Compilers/Z/Syntax/Equality.v +++ b/src/Compilers/Z/Syntax/Equality.v @@ -41,6 +41,12 @@ Definition op_beq_hetero {t1 tR t1' tR'} (f : op t1 tR) (g : op t1' tR') : bool => base_type_beq T1 T1' && base_type_beq T2 T2' && base_type_beq Tout Tout' | Opp Tin Tout, Opp Tin' Tout' => base_type_beq Tin Tin' && base_type_beq Tout Tout' + | Zselect T1 T2 T3 Tout, Zselect T1' T2' T3' Tout' + => base_type_beq T1 T1' && base_type_beq T2 T2' && base_type_beq T3 T3' && base_type_beq Tout Tout' + | AddWithCarry T1 T2 T3 Tout, AddWithCarry T1' T2' T3' Tout' + => base_type_beq T1 T1' && base_type_beq T2 T2' && base_type_beq T3 T3' && base_type_beq Tout Tout' + | AddWithGetCarry bitwidth T1 T2 T3 Tout1 Tout2, AddWithGetCarry bitwidth' T1' T2' T3' Tout1' Tout2' + => Z.eqb bitwidth bitwidth' && base_type_beq T1 T1' && base_type_beq T2 T2' && base_type_beq T3 T3' && base_type_beq Tout1 Tout1' && base_type_beq Tout2 Tout2' | OpConst _ _, _ | Add _ _ _, _ | Sub _ _ _, _ @@ -50,6 +56,9 @@ Definition op_beq_hetero {t1 tR t1' tR'} (f : op t1 tR) (g : op t1' tR') : bool | Land _ _ _, _ | Lor _ _ _, _ | Opp _ _, _ + | Zselect _ _ _ _, _ + | AddWithCarry _ _ _ _, _ + | AddWithGetCarry _ _ _ _ _ _, _ => false end%bool. diff --git a/src/Compilers/Z/Syntax/Util.v b/src/Compilers/Z/Syntax/Util.v index 9f18f47e9..b84a00dee 100644 --- a/src/Compilers/Z/Syntax/Util.v +++ b/src/Compilers/Z/Syntax/Util.v @@ -61,6 +61,9 @@ Definition genericize_op {var' src dst} (opc : op src dst) {f} | Land _ _ _ => fun _ _ => Land _ _ _ | Lor _ _ _ => fun _ _ => Lor _ _ _ | Opp _ _ => fun _ _ => Opp _ _ + | Zselect _ _ _ _ => fun _ _ => Zselect _ _ _ _ + | AddWithCarry _ _ _ _ => fun _ _ => AddWithCarry _ _ _ _ + | AddWithGetCarry bitwidth _ _ _ _ _ => fun _ _ => AddWithGetCarry bitwidth _ _ _ _ _ end. Lemma cast_const_id {t} v diff --git a/src/Specific/IntegrationTestFreeze.v b/src/Specific/IntegrationTestFreeze.v new file mode 100644 index 000000000..12f107be6 --- /dev/null +++ b/src/Specific/IntegrationTestFreeze.v @@ -0,0 +1,119 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.Lists.List. +Local Open Scope Z_scope. + +Require Import Crypto.Arithmetic.Core. +Require Import Crypto.Util.FixedWordSizes. +Require Import Crypto.Specific.ArithmeticSynthesisTest. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.Tuple Crypto.Util.Sigma Crypto.Util.Sigma.MapProjections Crypto.Util.Sigma.Lift Crypto.Util.Notations Crypto.Util.ZRange Crypto.Util.BoundedWord. +Require Import Crypto.Util.Tactics.Head. +Require Import Crypto.Util.Tactics.MoveLetIn. +Require Import Crypto.Util.Tactics.DestructHead. +Import ListNotations. + +Require Import Crypto.Specific.IntegrationTestTemporaryMiscCommon. + +Require Import Crypto.Compilers.Z.Bounds.Pipeline. + +Section BoundedField25p5. + Local Coercion Z.of_nat : nat >-> Z. + + Let limb_widths := Eval vm_compute in (List.map (fun i => Z.log2 (wt (S i) / wt i)) (seq 0 sz)). + Let length_lw := Eval compute in List.length limb_widths. + + Local Notation b_of exp := {| lower := 0 ; upper := 2^exp + 2^(exp-3) |}%Z (only parsing). (* max is [(0, 2^(exp+2) + 2^exp + 2^(exp-1) + 2^(exp-3) + 2^(exp-4) + 2^(exp-5) + 2^(exp-6) + 2^(exp-10) + 2^(exp-12) + 2^(exp-13) + 2^(exp-14) + 2^(exp-15) + 2^(exp-17) + 2^(exp-23) + 2^(exp-24))%Z] *) + (* The definition [bounds_exp] is a tuple-version of the + limb-widths, which are the [exp] argument in [b_of] above, i.e., + the approximate base-2 exponent of the bounds on the limb in that + position. *) + Let bounds_exp : Tuple.tuple Z length_lw + := Eval compute in + Tuple.from_list length_lw limb_widths eq_refl. + Let bounds : Tuple.tuple zrange length_lw + := Eval compute in + Tuple.map (fun e => b_of e) bounds_exp. + + Let lgbitwidth := Eval compute in (Z.to_nat (Z.log2_up (List.fold_right Z.max 0 limb_widths))). + Let bitwidth := Eval compute in (2^lgbitwidth)%nat. + Let feZ : Type := tuple Z sz. + Let feW : Type := tuple (wordT lgbitwidth) sz. + Let feBW : Type := BoundedWord sz bitwidth bounds. + Let phi : feBW -> F m := + fun x => B.Positional.Fdecode wt (BoundedWordToZ _ _ _ x). + + (* TODO : change this to field once field isomorphism happens *) + Definition freeze : + { freeze : feBW -> feBW + | forall a, phi (freeze a) = phi a }. + Proof. + lazymatch goal with + | [ |- { f | forall a, ?phi (f a) = @?rhs a } ] + => apply lift1_sig with (P:=fun a f => phi f = rhs a) + end. + intros. + eexists_sig_etransitivity. all:cbv [phi]. + rewrite <- (proj2_sig freeze_sig). + { set (freezeZ := proj1_sig freeze_sig). + context_to_dlet_in_rhs freezeZ. + cbv beta iota delta [freezeZ proj1_sig freeze_sig fst snd runtime_add runtime_and runtime_mul runtime_opp runtime_shr sz]. + reflexivity. } + { destruct a as [a H]; unfold BoundedWordToZ, proj1_sig. + revert H. + cbv -[Z.le Z.add Z.mul Z.lt fst snd wordToZ wt]; cbn [fst snd]. + repeat match goal with + | [ |- context[wt ?n] ] + => let v := (eval compute in (wt n)) in change (wt n) with v + end. + intro; destruct_head'_and. + omega. } + sig_dlet_in_rhs_to_context. + apply (fun f => proj2_sig_map (fun THIS_NAME_MUST_NOT_BE_UNDERSCORE_TO_WORK_AROUND_CONSTR_MATCHING_ANAOMLIES___BUT_NOTE_THAT_IF_THIS_NAME_IS_LOWERCASE_A___THEN_REIFICATION_STACK_OVERFLOWS___AND_I_HAVE_NO_IDEA_WHATS_GOING_ON p => f_equal f p)). + (* jgross start here! *) + (*Set Ltac Profiling.*) + Time refine_reflectively. (* Finished transaction in 5.792 secs (5.792u,0.004s) (successful) *) + (*Show Ltac Profile.*) + (* total time: 5.680s + + tactic local total calls max +────────────────────────────────────────┴──────┴──────┴───────┴─────────┘ +─refine_reflectively_gen --------------- 0.0% 100.0% 1 5.680s +─ReflectiveTactics.do_reflective_pipelin 0.0% 95.8% 1 5.444s +─ReflectiveTactics.solve_side_conditions 0.4% 95.6% 1 5.428s +─ReflectiveTactics.do_reify ------------ 46.0% 61.7% 1 3.504s +─UnifyAbstractReflexivity.unify_transfor 22.9% 28.4% 7 0.372s +─Reify_rhs_gen ------------------------- 0.7% 8.3% 1 0.472s +─eexact -------------------------------- 7.2% 7.2% 131 0.012s +─ReflectiveTactics.unify_abstract_cbv_in 3.9% 4.9% 1 0.280s +─Glue.refine_to_reflective_glue' ------- 0.0% 4.2% 1 0.236s +─Glue.zrange_to_reflective ------------- 0.1% 3.3% 1 0.188s +─unify (constr) (constr) --------------- 3.2% 3.2% 6 0.052s +─prove_interp_compile_correct ---------- 0.0% 2.7% 1 0.152s +─clear (var_list) ---------------------- 2.7% 2.7% 91 0.028s +─rewrite ?EtaInterp.InterpExprEta ------ 2.5% 2.5% 1 0.140s +─ClearAll.clear_all -------------------- 0.4% 2.5% 7 0.036s +─Glue.zrange_to_reflective_goal -------- 2.0% 2.4% 1 0.136s + + tactic local total calls max +────────────────────────────────────────┴──────┴──────┴───────┴─────────┘ +─refine_reflectively_gen --------------- 0.0% 100.0% 1 5.680s + ├─ReflectiveTactics.do_reflective_pipel 0.0% 95.8% 1 5.444s + │└ReflectiveTactics.solve_side_conditio 0.4% 95.6% 1 5.428s + │ ├─ReflectiveTactics.do_reify -------- 46.0% 61.7% 1 3.504s + │ │ ├─Reify_rhs_gen ------------------- 0.7% 8.3% 1 0.472s + │ │ │└prove_interp_compile_correct ---- 0.0% 2.7% 1 0.152s + │ │ │└rewrite ?EtaInterp.InterpExprEta 2.5% 2.5% 1 0.140s + │ │ └─eexact -------------------------- 7.2% 7.2% 131 0.012s + │ ├─UnifyAbstractReflexivity.unify_tran 22.9% 28.4% 7 0.372s + │ │ ├─ClearAll.clear_all -------------- 0.4% 2.5% 7 0.036s + │ │ │└clear (var_list) ---------------- 2.1% 2.1% 70 0.028s + │ │ └─unify (constr) (constr) --------- 2.3% 2.3% 5 0.044s + │ └─ReflectiveTactics.unify_abstract_cb 3.9% 4.9% 1 0.280s + └─Glue.refine_to_reflective_glue' ----- 0.0% 4.2% 1 0.236s + └Glue.zrange_to_reflective ----------- 0.1% 3.3% 1 0.188s + └Glue.zrange_to_reflective_goal ------ 2.0% 2.4% 1 0.136s + +*) + Time Defined. (* Finished transaction in 3.607 secs (3.607u,0.s) (successful) *) + +End BoundedField25p5. diff --git a/src/Specific/IntegrationTestFreezeDisplay.v b/src/Specific/IntegrationTestFreezeDisplay.v new file mode 100644 index 000000000..c6d13e05a --- /dev/null +++ b/src/Specific/IntegrationTestFreezeDisplay.v @@ -0,0 +1,4 @@ +Require Import Crypto.Specific.IntegrationTestFreeze. +Require Import Crypto.Specific.IntegrationTestDisplayCommon. + +Check display freeze. -- cgit v1.2.3