diff options
author | Jason Gross <jagro@google.com> | 2016-08-25 11:48:05 -0700 |
---|---|---|
committer | Jason Gross <jagro@google.com> | 2016-08-25 11:48:05 -0700 |
commit | bf0d9280ebf806eef8ee3280f7976edb3282ae6e (patch) | |
tree | b7a6e171a03a1912c480da38d1d961e7324f92f3 | |
parent | 34d53cc72df1a3c31838e0cc7e06f0cf8959d628 (diff) |
Integrate suggestions from Andres
-rw-r--r-- | _CoqProject | 1 | ||||
-rw-r--r-- | src/BoundedArithmetic/ArchitectureToZLikeProofs.v | 1 | ||||
-rw-r--r-- | src/BoundedArithmetic/DoubleBounded.v | 11 | ||||
-rw-r--r-- | src/BoundedArithmetic/DoubleBoundedProofs.v | 17 | ||||
-rw-r--r-- | src/BoundedArithmetic/Interface.v | 109 | ||||
-rw-r--r-- | src/BoundedArithmetic/InterfaceProofs.v | 59 | ||||
-rw-r--r-- | src/Util/AutoRewrite.v | 56 | ||||
-rw-r--r-- | src/Util/Notations.v | 1 |
8 files changed, 147 insertions, 108 deletions
diff --git a/_CoqProject b/_CoqProject index 9ed483860..40804dcd9 100644 --- a/_CoqProject +++ b/_CoqProject @@ -74,6 +74,7 @@ src/Specific/GF25519.v src/Tactics/VerdiTactics.v src/Tactics/Algebra_syntax/Nsatz.v src/Util/AdditionChainExponentiation.v +src/Util/AutoRewrite.v src/Util/Bool.v src/Util/CaseUtil.v src/Util/Decidable.v diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v index 3060e17bb..804296374 100644 --- a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v +++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v @@ -93,7 +93,6 @@ Section fancy_machine_p256_montgomery_foundation. { abstract t. } { abstract t. } { abstract t. } -Hint Resolve Z.div_pos : zarith. { abstract t. } { abstract t. } { abstract t. } diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index 55e46aa2b..b624c5082 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -7,7 +7,6 @@ Require Import Crypto.Util.Tuple. Require Import Crypto.Util.ListUtil. Require Import Crypto.Util.Notations. -Local Open Scope list_scope. Local Open Scope nat_scope. Local Open Scope Z_scope. Local Open Scope type_scope. @@ -23,16 +22,6 @@ Global Arguments tuple_decoder : simpl never. Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances. Section ripple_carry_definitions. - Definition ripple_carry {T} (f : T -> T -> bool -> bool * T) - (xs ys : list T) (carry : bool) : bool * list T - := List.fold_right - (fun x_y carry_zs => let '(x, y) := eta x_y in - let '(carry, zs) := eta carry_zs in - let '(carry, z) := eta (f x y carry) in - (carry, z :: zs)) - (carry, nil) - (List.combine xs ys). - (** tuple is high to low ([to_list] reverses) *) Fixpoint ripple_carry_tuple' {T} (f : T -> T -> bool -> bool * T) k : forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k diff --git a/src/BoundedArithmetic/DoubleBoundedProofs.v b/src/BoundedArithmetic/DoubleBoundedProofs.v index b69232076..53ac59d00 100644 --- a/src/BoundedArithmetic/DoubleBoundedProofs.v +++ b/src/BoundedArithmetic/DoubleBoundedProofs.v @@ -16,16 +16,15 @@ Require Import Crypto.Util.Notations. Import ListNotations. Local Open Scope list_scope. Local Open Scope nat_scope. -Local Open Scope Z_scope. Local Open Scope type_scope. +Local Open Scope Z_scope. Local Coercion Z.of_nat : nat >-> Z. Local Coercion Pos.to_nat : positive >-> nat. Local Notation eta x := (fst x, snd x). -Local Infix "==" := rewrite_eq. -Local Infix "=~>" := rewrite_left_to_right_eq. -Local Infix "<~=" := rewrite_right_to_left_eq. +Import BoundedRewriteNotations. +Local Open Scope Z_scope. Section decode. Context {n W} {decode : decoder n W}. @@ -99,12 +98,12 @@ Section decode. Global Instance tuple_decoder_m1 w : tuple_decoder (k := 0) w =~> 0. Proof. reflexivity. Qed. - Global Instance tuple_decoder_2' w : bounded_le_cls 0 n -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << (1%nat * n))%Z. + Global Instance tuple_decoder_2' w : (0 <= n)%bounded_rewrite -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << (1%nat * n))%Z. Proof. intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption. reflexivity. Qed. - Global Instance tuple_decoder_2 w : bounded_le_cls 0 n -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << n)%Z. + Global Instance tuple_decoder_2 w : (0 <= n)%bounded_rewrite -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << n)%Z. Proof. intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption. autorewrite with zsimplify_const; reflexivity. @@ -205,8 +204,8 @@ Global Instance decode_is_spread_left_immediate {isdecode : is_decode decode} {issprl : is_spread_left_immediate sprl} : forall r count, - 0 <= count < n - -> tuple_decoder (sprl r count) == decode r << count + (0 <= count < n)%bounded_rewrite + -> tuple_decoder (sprl r count) <~=~> decode r << count := proj1 decode_is_spread_left_immediate_iff _. Lemma decode_mul_double_iff @@ -233,7 +232,7 @@ Global Instance decode_mul_double {muldw : multiply_double W} {isdecode : is_decode decode} {ismuldw : is_mul_double muldw} - : forall x y, tuple_decoder (muldw x y) == (decode x * decode y)%Z + : forall x y, tuple_decoder (muldw x y) <~=~> (decode x * decode y)%Z := proj1 decode_mul_double_iff _. Lemma ripple_carry_tuple_SS {T} f k xss yss carry diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index 00528a053..152c43cee 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -2,10 +2,11 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.AutoRewrite. Require Import Crypto.Util.Notations. -Local Open Scope Z_scope. Local Open Scope type_scope. +Local Open Scope Z_scope. Class decoder (n : Z) W := { decode : W -> Z }. @@ -15,22 +16,6 @@ Global Arguments decode {n W _} _. Class is_decode {n W} (decode : decoder n W) := decode_range : forall x, 0 <= decode x < 2^n. -Class rewrite_eq {A} (x y : A) - := by_rewrite : x = y. -Arguments by_rewrite {A} _ _ {_}. - -Class rewrite_right_to_left_eq {A} (x y : A) - := by_rewrite_right_to_left : rewrite_eq x y. -Arguments by_rewrite_right_to_left {A} _ _ {_}. -Global Instance unfold_rewrite_right_to_left_eq {A x y} (H : @rewrite_eq A x y) - : @rewrite_right_to_left_eq A x y := H. - -Class rewrite_left_to_right_eq {A} (x y : A) - := by_rewrite_left_to_right : rewrite_eq x y. -Arguments by_rewrite_left_to_right {A} _ _ {_}. -Global Instance unfold_rewrite_left_to_right_eq {A x y} (H : @rewrite_eq A x y) - : @rewrite_left_to_right_eq A x y := H. - Class bounded_in_range_cls (x y z : Z) := is_bounded_in_range : x <= y < z. Ltac bounded_solver_tac := solve [ eassumption | typeclasses eauto | omega ]. @@ -44,26 +29,42 @@ Class bounded_le_cls (x y : Z) := is_bounded_le : x <= y. Hint Extern 0 (bounded_le_cls _ _) => unfold bounded_le_cls; bounded_solver_tac : typeclass_instances. Global Arguments bounded_le_cls / . +Inductive bounded_decode_pusher_tag := decode_tag. + Ltac push_decode_step := match goal with | [ |- context[@decode ?n ?W ?decoder ?w] ] - => let lem := constr:(by_rewrite_left_to_right (A := Z) (@decode n W decoder w) _) in - rewrite (lem : @decode n W decoder w = _) + => tc_rewrite (decode_tag) (@decode n W decoder w) -> | [ |- context[match @fst ?A ?B ?x with true => 1 | false => 0 end] ] - => let lem := constr:(by_rewrite_left_to_right (A := Z) (match @fst A B x with true => 1 | false => 0 end) _) in - rewrite (lem : _ = _) + => tc_rewrite (decode_tag) (match @fst A B x with true => 1 | false => 0 end) -> end. Ltac push_decode := repeat push_decode_step. Ltac pull_decode_step := match goal with | [ |- context[?E] ] - => first [ let lem := constr:(by_rewrite_right_to_left (A := Z) _ E) in - rewrite <- (lem : _ = E) - | let lem := constr:(by_rewrite_right_to_left (A := bool) _ E) in - rewrite <- (lem : _ = E) ] + => lazymatch type of E with + | Z => idtac + | bool => idtac + end; + tc_rewrite (decode_tag) <- E end. Ltac pull_decode := repeat pull_decode_step. +Delimit Scope bounded_rewrite_scope with bounded_rewrite. + +Infix "<~=~>" := (rewrite_eq decode_tag) : bounded_rewrite_scope. +Infix "=~>" := (rewrite_left_to_right_eq decode_tag) : bounded_rewrite_scope. +Infix "<~=" := (rewrite_right_to_left_eq decode_tag) : bounded_rewrite_scope. +Notation "x <= y" := (bounded_le_cls x y) : bounded_rewrite_scope. +Notation "x <= y < z" := (bounded_in_range_cls x y z) : bounded_rewrite_scope. + +Module Import BoundedRewriteNotations. + Infix "<~=~>" := (rewrite_eq decode_tag) : type_scope. + Infix "=~>" := (rewrite_left_to_right_eq decode_tag) : type_scope. + Infix "<~=" := (rewrite_right_to_left_eq decode_tag) : type_scope. + Open Scope bounded_rewrite_scope. +End BoundedRewriteNotations. + (** This is required for typeclass resolution to be fast. *) Typeclasses Opaque decode. @@ -71,10 +72,6 @@ Section InstructionGallery. Context (n : Z) (* bit-width of width of [W] *) {W : Type} (* bounded type, [W] for word *) (Wdecoder : decoder n W). - Local Infix "==" := rewrite_eq. - Local Infix "=~>" := rewrite_left_to_right_eq. - Local Infix "<~=" := rewrite_right_to_left_eq. - Local Notation "x <= y < z" := (bounded_in_range_cls x y z). Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) Class load_immediate := { ldi : imm -> W }. @@ -90,21 +87,21 @@ Section InstructionGallery. decode_shift_right_doubleword :> forall high low count, 0 <= count < n - -> decode (shrd high low count) == (((decode high << n) + decode low) >> count) mod 2^n. + -> decode (shrd high low count) <~=~> (((decode high << n) + decode low) >> count) mod 2^n. Class shift_left_immediate := { shl : W -> imm -> W }. Global Coercion shl : shift_left_immediate >-> Funclass. Class is_shift_left_immediate (shl : shift_left_immediate) := decode_shift_left_immediate :> - forall r count, 0 <= count < n -> decode (shl r count) == (decode r << count) mod 2^n. + forall r count, 0 <= count < n -> decode (shl r count) <~=~> (decode r << count) mod 2^n. Class shift_right_immediate := { shr : W -> imm -> W }. Global Coercion shr : shift_right_immediate >-> Funclass. Class is_shift_right_immediate (shr : shift_right_immediate) := decode_shift_right_immediate :> - forall r count, 0 <= count < n -> decode (shr r count) == (decode r >> count). + forall r count, 0 <= count < n -> decode (shr r count) <~=~> (decode r >> count). Class spread_left_immediate := { sprl : W -> imm -> tuple W 2 (* [(low, high)] *) }. Global Coercion sprl : spread_left_immediate >-> Funclass. @@ -120,19 +117,12 @@ Section InstructionGallery. }. - Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate) - (pf : forall r count, 0 <= count < n - -> decode (fst (sprl r count)) = (decode r << count) mod 2^n - /\ decode (snd (sprl r count)) = (decode r << count) >> n) - := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H); - decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}. - Class mask_keep_low := { mkl :> W -> imm -> W }. Global Coercion mkl : mask_keep_low >-> Funclass. Class is_mask_keep_low (mkl : mask_keep_low) := decode_mask_keep_low :> forall r count, - 0 <= count < n -> decode (mkl r count) == decode r mod 2^count. + 0 <= count < n -> decode (mkl r count) <~=~> decode r mod 2^count. Local Notation bit b := (if b then 1 else 0). @@ -141,34 +131,24 @@ Section InstructionGallery. Class is_add_with_carry (adc : add_with_carry) := { - bit_fst_add_with_carry :> forall x y c, bit (fst (adc x y c)) == (decode x + decode y + bit c) >> n; - decode_snd_add_with_carry :> forall x y c, decode (snd (adc x y c)) == (decode x + decode y + bit c) mod (2^n) + bit_fst_add_with_carry :> forall x y c, bit (fst (adc x y c)) <~=~> (decode x + decode y + bit c) >> n; + decode_snd_add_with_carry :> forall x y c, decode (snd (adc x y c)) <~=~> (decode x + decode y + bit c) mod (2^n) }. - Definition Build_is_add_with_carry' (adc : add_with_carry) - (pf : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n /\ decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n)) - := {| bit_fst_add_with_carry x y c := proj1 (pf x y c); - decode_snd_add_with_carry x y c := proj2 (pf x y c) |}. - Class sub_with_carry := { subc : W -> W -> bool -> bool * W }. Global Coercion subc : sub_with_carry >-> Funclass. Class is_sub_with_carry (subc:W->W->bool->bool*W) := { - fst_sub_with_carry :> forall x y c, fst (subc x y c) == ((decode x - decode y - bit c) <? 0); - decode_snd_sub_with_carry :> forall x y c, decode (snd (subc x y c)) == (decode x - decode y - bit c) mod 2^n + fst_sub_with_carry :> forall x y c, fst (subc x y c) <~=~> ((decode x - decode y - bit c) <? 0); + decode_snd_sub_with_carry :> forall x y c, decode (snd (subc x y c)) <~=~> (decode x - decode y - bit c) mod 2^n }. - Definition Build_is_sub_with_carry' (subc : sub_with_carry) - (pf : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) <? 0) /\ decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n) - := {| fst_sub_with_carry x y c := proj1 (pf x y c); - decode_snd_sub_with_carry x y c := proj2 (pf x y c) |}. - Class multiply := { mul : W -> W -> W }. Global Coercion mul : multiply >-> Funclass. Class is_mul (mul : multiply) := - decode_mul :> forall x y, decode (mul x y) == (decode x * decode y)%Z. + decode_mul :> forall x y, decode (mul x y) <~=~> (decode x * decode y). Class multiply_low_low := { mulhwll : W -> W -> W }. Global Coercion mulhwll : multiply_low_low >-> Funclass. @@ -181,13 +161,13 @@ Section InstructionGallery. Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := decode_mul_low_low :> - forall x y, decode (mulhwll x y) == ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. + forall x y, decode (mulhwll x y) <~=~> ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) := decode_mul_high_low :> - forall x y, decode (mulhwhl x y) == ((decode x >> w) * (decode y mod 2^w)) mod 2^n. + forall x y, decode (mulhwhl x y) <~=~> ((decode x >> w) * (decode y mod 2^w)) mod 2^n. Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) := decode_mul_high_high :> - forall x y, decode (mulhwhh x y) == ((decode x >> w) * (decode y >> w)) mod 2^n. + forall x y, decode (mulhwhh x y) <~=~> ((decode x >> w) * (decode y >> w)) mod 2^n. Class is_mul_double (muldw : multiply_double) := { decode_fst_mul_double :> @@ -195,27 +175,22 @@ Section InstructionGallery. decode_snd_mul_double :> forall x y, decode (snd (muldw x y)) =~> (decode x * decode y) >> n }. - Definition Build_is_mul_double' (muldw : multiply_double) - (pf : forall x y, _ /\ _) - := {| decode_fst_mul_double x y := proj1 (pf x y); - decode_snd_mul_double x y := proj2 (pf x y) |}. - Class select_conditional := { selc : bool -> W -> W -> W }. Global Coercion selc : select_conditional >-> Funclass. Class is_select_conditional (selc : select_conditional) := decode_select_conditional :> forall b x y, - decode (selc b x y) == if b then decode x else decode y. + decode (selc b x y) <~=~> if b then decode x else decode y. Class add_modulo := { addm : W -> W -> W (* modulus *) -> W }. Global Coercion addm : add_modulo >-> Funclass. Class is_add_modulo (addm : add_modulo) := decode_add_modulo :> forall x y modulus, - decode (addm x y modulus) == (if (decode x + decode y) <? decode modulus - then (decode x + decode y) - else (decode x + decode y) - decode modulus)%Z. + decode (addm x y modulus) <~=~> (if (decode x + decode y) <? decode modulus + then (decode x + decode y) + else (decode x + decode y) - decode modulus). End InstructionGallery. Global Arguments load_immediate : clear implicits. diff --git a/src/BoundedArithmetic/InterfaceProofs.v b/src/BoundedArithmetic/InterfaceProofs.v index b8e20c607..8256fc23f 100644 --- a/src/BoundedArithmetic/InterfaceProofs.v +++ b/src/BoundedArithmetic/InterfaceProofs.v @@ -3,14 +3,13 @@ Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. Require Import Crypto.BoundedArithmetic.Interface. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.AutoRewrite. Require Import Crypto.Util.Notations. Local Open Scope type_scope. Local Open Scope Z_scope. -Local Infix "==" := rewrite_eq. -Local Infix "=~>" := rewrite_left_to_right_eq. -Local Infix "<~=" := rewrite_right_to_left_eq. +Import BoundedRewriteNotations. Local Notation bit b := (if b then 1 else 0). Section InstructionGallery. @@ -19,11 +18,33 @@ Section InstructionGallery. (Wdecoder : decoder n W). Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) + Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate W) + (pf : forall r count, 0 <= count < n + -> _ /\ _) + := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H); + decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}. + + Definition Build_is_add_with_carry' (adc : add_with_carry W) + (pf : forall x y c, _ /\ _) + := {| bit_fst_add_with_carry x y c := proj1 (pf x y c); + decode_snd_add_with_carry x y c := proj2 (pf x y c) |}. + + Definition Build_is_sub_with_carry' (subc : sub_with_carry W) + (pf : forall x y c, _ /\ _) + : is_sub_with_carry subc + := {| fst_sub_with_carry x y c := proj1 (pf x y c); + decode_snd_sub_with_carry x y c := proj2 (pf x y c) |}. + + Definition Build_is_mul_double' (muldw : multiply_double W) + (pf : forall x y, _ /\ _) + := {| decode_fst_mul_double x y := proj1 (pf x y); + decode_snd_mul_double x y := proj2 (pf x y) |}. + Lemma is_spread_left_immediate_alt {sprl : spread_left_immediate W} {isdecode : is_decode Wdecoder} : is_spread_left_immediate sprl - <-> (forall r count, 0 <= count < n -> decode (fst (sprl r count)) + decode (snd (sprl r count)) << n = (decode r << count) mod (2^n*2^n)). + <-> (forall r count, 0 <= count < n -> decode (fst (sprl r count)) + decode (snd (sprl r count)) << n = (decode r << count) mod (2^n*2^n))%Z. Proof. split; intro H; [ | apply Build_is_spread_left_immediate' ]; intros r count Hc; @@ -31,8 +52,8 @@ Section InstructionGallery. unfold bounded_in_range_cls in *; pose proof (decode_range r); assert (0 < 2^n) by auto with zarith; - assert (0 <= 2^count < 2^n) by auto with zarith; - assert (0 <= decode r * 2^count < 2^n * 2^n) by (generalize dependent (decode r); intros; nia); + assert (0 <= 2^count < 2^n)%Z by auto with zarith; + assert (0 <= decode r * 2^count < 2^n * 2^n)%Z by (generalize dependent (decode r); intros; nia); rewrite ?decode_fst_spread_left_immediate, ?decode_snd_spread_left_immediate by typeclasses eauto with typeclass_instances core; autorewrite with Zshift_to_pow zsimplify push_Zpow. @@ -53,7 +74,7 @@ Section InstructionGallery. pose proof (decode_range x); pose proof (decode_range y); assert (0 < 2^n) by auto with zarith; - assert (0 <= decode x * decode y < 2^n * 2^n) by nia; + assert (0 <= decode x * decode y < 2^n * 2^n)%Z by nia; (destruct (0 <=? n) eqn:?; Z.ltb_to_lt; [ | assert (2^n = 0) by auto with zarith; exfalso; omega ]); rewrite ?decode_fst_mul_double, ?decode_snd_mul_double @@ -65,8 +86,6 @@ Section InstructionGallery. Qed. End InstructionGallery. -Local Notation "x <= y < z" := (bounded_in_range_cls x y z). - Global Arguments is_spread_left_immediate_alt {_ _ _ _ _}. Global Arguments is_mul_double_alt {_ _ _ _ _}. @@ -95,7 +114,7 @@ Proof. exact _. Qed. Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode} (isinhabited : W) - : 0 <= n. + : (0 <= n)%Z. Proof. pose proof (decode_range isinhabited). assert (0 < 2^n) by omega. @@ -113,13 +132,13 @@ Section adc_subc. {isadc : is_add_with_carry adc} {issubc : is_sub_with_carry subc}. Global Instance bit_fst_add_with_carry_false - : forall x y, bit (fst (adc x y false)) == (decode x + decode y) >> n. + : forall x y, bit (fst (adc x y false)) <~=~> (decode x + decode y) >> n. Proof. intros; erewrite bit_fst_add_with_carry by assumption. autorewrite with zsimplify_const; reflexivity. Qed. Global Instance bit_fst_add_with_carry_true - : forall x y, bit (fst (adc x y true)) == (decode x + decode y + 1) >> n. + : forall x y, bit (fst (adc x y true)) <~=~> (decode x + decode y + 1) >> n. Proof. intros; erewrite bit_fst_add_with_carry by assumption. autorewrite with zsimplify_const; reflexivity. @@ -128,9 +147,9 @@ Section adc_subc. : forall x y c, fst (adc x y c) <~= (2^n <=? (decode x + decode y + bit c)). Proof. intros x y c; hnf. - assert (0 <= n) by eauto using decode_exponent_nonnegative. + assert (0 <= n)%Z by eauto using decode_exponent_nonnegative. pose proof (decode_range x); pose proof (decode_range y). - assert (0 <= bit c <= 1) by (destruct c; omega). + assert (0 <= bit c <= 1)%Z by (destruct c; omega). lazymatch goal with | [ |- fst ?x = (?a <=? ?b) :> bool ] => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); @@ -148,30 +167,30 @@ Section adc_subc. autorewrite with zsimplify_const; reflexivity. Qed. Global Instance fst_add_with_carry_true_leb - : forall x y, fst (adc x y true) == (2^n <=? (decode x + decode y + 1)). + : forall x y, fst (adc x y true) <~=~> (2^n <=? (decode x + decode y + 1)). Proof. intros; erewrite fst_add_with_carry_leb by assumption. autorewrite with zsimplify_const; reflexivity. Qed. Global Instance fst_sub_with_carry_false - : forall x y, fst (subc x y false) == ((decode x - decode y) <? 0). + : forall x y, fst (subc x y false) <~=~> ((decode x - decode y) <? 0). Proof. intros; erewrite fst_sub_with_carry by assumption. autorewrite with zsimplify_const; reflexivity. Qed. Global Instance fst_sub_with_carry_true - : forall x y, fst (subc x y true) == ((decode x - decode y - 1) <? 0). + : forall x y, fst (subc x y true) <~=~> ((decode x - decode y - 1) <? 0). Proof. intros; erewrite fst_sub_with_carry by assumption. autorewrite with zsimplify_const; reflexivity. Qed. End adc_subc. -Hint Extern 2 (rewrite_right_to_left_eq _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y))) +Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y))) => apply @fst_add_with_carry_false_leb : typeclass_instances. -Hint Extern 2 (rewrite_right_to_left_eq _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + 1))) +Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + 1))) => apply @fst_add_with_carry_true_leb : typeclass_instances. -Hint Extern 2 (rewrite_right_to_left_eq _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + if ?c then _ else _))) +Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + if ?c then _ else _))) => apply @fst_add_with_carry_leb : typeclass_instances. diff --git a/src/Util/AutoRewrite.v b/src/Util/AutoRewrite.v new file mode 100644 index 000000000..b5e276f20 --- /dev/null +++ b/src/Util/AutoRewrite.v @@ -0,0 +1,56 @@ +(** * Machinery for reimplementing some bits of [rewrite_strat] with open pattern databases *) +Require Import Crypto.Util.Notations. +(** We build classes for rewriting in each direction, and pick lemmas + by resolving on tags and the term to be rewritten. *) +(** Base class, for bidirectional rewriting. *) +Class rewrite_eq {tagT} (tag : tagT) {A} (x y : A) + := by_rewrite : x = y. +Arguments by_rewrite {tagT} tag {A} _ _ {_}. +Infix "<~=~>" := (rewrite_eq _) : type_scope. + +Class rewrite_right_to_left_eq {tagT} (tag : tagT) {A} (x y : A) + := by_rewrite_right_to_left : rewrite_eq tag x y. +Arguments by_rewrite_right_to_left {tagT} tag {A} _ _ {_}. +Global Instance unfold_rewrite_right_to_left_eq {tagT tag A x y} (H : @rewrite_eq tagT tag A x y) + : @rewrite_right_to_left_eq tagT tag A x y := H. +Infix "<~=" := (rewrite_right_to_left_eq _) : type_scope. + +Class rewrite_left_to_right_eq {tagT} (tag : tagT) {A} (x y : A) + := by_rewrite_left_to_right : rewrite_eq tag x y. +Arguments by_rewrite_left_to_right {tagT} tag {A} _ _ {_}. +Global Instance unfold_rewrite_left_to_right_eq {tagT tag A x y} (H : @rewrite_eq tagT tag A x y) + : @rewrite_left_to_right_eq tagT tag A x y := H. +Infix "=~>" := (rewrite_left_to_right_eq _) : type_scope. + +Ltac typeclass_do_left_to_right tag from tac := + let lem := constr:(by_rewrite_left_to_right tag from _ : from = _) in tac lem. +Ltac typeclass_do_right_to_left tag from tac := + let lem := constr:(by_rewrite_right_to_left tag _ from : _ = from) in tac lem. + + +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" hyp_list(H) := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in H). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" hyp_list(H) "|-" "*" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in H |- *). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" "|-" "*" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * |- * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" "|-" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * |- ). + + +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" hyp_list(H) := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in H). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" hyp_list(H) "|-" "*" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in H |- *). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" "|-" "*" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * |- * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" "|-" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * |- ). diff --git a/src/Util/Notations.v b/src/Util/Notations.v index 686443829..8795da399 100644 --- a/src/Util/Notations.v +++ b/src/Util/Notations.v @@ -27,6 +27,7 @@ Reserved Infix "~=" (at level 70). Reserved Infix "==" (at level 70, no associativity). Reserved Infix "=~>" (at level 70, no associativity). Reserved Infix "<~=" (at level 70, no associativity). +Reserved Infix "<~=~>" (at level 70, no associativity). Reserved Infix "≡" (at level 70, no associativity). Reserved Infix "≢" (at level 70, no associativity). Reserved Infix "≡_n" (at level 70, no associativity). |