aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jagro@google.com>2016-08-25 11:48:05 -0700
committerGravatar Jason Gross <jagro@google.com>2016-08-25 11:48:05 -0700
commitbf0d9280ebf806eef8ee3280f7976edb3282ae6e (patch)
treeb7a6e171a03a1912c480da38d1d961e7324f92f3 /src
parent34d53cc72df1a3c31838e0cc7e06f0cf8959d628 (diff)
Integrate suggestions from Andres
Diffstat (limited to 'src')
-rw-r--r--src/BoundedArithmetic/ArchitectureToZLikeProofs.v1
-rw-r--r--src/BoundedArithmetic/DoubleBounded.v11
-rw-r--r--src/BoundedArithmetic/DoubleBoundedProofs.v17
-rw-r--r--src/BoundedArithmetic/Interface.v109
-rw-r--r--src/BoundedArithmetic/InterfaceProofs.v59
-rw-r--r--src/Util/AutoRewrite.v56
-rw-r--r--src/Util/Notations.v1
7 files changed, 146 insertions, 108 deletions
diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v
index 3060e17bb..804296374 100644
--- a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v
+++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v
@@ -93,7 +93,6 @@ Section fancy_machine_p256_montgomery_foundation.
{ abstract t. }
{ abstract t. }
{ abstract t. }
-Hint Resolve Z.div_pos : zarith.
{ abstract t. }
{ abstract t. }
{ abstract t. }
diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v
index 55e46aa2b..b624c5082 100644
--- a/src/BoundedArithmetic/DoubleBounded.v
+++ b/src/BoundedArithmetic/DoubleBounded.v
@@ -7,7 +7,6 @@ Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.ListUtil.
Require Import Crypto.Util.Notations.
-Local Open Scope list_scope.
Local Open Scope nat_scope.
Local Open Scope Z_scope.
Local Open Scope type_scope.
@@ -23,16 +22,6 @@ Global Arguments tuple_decoder : simpl never.
Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances.
Section ripple_carry_definitions.
- Definition ripple_carry {T} (f : T -> T -> bool -> bool * T)
- (xs ys : list T) (carry : bool) : bool * list T
- := List.fold_right
- (fun x_y carry_zs => let '(x, y) := eta x_y in
- let '(carry, zs) := eta carry_zs in
- let '(carry, z) := eta (f x y carry) in
- (carry, z :: zs))
- (carry, nil)
- (List.combine xs ys).
-
(** tuple is high to low ([to_list] reverses) *)
Fixpoint ripple_carry_tuple' {T} (f : T -> T -> bool -> bool * T) k
: forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k
diff --git a/src/BoundedArithmetic/DoubleBoundedProofs.v b/src/BoundedArithmetic/DoubleBoundedProofs.v
index b69232076..53ac59d00 100644
--- a/src/BoundedArithmetic/DoubleBoundedProofs.v
+++ b/src/BoundedArithmetic/DoubleBoundedProofs.v
@@ -16,16 +16,15 @@ Require Import Crypto.Util.Notations.
Import ListNotations.
Local Open Scope list_scope.
Local Open Scope nat_scope.
-Local Open Scope Z_scope.
Local Open Scope type_scope.
+Local Open Scope Z_scope.
Local Coercion Z.of_nat : nat >-> Z.
Local Coercion Pos.to_nat : positive >-> nat.
Local Notation eta x := (fst x, snd x).
-Local Infix "==" := rewrite_eq.
-Local Infix "=~>" := rewrite_left_to_right_eq.
-Local Infix "<~=" := rewrite_right_to_left_eq.
+Import BoundedRewriteNotations.
+Local Open Scope Z_scope.
Section decode.
Context {n W} {decode : decoder n W}.
@@ -99,12 +98,12 @@ Section decode.
Global Instance tuple_decoder_m1 w : tuple_decoder (k := 0) w =~> 0.
Proof. reflexivity. Qed.
- Global Instance tuple_decoder_2' w : bounded_le_cls 0 n -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << (1%nat * n))%Z.
+ Global Instance tuple_decoder_2' w : (0 <= n)%bounded_rewrite -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << (1%nat * n))%Z.
Proof.
intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption.
reflexivity.
Qed.
- Global Instance tuple_decoder_2 w : bounded_le_cls 0 n -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << n)%Z.
+ Global Instance tuple_decoder_2 w : (0 <= n)%bounded_rewrite -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << n)%Z.
Proof.
intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption.
autorewrite with zsimplify_const; reflexivity.
@@ -205,8 +204,8 @@ Global Instance decode_is_spread_left_immediate
{isdecode : is_decode decode}
{issprl : is_spread_left_immediate sprl}
: forall r count,
- 0 <= count < n
- -> tuple_decoder (sprl r count) == decode r << count
+ (0 <= count < n)%bounded_rewrite
+ -> tuple_decoder (sprl r count) <~=~> decode r << count
:= proj1 decode_is_spread_left_immediate_iff _.
Lemma decode_mul_double_iff
@@ -233,7 +232,7 @@ Global Instance decode_mul_double
{muldw : multiply_double W}
{isdecode : is_decode decode}
{ismuldw : is_mul_double muldw}
- : forall x y, tuple_decoder (muldw x y) == (decode x * decode y)%Z
+ : forall x y, tuple_decoder (muldw x y) <~=~> (decode x * decode y)%Z
:= proj1 decode_mul_double_iff _.
Lemma ripple_carry_tuple_SS {T} f k xss yss carry
diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v
index 00528a053..152c43cee 100644
--- a/src/BoundedArithmetic/Interface.v
+++ b/src/BoundedArithmetic/Interface.v
@@ -2,10 +2,11 @@
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Util.ZUtil.
Require Import Crypto.Util.Tuple.
+Require Import Crypto.Util.AutoRewrite.
Require Import Crypto.Util.Notations.
-Local Open Scope Z_scope.
Local Open Scope type_scope.
+Local Open Scope Z_scope.
Class decoder (n : Z) W :=
{ decode : W -> Z }.
@@ -15,22 +16,6 @@ Global Arguments decode {n W _} _.
Class is_decode {n W} (decode : decoder n W) :=
decode_range : forall x, 0 <= decode x < 2^n.
-Class rewrite_eq {A} (x y : A)
- := by_rewrite : x = y.
-Arguments by_rewrite {A} _ _ {_}.
-
-Class rewrite_right_to_left_eq {A} (x y : A)
- := by_rewrite_right_to_left : rewrite_eq x y.
-Arguments by_rewrite_right_to_left {A} _ _ {_}.
-Global Instance unfold_rewrite_right_to_left_eq {A x y} (H : @rewrite_eq A x y)
- : @rewrite_right_to_left_eq A x y := H.
-
-Class rewrite_left_to_right_eq {A} (x y : A)
- := by_rewrite_left_to_right : rewrite_eq x y.
-Arguments by_rewrite_left_to_right {A} _ _ {_}.
-Global Instance unfold_rewrite_left_to_right_eq {A x y} (H : @rewrite_eq A x y)
- : @rewrite_left_to_right_eq A x y := H.
-
Class bounded_in_range_cls (x y z : Z) := is_bounded_in_range : x <= y < z.
Ltac bounded_solver_tac :=
solve [ eassumption | typeclasses eauto | omega ].
@@ -44,26 +29,42 @@ Class bounded_le_cls (x y : Z) := is_bounded_le : x <= y.
Hint Extern 0 (bounded_le_cls _ _) => unfold bounded_le_cls; bounded_solver_tac : typeclass_instances.
Global Arguments bounded_le_cls / .
+Inductive bounded_decode_pusher_tag := decode_tag.
+
Ltac push_decode_step :=
match goal with
| [ |- context[@decode ?n ?W ?decoder ?w] ]
- => let lem := constr:(by_rewrite_left_to_right (A := Z) (@decode n W decoder w) _) in
- rewrite (lem : @decode n W decoder w = _)
+ => tc_rewrite (decode_tag) (@decode n W decoder w) ->
| [ |- context[match @fst ?A ?B ?x with true => 1 | false => 0 end] ]
- => let lem := constr:(by_rewrite_left_to_right (A := Z) (match @fst A B x with true => 1 | false => 0 end) _) in
- rewrite (lem : _ = _)
+ => tc_rewrite (decode_tag) (match @fst A B x with true => 1 | false => 0 end) ->
end.
Ltac push_decode := repeat push_decode_step.
Ltac pull_decode_step :=
match goal with
| [ |- context[?E] ]
- => first [ let lem := constr:(by_rewrite_right_to_left (A := Z) _ E) in
- rewrite <- (lem : _ = E)
- | let lem := constr:(by_rewrite_right_to_left (A := bool) _ E) in
- rewrite <- (lem : _ = E) ]
+ => lazymatch type of E with
+ | Z => idtac
+ | bool => idtac
+ end;
+ tc_rewrite (decode_tag) <- E
end.
Ltac pull_decode := repeat pull_decode_step.
+Delimit Scope bounded_rewrite_scope with bounded_rewrite.
+
+Infix "<~=~>" := (rewrite_eq decode_tag) : bounded_rewrite_scope.
+Infix "=~>" := (rewrite_left_to_right_eq decode_tag) : bounded_rewrite_scope.
+Infix "<~=" := (rewrite_right_to_left_eq decode_tag) : bounded_rewrite_scope.
+Notation "x <= y" := (bounded_le_cls x y) : bounded_rewrite_scope.
+Notation "x <= y < z" := (bounded_in_range_cls x y z) : bounded_rewrite_scope.
+
+Module Import BoundedRewriteNotations.
+ Infix "<~=~>" := (rewrite_eq decode_tag) : type_scope.
+ Infix "=~>" := (rewrite_left_to_right_eq decode_tag) : type_scope.
+ Infix "<~=" := (rewrite_right_to_left_eq decode_tag) : type_scope.
+ Open Scope bounded_rewrite_scope.
+End BoundedRewriteNotations.
+
(** This is required for typeclass resolution to be fast. *)
Typeclasses Opaque decode.
@@ -71,10 +72,6 @@ Section InstructionGallery.
Context (n : Z) (* bit-width of width of [W] *)
{W : Type} (* bounded type, [W] for word *)
(Wdecoder : decoder n W).
- Local Infix "==" := rewrite_eq.
- Local Infix "=~>" := rewrite_left_to_right_eq.
- Local Infix "<~=" := rewrite_right_to_left_eq.
- Local Notation "x <= y < z" := (bounded_in_range_cls x y z).
Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *)
Class load_immediate := { ldi : imm -> W }.
@@ -90,21 +87,21 @@ Section InstructionGallery.
decode_shift_right_doubleword :>
forall high low count,
0 <= count < n
- -> decode (shrd high low count) == (((decode high << n) + decode low) >> count) mod 2^n.
+ -> decode (shrd high low count) <~=~> (((decode high << n) + decode low) >> count) mod 2^n.
Class shift_left_immediate := { shl : W -> imm -> W }.
Global Coercion shl : shift_left_immediate >-> Funclass.
Class is_shift_left_immediate (shl : shift_left_immediate) :=
decode_shift_left_immediate :>
- forall r count, 0 <= count < n -> decode (shl r count) == (decode r << count) mod 2^n.
+ forall r count, 0 <= count < n -> decode (shl r count) <~=~> (decode r << count) mod 2^n.
Class shift_right_immediate := { shr : W -> imm -> W }.
Global Coercion shr : shift_right_immediate >-> Funclass.
Class is_shift_right_immediate (shr : shift_right_immediate) :=
decode_shift_right_immediate :>
- forall r count, 0 <= count < n -> decode (shr r count) == (decode r >> count).
+ forall r count, 0 <= count < n -> decode (shr r count) <~=~> (decode r >> count).
Class spread_left_immediate := { sprl : W -> imm -> tuple W 2 (* [(low, high)] *) }.
Global Coercion sprl : spread_left_immediate >-> Funclass.
@@ -120,19 +117,12 @@ Section InstructionGallery.
}.
- Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate)
- (pf : forall r count, 0 <= count < n
- -> decode (fst (sprl r count)) = (decode r << count) mod 2^n
- /\ decode (snd (sprl r count)) = (decode r << count) >> n)
- := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H);
- decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}.
-
Class mask_keep_low := { mkl :> W -> imm -> W }.
Global Coercion mkl : mask_keep_low >-> Funclass.
Class is_mask_keep_low (mkl : mask_keep_low) :=
decode_mask_keep_low :> forall r count,
- 0 <= count < n -> decode (mkl r count) == decode r mod 2^count.
+ 0 <= count < n -> decode (mkl r count) <~=~> decode r mod 2^count.
Local Notation bit b := (if b then 1 else 0).
@@ -141,34 +131,24 @@ Section InstructionGallery.
Class is_add_with_carry (adc : add_with_carry) :=
{
- bit_fst_add_with_carry :> forall x y c, bit (fst (adc x y c)) == (decode x + decode y + bit c) >> n;
- decode_snd_add_with_carry :> forall x y c, decode (snd (adc x y c)) == (decode x + decode y + bit c) mod (2^n)
+ bit_fst_add_with_carry :> forall x y c, bit (fst (adc x y c)) <~=~> (decode x + decode y + bit c) >> n;
+ decode_snd_add_with_carry :> forall x y c, decode (snd (adc x y c)) <~=~> (decode x + decode y + bit c) mod (2^n)
}.
- Definition Build_is_add_with_carry' (adc : add_with_carry)
- (pf : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n /\ decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n))
- := {| bit_fst_add_with_carry x y c := proj1 (pf x y c);
- decode_snd_add_with_carry x y c := proj2 (pf x y c) |}.
-
Class sub_with_carry := { subc : W -> W -> bool -> bool * W }.
Global Coercion subc : sub_with_carry >-> Funclass.
Class is_sub_with_carry (subc:W->W->bool->bool*W) :=
{
- fst_sub_with_carry :> forall x y c, fst (subc x y c) == ((decode x - decode y - bit c) <? 0);
- decode_snd_sub_with_carry :> forall x y c, decode (snd (subc x y c)) == (decode x - decode y - bit c) mod 2^n
+ fst_sub_with_carry :> forall x y c, fst (subc x y c) <~=~> ((decode x - decode y - bit c) <? 0);
+ decode_snd_sub_with_carry :> forall x y c, decode (snd (subc x y c)) <~=~> (decode x - decode y - bit c) mod 2^n
}.
- Definition Build_is_sub_with_carry' (subc : sub_with_carry)
- (pf : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) <? 0) /\ decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n)
- := {| fst_sub_with_carry x y c := proj1 (pf x y c);
- decode_snd_sub_with_carry x y c := proj2 (pf x y c) |}.
-
Class multiply := { mul : W -> W -> W }.
Global Coercion mul : multiply >-> Funclass.
Class is_mul (mul : multiply) :=
- decode_mul :> forall x y, decode (mul x y) == (decode x * decode y)%Z.
+ decode_mul :> forall x y, decode (mul x y) <~=~> (decode x * decode y).
Class multiply_low_low := { mulhwll : W -> W -> W }.
Global Coercion mulhwll : multiply_low_low >-> Funclass.
@@ -181,13 +161,13 @@ Section InstructionGallery.
Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) :=
decode_mul_low_low :>
- forall x y, decode (mulhwll x y) == ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n.
+ forall x y, decode (mulhwll x y) <~=~> ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n.
Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) :=
decode_mul_high_low :>
- forall x y, decode (mulhwhl x y) == ((decode x >> w) * (decode y mod 2^w)) mod 2^n.
+ forall x y, decode (mulhwhl x y) <~=~> ((decode x >> w) * (decode y mod 2^w)) mod 2^n.
Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) :=
decode_mul_high_high :>
- forall x y, decode (mulhwhh x y) == ((decode x >> w) * (decode y >> w)) mod 2^n.
+ forall x y, decode (mulhwhh x y) <~=~> ((decode x >> w) * (decode y >> w)) mod 2^n.
Class is_mul_double (muldw : multiply_double) :=
{
decode_fst_mul_double :>
@@ -195,27 +175,22 @@ Section InstructionGallery.
decode_snd_mul_double :>
forall x y, decode (snd (muldw x y)) =~> (decode x * decode y) >> n
}.
- Definition Build_is_mul_double' (muldw : multiply_double)
- (pf : forall x y, _ /\ _)
- := {| decode_fst_mul_double x y := proj1 (pf x y);
- decode_snd_mul_double x y := proj2 (pf x y) |}.
-
Class select_conditional := { selc : bool -> W -> W -> W }.
Global Coercion selc : select_conditional >-> Funclass.
Class is_select_conditional (selc : select_conditional) :=
decode_select_conditional :> forall b x y,
- decode (selc b x y) == if b then decode x else decode y.
+ decode (selc b x y) <~=~> if b then decode x else decode y.
Class add_modulo := { addm : W -> W -> W (* modulus *) -> W }.
Global Coercion addm : add_modulo >-> Funclass.
Class is_add_modulo (addm : add_modulo) :=
decode_add_modulo :> forall x y modulus,
- decode (addm x y modulus) == (if (decode x + decode y) <? decode modulus
- then (decode x + decode y)
- else (decode x + decode y) - decode modulus)%Z.
+ decode (addm x y modulus) <~=~> (if (decode x + decode y) <? decode modulus
+ then (decode x + decode y)
+ else (decode x + decode y) - decode modulus).
End InstructionGallery.
Global Arguments load_immediate : clear implicits.
diff --git a/src/BoundedArithmetic/InterfaceProofs.v b/src/BoundedArithmetic/InterfaceProofs.v
index b8e20c607..8256fc23f 100644
--- a/src/BoundedArithmetic/InterfaceProofs.v
+++ b/src/BoundedArithmetic/InterfaceProofs.v
@@ -3,14 +3,13 @@ Require Import Coq.ZArith.ZArith Coq.micromega.Psatz.
Require Import Crypto.BoundedArithmetic.Interface.
Require Import Crypto.Util.ZUtil.
Require Import Crypto.Util.Tuple.
+Require Import Crypto.Util.AutoRewrite.
Require Import Crypto.Util.Notations.
Local Open Scope type_scope.
Local Open Scope Z_scope.
-Local Infix "==" := rewrite_eq.
-Local Infix "=~>" := rewrite_left_to_right_eq.
-Local Infix "<~=" := rewrite_right_to_left_eq.
+Import BoundedRewriteNotations.
Local Notation bit b := (if b then 1 else 0).
Section InstructionGallery.
@@ -19,11 +18,33 @@ Section InstructionGallery.
(Wdecoder : decoder n W).
Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *)
+ Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate W)
+ (pf : forall r count, 0 <= count < n
+ -> _ /\ _)
+ := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H);
+ decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}.
+
+ Definition Build_is_add_with_carry' (adc : add_with_carry W)
+ (pf : forall x y c, _ /\ _)
+ := {| bit_fst_add_with_carry x y c := proj1 (pf x y c);
+ decode_snd_add_with_carry x y c := proj2 (pf x y c) |}.
+
+ Definition Build_is_sub_with_carry' (subc : sub_with_carry W)
+ (pf : forall x y c, _ /\ _)
+ : is_sub_with_carry subc
+ := {| fst_sub_with_carry x y c := proj1 (pf x y c);
+ decode_snd_sub_with_carry x y c := proj2 (pf x y c) |}.
+
+ Definition Build_is_mul_double' (muldw : multiply_double W)
+ (pf : forall x y, _ /\ _)
+ := {| decode_fst_mul_double x y := proj1 (pf x y);
+ decode_snd_mul_double x y := proj2 (pf x y) |}.
+
Lemma is_spread_left_immediate_alt
{sprl : spread_left_immediate W}
{isdecode : is_decode Wdecoder}
: is_spread_left_immediate sprl
- <-> (forall r count, 0 <= count < n -> decode (fst (sprl r count)) + decode (snd (sprl r count)) << n = (decode r << count) mod (2^n*2^n)).
+ <-> (forall r count, 0 <= count < n -> decode (fst (sprl r count)) + decode (snd (sprl r count)) << n = (decode r << count) mod (2^n*2^n))%Z.
Proof.
split; intro H; [ | apply Build_is_spread_left_immediate' ];
intros r count Hc;
@@ -31,8 +52,8 @@ Section InstructionGallery.
unfold bounded_in_range_cls in *;
pose proof (decode_range r);
assert (0 < 2^n) by auto with zarith;
- assert (0 <= 2^count < 2^n) by auto with zarith;
- assert (0 <= decode r * 2^count < 2^n * 2^n) by (generalize dependent (decode r); intros; nia);
+ assert (0 <= 2^count < 2^n)%Z by auto with zarith;
+ assert (0 <= decode r * 2^count < 2^n * 2^n)%Z by (generalize dependent (decode r); intros; nia);
rewrite ?decode_fst_spread_left_immediate, ?decode_snd_spread_left_immediate
by typeclasses eauto with typeclass_instances core;
autorewrite with Zshift_to_pow zsimplify push_Zpow.
@@ -53,7 +74,7 @@ Section InstructionGallery.
pose proof (decode_range x);
pose proof (decode_range y);
assert (0 < 2^n) by auto with zarith;
- assert (0 <= decode x * decode y < 2^n * 2^n) by nia;
+ assert (0 <= decode x * decode y < 2^n * 2^n)%Z by nia;
(destruct (0 <=? n) eqn:?; Z.ltb_to_lt;
[ | assert (2^n = 0) by auto with zarith; exfalso; omega ]);
rewrite ?decode_fst_mul_double, ?decode_snd_mul_double
@@ -65,8 +86,6 @@ Section InstructionGallery.
Qed.
End InstructionGallery.
-Local Notation "x <= y < z" := (bounded_in_range_cls x y z).
-
Global Arguments is_spread_left_immediate_alt {_ _ _ _ _}.
Global Arguments is_mul_double_alt {_ _ _ _ _}.
@@ -95,7 +114,7 @@ Proof. exact _. Qed.
Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode}
(isinhabited : W)
- : 0 <= n.
+ : (0 <= n)%Z.
Proof.
pose proof (decode_range isinhabited).
assert (0 < 2^n) by omega.
@@ -113,13 +132,13 @@ Section adc_subc.
{isadc : is_add_with_carry adc}
{issubc : is_sub_with_carry subc}.
Global Instance bit_fst_add_with_carry_false
- : forall x y, bit (fst (adc x y false)) == (decode x + decode y) >> n.
+ : forall x y, bit (fst (adc x y false)) <~=~> (decode x + decode y) >> n.
Proof.
intros; erewrite bit_fst_add_with_carry by assumption.
autorewrite with zsimplify_const; reflexivity.
Qed.
Global Instance bit_fst_add_with_carry_true
- : forall x y, bit (fst (adc x y true)) == (decode x + decode y + 1) >> n.
+ : forall x y, bit (fst (adc x y true)) <~=~> (decode x + decode y + 1) >> n.
Proof.
intros; erewrite bit_fst_add_with_carry by assumption.
autorewrite with zsimplify_const; reflexivity.
@@ -128,9 +147,9 @@ Section adc_subc.
: forall x y c, fst (adc x y c) <~= (2^n <=? (decode x + decode y + bit c)).
Proof.
intros x y c; hnf.
- assert (0 <= n) by eauto using decode_exponent_nonnegative.
+ assert (0 <= n)%Z by eauto using decode_exponent_nonnegative.
pose proof (decode_range x); pose proof (decode_range y).
- assert (0 <= bit c <= 1) by (destruct c; omega).
+ assert (0 <= bit c <= 1)%Z by (destruct c; omega).
lazymatch goal with
| [ |- fst ?x = (?a <=? ?b) :> bool ]
=> cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z);
@@ -148,30 +167,30 @@ Section adc_subc.
autorewrite with zsimplify_const; reflexivity.
Qed.
Global Instance fst_add_with_carry_true_leb
- : forall x y, fst (adc x y true) == (2^n <=? (decode x + decode y + 1)).
+ : forall x y, fst (adc x y true) <~=~> (2^n <=? (decode x + decode y + 1)).
Proof.
intros; erewrite fst_add_with_carry_leb by assumption.
autorewrite with zsimplify_const; reflexivity.
Qed.
Global Instance fst_sub_with_carry_false
- : forall x y, fst (subc x y false) == ((decode x - decode y) <? 0).
+ : forall x y, fst (subc x y false) <~=~> ((decode x - decode y) <? 0).
Proof.
intros; erewrite fst_sub_with_carry by assumption.
autorewrite with zsimplify_const; reflexivity.
Qed.
Global Instance fst_sub_with_carry_true
- : forall x y, fst (subc x y true) == ((decode x - decode y - 1) <? 0).
+ : forall x y, fst (subc x y true) <~=~> ((decode x - decode y - 1) <? 0).
Proof.
intros; erewrite fst_sub_with_carry by assumption.
autorewrite with zsimplify_const; reflexivity.
Qed.
End adc_subc.
-Hint Extern 2 (rewrite_right_to_left_eq _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y)))
+Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y)))
=> apply @fst_add_with_carry_false_leb : typeclass_instances.
-Hint Extern 2 (rewrite_right_to_left_eq _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + 1)))
+Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + 1)))
=> apply @fst_add_with_carry_true_leb : typeclass_instances.
-Hint Extern 2 (rewrite_right_to_left_eq _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + if ?c then _ else _)))
+Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + if ?c then _ else _)))
=> apply @fst_add_with_carry_leb : typeclass_instances.
diff --git a/src/Util/AutoRewrite.v b/src/Util/AutoRewrite.v
new file mode 100644
index 000000000..b5e276f20
--- /dev/null
+++ b/src/Util/AutoRewrite.v
@@ -0,0 +1,56 @@
+(** * Machinery for reimplementing some bits of [rewrite_strat] with open pattern databases *)
+Require Import Crypto.Util.Notations.
+(** We build classes for rewriting in each direction, and pick lemmas
+ by resolving on tags and the term to be rewritten. *)
+(** Base class, for bidirectional rewriting. *)
+Class rewrite_eq {tagT} (tag : tagT) {A} (x y : A)
+ := by_rewrite : x = y.
+Arguments by_rewrite {tagT} tag {A} _ _ {_}.
+Infix "<~=~>" := (rewrite_eq _) : type_scope.
+
+Class rewrite_right_to_left_eq {tagT} (tag : tagT) {A} (x y : A)
+ := by_rewrite_right_to_left : rewrite_eq tag x y.
+Arguments by_rewrite_right_to_left {tagT} tag {A} _ _ {_}.
+Global Instance unfold_rewrite_right_to_left_eq {tagT tag A x y} (H : @rewrite_eq tagT tag A x y)
+ : @rewrite_right_to_left_eq tagT tag A x y := H.
+Infix "<~=" := (rewrite_right_to_left_eq _) : type_scope.
+
+Class rewrite_left_to_right_eq {tagT} (tag : tagT) {A} (x y : A)
+ := by_rewrite_left_to_right : rewrite_eq tag x y.
+Arguments by_rewrite_left_to_right {tagT} tag {A} _ _ {_}.
+Global Instance unfold_rewrite_left_to_right_eq {tagT tag A x y} (H : @rewrite_eq tagT tag A x y)
+ : @rewrite_left_to_right_eq tagT tag A x y := H.
+Infix "=~>" := (rewrite_left_to_right_eq _) : type_scope.
+
+Ltac typeclass_do_left_to_right tag from tac :=
+ let lem := constr:(by_rewrite_left_to_right tag from _ : from = _) in tac lem.
+Ltac typeclass_do_right_to_left tag from tac :=
+ let lem := constr:(by_rewrite_right_to_left tag _ from : _ = from) in tac lem.
+
+
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" :=
+ typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem).
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" :=
+ typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * ).
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" hyp_list(H) :=
+ typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in H).
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" hyp_list(H) "|-" "*" :=
+ typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in H |- *).
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" "|-" "*" :=
+ typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * |- * ).
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" "|-" :=
+ typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * |- ).
+
+
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) :=
+ typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem).
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" :=
+ typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * ).
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" hyp_list(H) :=
+ typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in H).
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" hyp_list(H) "|-" "*" :=
+ typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in H |- *).
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" "|-" "*" :=
+ typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * |- * ).
+Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" "|-" :=
+ typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * |- ).
diff --git a/src/Util/Notations.v b/src/Util/Notations.v
index 686443829..8795da399 100644
--- a/src/Util/Notations.v
+++ b/src/Util/Notations.v
@@ -27,6 +27,7 @@ Reserved Infix "~=" (at level 70).
Reserved Infix "==" (at level 70, no associativity).
Reserved Infix "=~>" (at level 70, no associativity).
Reserved Infix "<~=" (at level 70, no associativity).
+Reserved Infix "<~=~>" (at level 70, no associativity).
Reserved Infix "≡" (at level 70, no associativity).
Reserved Infix "≢" (at level 70, no associativity).
Reserved Infix "≡_n" (at level 70, no associativity).