From 9d16ae4ecb6f24ae7eefabd056902b00bf2fe001 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Thu, 19 Oct 2017 13:57:48 -0400 Subject: Update ZUtil cps definitions This will hopefully allow the compiler to reflective-land an easier job. --- src/Util/ZUtil/AddGetCarry.v | 4 +++- src/Util/ZUtil/CPS.v | 27 +++++++++++++++------------ src/Util/ZUtil/Definitions.v | 8 ++++---- 3 files changed, 22 insertions(+), 17 deletions(-) (limited to 'src/Util/ZUtil') diff --git a/src/Util/ZUtil/AddGetCarry.v b/src/Util/ZUtil/AddGetCarry.v index e8897431a..85e9ece74 100644 --- a/src/Util/ZUtil/AddGetCarry.v +++ b/src/Util/ZUtil/AddGetCarry.v @@ -5,8 +5,10 @@ Require Import Crypto.Util.ZUtil.Hints.ZArith. Require Import Crypto.Util.Prod. Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. Require Import Crypto.Util.LetIn. Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.RewriteHyp. Local Open Scope Z_scope. Local Notation eta x := (fst x, snd x). @@ -55,7 +57,7 @@ Module Z. Local Ltac easypeasy := repeat progress autounfold; - break_match; autorewrite with cancel_pair zsimplify; + break_match; Z.ltb_to_lt; rewrite_hyp ?*; autorewrite with cancel_pair zsimplify; solve [repeat (f_equal; try ring)]. Local Hint Unfold Z.get_carry Z.get_borrow diff --git a/src/Util/ZUtil/CPS.v b/src/Util/ZUtil/CPS.v index c113191e8..a875247a6 100644 --- a/src/Util/ZUtil/CPS.v +++ b/src/Util/ZUtil/CPS.v @@ -1,5 +1,6 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. Require Import Crypto.Util.LetIn. Require Import Crypto.Util.Tactics.BreakMatch. Require Import Crypto.Util.Tactics.Head. @@ -28,6 +29,8 @@ Module Z. end; repeat first [ reflexivity | progress cbv [Decidable.dec Decidable.dec_eq_Z] in * + | progress Z.ltb_to_lt + | congruence | progress autorewrite with uncps | break_innermost_match_step ]. @@ -38,16 +41,16 @@ Module Z. := eq_refl. Hint Rewrite @get_carry_cps_correct : uncps. Definition add_with_get_carry_cps {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T) : T - := get_carry_cps bitwidth (Z.add_with_carry c x y) f. - Lemma add_with_get_carry_cps_correct {T} bitwidth c x y f - : @add_with_get_carry_cps T bitwidth c x y f = f (Z.add_with_get_carry bitwidth c x y). - Proof. prove_cps_correct (). Qed. + := let '(v, c) := Z.add_with_get_carry bitwidth c x y in f (v, c). + Definition add_with_get_carry_cps_correct {T} bitwidth c x y f + : @add_with_get_carry_cps T bitwidth c x y f = f (Z.add_with_get_carry bitwidth c x y) + := eq_refl. Hint Rewrite @add_with_get_carry_cps_correct : uncps. Definition add_get_carry_cps {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : T - := add_with_get_carry_cps bitwidth 0 x y f. + := let '(v, c) := Z.add_get_carry bitwidth x y in f (v, c). Definition add_get_carry_cps_correct {T} bitwidth x y f : @add_get_carry_cps T bitwidth x y f = f (Z.add_get_carry bitwidth x y) - := add_with_get_carry_cps_correct _ _ _ _ _. + := eq_refl. Hint Rewrite @add_get_carry_cps_correct : uncps. Definition get_borrow_cps {T} (bitwidth : Z) (v : Z) (f : Z * Z -> T) @@ -57,13 +60,13 @@ Module Z. := eq_refl. Hint Rewrite @get_borrow_cps_correct : uncps. Definition sub_with_get_borrow_cps {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T) : T - := get_borrow_cps bitwidth (Z.sub_with_borrow c x y) f. + := let '(v, c) := Z.sub_with_get_borrow bitwidth c x y in f (v, c). Definition sub_with_get_borrow_cps_correct {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T) : @sub_with_get_borrow_cps T bitwidth c x y f = f (Z.sub_with_get_borrow bitwidth c x y) := eq_refl. Hint Rewrite @sub_with_get_borrow_cps_correct : uncps. Definition sub_get_borrow_cps {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : T - := sub_with_get_borrow_cps bitwidth 0 x y f. + := let '(v, c) := Z.sub_get_borrow bitwidth x y in f (v, c). Definition sub_get_borrow_cps_correct {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : @sub_get_borrow_cps T bitwidth x y f = f (Z.sub_get_borrow bitwidth x y) := eq_refl. @@ -72,7 +75,7 @@ Module Z. (* splits at [bound], not [2^bitwidth]; wrapper to make add_getcarry work if input is not known to be a power of 2 *) Definition add_get_carry_full_cps {T} (bound : Z) (x y : Z) (f : Z * Z -> T) : T - := eq_dec_cps + := eqb_cps (2 ^ (Z.log2 bound)) bound (fun eqb => if eqb @@ -83,7 +86,7 @@ Module Z. Proof. prove_cps_correct (). Qed. Hint Rewrite @add_get_carry_full_cps_correct : uncps. Definition add_with_get_carry_full_cps {T} (bound : Z) (c x y : Z) (f : Z * Z -> T) : T - := eq_dec_cps + := eqb_cps (2 ^ (Z.log2 bound)) bound (fun eqb => if eqb @@ -94,7 +97,7 @@ Module Z. Proof. prove_cps_correct (). Qed. Hint Rewrite @add_with_get_carry_full_cps_correct : uncps. Definition sub_get_borrow_full_cps {T} (bound : Z) (x y : Z) (f : Z * Z -> T) : T - := eq_dec_cps + := eqb_cps (2 ^ (Z.log2 bound)) bound (fun eqb => if eqb @@ -105,7 +108,7 @@ Module Z. Proof. prove_cps_correct (). Qed. Hint Rewrite @sub_get_borrow_full_cps_correct : uncps. Definition sub_with_get_borrow_full_cps {T} (bound : Z) (c x y : Z) (f : Z * Z -> T) : T - := eq_dec_cps + := eqb_cps (2 ^ (Z.log2 bound)) bound (fun eqb => if eqb diff --git a/src/Util/ZUtil/Definitions.v b/src/Util/ZUtil/Definitions.v index d80b2bde5..760651a94 100644 --- a/src/Util/ZUtil/Definitions.v +++ b/src/Util/ZUtil/Definitions.v @@ -32,19 +32,19 @@ Module Z. (* splits at [bound], not [2^bitwidth]; wrapper to make add_getcarry work if input is not known to be a power of 2 *) Definition add_get_carry_full (bound : Z) (x y : Z) : Z * Z - := if dec (2 ^ (Z.log2 bound) = bound) + := if 2 ^ (Z.log2 bound) =? bound then add_get_carry (Z.log2 bound) x y else ((x + y) mod bound, (x + y) / bound). Definition add_with_get_carry_full (bound : Z) (c x y : Z) : Z * Z - := if dec (2 ^ (Z.log2 bound) = bound) + := if 2 ^ (Z.log2 bound) =? bound then add_with_get_carry (Z.log2 bound) c x y else ((c + x + y) mod bound, (c + x + y) / bound). Definition sub_get_borrow_full (bound : Z) (x y : Z) : Z * Z - := if dec (2 ^ (Z.log2 bound) = bound) + := if 2 ^ (Z.log2 bound) =? bound then sub_get_borrow (Z.log2 bound) x y else ((x - y) mod bound, -((x - y) / bound)). Definition sub_with_get_borrow_full (bound : Z) (c x y : Z) : Z * Z - := if dec (2 ^ (Z.log2 bound) = bound) + := if 2 ^ (Z.log2 bound) =? bound then sub_with_get_borrow (Z.log2 bound) c x y else ((x - y - c) mod bound, -((x - y - c) / bound)). -- cgit v1.2.3