From 6d3702edad1a69a08565a288f1153b4853ba3b25 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 26 Jun 2018 13:23:40 -0400 Subject: Slightly better definitions of some ZUtil functions This way we can just directly reify most of the primitives we care about. --- src/Util/ZUtil/AddGetCarry.v | 3 +-- src/Util/ZUtil/CPS.v | 14 ++++++-------- src/Util/ZUtil/Definitions.v | 18 ++++++++---------- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/Util/ZUtil/AddGetCarry.v b/src/Util/ZUtil/AddGetCarry.v index e8f8d883d..5b8cbf13e 100644 --- a/src/Util/ZUtil/AddGetCarry.v +++ b/src/Util/ZUtil/AddGetCarry.v @@ -1,5 +1,4 @@ Require Import Coq.ZArith.ZArith Coq.micromega.Lia. -Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.ZUtil.Definitions. Require Import Crypto.Util.ZUtil.Hints.ZArith. Require Import Crypto.Util.Prod. @@ -64,7 +63,7 @@ Module Z. Z.add_get_carry_full Z.add_with_get_carry_full Z.add_get_carry Z.add_with_get_carry Z.add_with_carry Z.sub_get_borrow_full Z.sub_with_get_borrow_full - Z.sub_get_borrow Z.sub_with_get_borrow Z.sub_with_borrow. + Z.sub_get_borrow Z.sub_with_get_borrow Z.sub_with_borrow Let_In. Lemma add_get_carry_full_mod s x y : fst (Z.add_get_carry_full s x y) = (x + y) mod s. diff --git a/src/Util/ZUtil/CPS.v b/src/Util/ZUtil/CPS.v index 3c0007c88..e2b21933b 100644 --- a/src/Util/ZUtil/CPS.v +++ b/src/Util/ZUtil/CPS.v @@ -121,14 +121,12 @@ Module Z. Definition mul_split_at_bitwidth_cps {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : T := dlet xy := x * y in - f (match bitwidth with - | Z.pos _ | Z0 => Z.land xy (Z.ones bitwidth) - | Z.neg _ => xy mod 2^bitwidth - end, - match bitwidth with - | Z.pos _ | Z0 => Z.shiftr xy bitwidth - | Z.neg _ => xy / 2^bitwidth - end). + f (if Z.geb bitwidth 0 + then Z.land xy (Z.ones bitwidth) + else xy mod 2^bitwidth, + if Z.geb bitwidth 0 + then Z.shiftr xy bitwidth + else xy / 2^bitwidth). Definition mul_split_at_bitwidth_cps_correct {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : @mul_split_at_bitwidth_cps T bitwidth x y f = f (Z.mul_split_at_bitwidth bitwidth x y) := eq_refl. diff --git a/src/Util/ZUtil/Definitions.v b/src/Util/ZUtil/Definitions.v index 67ccbc772..373c71763 100644 --- a/src/Util/ZUtil/Definitions.v +++ b/src/Util/ZUtil/Definitions.v @@ -31,7 +31,7 @@ Module Z. Definition add_with_carry (c : Z) (x y : Z) : Z := c + x + y. Definition add_with_get_carry (bitwidth : Z) (c : Z) (x y : Z) : Z * Z - := get_carry bitwidth (add_with_carry c x y). + := dlet v := add_with_carry c x y in get_carry bitwidth v. Definition add_get_carry (bitwidth : Z) (x y : Z) : Z * Z := add_with_get_carry bitwidth 0 x y. @@ -41,7 +41,7 @@ Module Z. Definition sub_with_borrow (c : Z) (x y : Z) : Z := add_with_carry (-c) x (-y). Definition sub_with_get_borrow (bitwidth : Z) (c : Z) (x y : Z) : Z * Z - := get_borrow bitwidth (sub_with_borrow c x y). + := dlet v := sub_with_borrow c x y in get_borrow bitwidth v. Definition sub_get_borrow (bitwidth : Z) (x y : Z) : Z * Z := sub_with_get_borrow bitwidth 0 x y. @@ -66,14 +66,12 @@ Module Z. Definition mul_split_at_bitwidth (bitwidth : Z) (x y : Z) : Z * Z := dlet xy := x * y in - (match bitwidth with - | Z.pos _ | Z0 => xy &' Z.ones bitwidth - | Z.neg _ => xy mod 2^bitwidth - end, - match bitwidth with - | Z.pos _ | Z0 => xy >> bitwidth - | Z.neg _ => xy / 2^bitwidth - end). + (if Z.geb bitwidth 0 + then xy &' Z.ones bitwidth + else xy mod 2^bitwidth, + if Z.geb bitwidth 0 + then xy >> bitwidth + else xy / 2^bitwidth). Definition mul_split (s x y : Z) : Z * Z := if s =? 2^Z.log2 s then mul_split_at_bitwidth (Z.log2 s) x y -- cgit v1.2.3