aboutsummaryrefslogtreecommitdiff
path: root/src/Util/ZUtil
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-10-19 13:57:48 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-10-19 13:57:48 -0400
commit9d16ae4ecb6f24ae7eefabd056902b00bf2fe001 (patch)
tree0ad36b87d1d6670ff7d0c5e57d818104bcdf911f /src/Util/ZUtil
parentb0d6dfe21d669be0b12a76d37c9bd78c20f788f9 (diff)
Update ZUtil cps definitions
This will hopefully allow the compiler to reflective-land an easier job.
Diffstat (limited to 'src/Util/ZUtil')
-rw-r--r--src/Util/ZUtil/AddGetCarry.v4
-rw-r--r--src/Util/ZUtil/CPS.v27
-rw-r--r--src/Util/ZUtil/Definitions.v8
3 files changed, 22 insertions, 17 deletions
diff --git a/src/Util/ZUtil/AddGetCarry.v b/src/Util/ZUtil/AddGetCarry.v
index e8897431a..85e9ece74 100644
--- a/src/Util/ZUtil/AddGetCarry.v
+++ b/src/Util/ZUtil/AddGetCarry.v
@@ -5,8 +5,10 @@ Require Import Crypto.Util.ZUtil.Hints.ZArith.
Require Import Crypto.Util.Prod.
Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo.
Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem.
+Require Import Crypto.Util.ZUtil.Tactics.LtbToLt.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Tactics.BreakMatch.
+Require Import Crypto.Util.Tactics.RewriteHyp.
Local Open Scope Z_scope.
Local Notation eta x := (fst x, snd x).
@@ -55,7 +57,7 @@ Module Z.
Local Ltac easypeasy :=
repeat progress autounfold;
- break_match; autorewrite with cancel_pair zsimplify;
+ break_match; Z.ltb_to_lt; rewrite_hyp ?*; autorewrite with cancel_pair zsimplify;
solve [repeat (f_equal; try ring)].
Local Hint Unfold Z.get_carry Z.get_borrow
diff --git a/src/Util/ZUtil/CPS.v b/src/Util/ZUtil/CPS.v
index c113191e8..a875247a6 100644
--- a/src/Util/ZUtil/CPS.v
+++ b/src/Util/ZUtil/CPS.v
@@ -1,5 +1,6 @@
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.Tactics.LtbToLt.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Tactics.Head.
@@ -28,6 +29,8 @@ Module Z.
end;
repeat first [ reflexivity
| progress cbv [Decidable.dec Decidable.dec_eq_Z] in *
+ | progress Z.ltb_to_lt
+ | congruence
| progress autorewrite with uncps
| break_innermost_match_step ].
@@ -38,16 +41,16 @@ Module Z.
:= eq_refl.
Hint Rewrite @get_carry_cps_correct : uncps.
Definition add_with_get_carry_cps {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T) : T
- := get_carry_cps bitwidth (Z.add_with_carry c x y) f.
- Lemma add_with_get_carry_cps_correct {T} bitwidth c x y f
- : @add_with_get_carry_cps T bitwidth c x y f = f (Z.add_with_get_carry bitwidth c x y).
- Proof. prove_cps_correct (). Qed.
+ := let '(v, c) := Z.add_with_get_carry bitwidth c x y in f (v, c).
+ Definition add_with_get_carry_cps_correct {T} bitwidth c x y f
+ : @add_with_get_carry_cps T bitwidth c x y f = f (Z.add_with_get_carry bitwidth c x y)
+ := eq_refl.
Hint Rewrite @add_with_get_carry_cps_correct : uncps.
Definition add_get_carry_cps {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : T
- := add_with_get_carry_cps bitwidth 0 x y f.
+ := let '(v, c) := Z.add_get_carry bitwidth x y in f (v, c).
Definition add_get_carry_cps_correct {T} bitwidth x y f
: @add_get_carry_cps T bitwidth x y f = f (Z.add_get_carry bitwidth x y)
- := add_with_get_carry_cps_correct _ _ _ _ _.
+ := eq_refl.
Hint Rewrite @add_get_carry_cps_correct : uncps.
Definition get_borrow_cps {T} (bitwidth : Z) (v : Z) (f : Z * Z -> T)
@@ -57,13 +60,13 @@ Module Z.
:= eq_refl.
Hint Rewrite @get_borrow_cps_correct : uncps.
Definition sub_with_get_borrow_cps {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T) : T
- := get_borrow_cps bitwidth (Z.sub_with_borrow c x y) f.
+ := let '(v, c) := Z.sub_with_get_borrow bitwidth c x y in f (v, c).
Definition sub_with_get_borrow_cps_correct {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T)
: @sub_with_get_borrow_cps T bitwidth c x y f = f (Z.sub_with_get_borrow bitwidth c x y)
:= eq_refl.
Hint Rewrite @sub_with_get_borrow_cps_correct : uncps.
Definition sub_get_borrow_cps {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : T
- := sub_with_get_borrow_cps bitwidth 0 x y f.
+ := let '(v, c) := Z.sub_get_borrow bitwidth x y in f (v, c).
Definition sub_get_borrow_cps_correct {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T)
: @sub_get_borrow_cps T bitwidth x y f = f (Z.sub_get_borrow bitwidth x y)
:= eq_refl.
@@ -72,7 +75,7 @@ Module Z.
(* splits at [bound], not [2^bitwidth]; wrapper to make add_getcarry
work if input is not known to be a power of 2 *)
Definition add_get_carry_full_cps {T} (bound : Z) (x y : Z) (f : Z * Z -> T) : T
- := eq_dec_cps
+ := eqb_cps
(2 ^ (Z.log2 bound)) bound
(fun eqb
=> if eqb
@@ -83,7 +86,7 @@ Module Z.
Proof. prove_cps_correct (). Qed.
Hint Rewrite @add_get_carry_full_cps_correct : uncps.
Definition add_with_get_carry_full_cps {T} (bound : Z) (c x y : Z) (f : Z * Z -> T) : T
- := eq_dec_cps
+ := eqb_cps
(2 ^ (Z.log2 bound)) bound
(fun eqb
=> if eqb
@@ -94,7 +97,7 @@ Module Z.
Proof. prove_cps_correct (). Qed.
Hint Rewrite @add_with_get_carry_full_cps_correct : uncps.
Definition sub_get_borrow_full_cps {T} (bound : Z) (x y : Z) (f : Z * Z -> T) : T
- := eq_dec_cps
+ := eqb_cps
(2 ^ (Z.log2 bound)) bound
(fun eqb
=> if eqb
@@ -105,7 +108,7 @@ Module Z.
Proof. prove_cps_correct (). Qed.
Hint Rewrite @sub_get_borrow_full_cps_correct : uncps.
Definition sub_with_get_borrow_full_cps {T} (bound : Z) (c x y : Z) (f : Z * Z -> T) : T
- := eq_dec_cps
+ := eqb_cps
(2 ^ (Z.log2 bound)) bound
(fun eqb
=> if eqb
diff --git a/src/Util/ZUtil/Definitions.v b/src/Util/ZUtil/Definitions.v
index d80b2bde5..760651a94 100644
--- a/src/Util/ZUtil/Definitions.v
+++ b/src/Util/ZUtil/Definitions.v
@@ -32,19 +32,19 @@ Module Z.
(* splits at [bound], not [2^bitwidth]; wrapper to make add_getcarry
work if input is not known to be a power of 2 *)
Definition add_get_carry_full (bound : Z) (x y : Z) : Z * Z
- := if dec (2 ^ (Z.log2 bound) = bound)
+ := if 2 ^ (Z.log2 bound) =? bound
then add_get_carry (Z.log2 bound) x y
else ((x + y) mod bound, (x + y) / bound).
Definition add_with_get_carry_full (bound : Z) (c x y : Z) : Z * Z
- := if dec (2 ^ (Z.log2 bound) = bound)
+ := if 2 ^ (Z.log2 bound) =? bound
then add_with_get_carry (Z.log2 bound) c x y
else ((c + x + y) mod bound, (c + x + y) / bound).
Definition sub_get_borrow_full (bound : Z) (x y : Z) : Z * Z
- := if dec (2 ^ (Z.log2 bound) = bound)
+ := if 2 ^ (Z.log2 bound) =? bound
then sub_get_borrow (Z.log2 bound) x y
else ((x - y) mod bound, -((x - y) / bound)).
Definition sub_with_get_borrow_full (bound : Z) (c x y : Z) : Z * Z
- := if dec (2 ^ (Z.log2 bound) = bound)
+ := if 2 ^ (Z.log2 bound) =? bound
then sub_with_get_borrow (Z.log2 bound) c x y
else ((x - y - c) mod bound, -((x - y - c) / bound)).