From ef3ce824f87fd11eed78a13a884579499a8ffc53 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 2 Apr 2019 17:51:17 -0400 Subject: Add Z.combine_at_bitwidth --- _CoqProject | 1 + src/AbstractInterpretation.v | 6 ++++++ src/AbstractInterpretationZRangeProofs.v | 1 + src/CStringification.v | 4 ++++ src/GENERATEDIdentifiersWithoutTypes.v | 26 ++++++++++++++++++++++++ src/Language.v | 3 +++ src/Rewriter.v | 1 + src/Util/ZUtil/Combine.v | 34 ++++++++++++++++++++++++++++++++ src/Util/ZUtil/Definitions.v | 3 +++ 9 files changed, 79 insertions(+) create mode 100644 src/Util/ZUtil/Combine.v diff --git a/_CoqProject b/_CoqProject index 8c26d122b..cc94c6619 100644 --- a/_CoqProject +++ b/_CoqProject @@ -273,6 +273,7 @@ src/Util/ZUtil/AddGetCarry.v src/Util/ZUtil/AddModulo.v src/Util/ZUtil/CC.v src/Util/ZUtil/CPS.v +src/Util/ZUtil/Combine.v src/Util/ZUtil/Definitions.v src/Util/ZUtil/DistrIf.v src/Util/ZUtil/Div.v diff --git a/src/AbstractInterpretation.v b/src/AbstractInterpretation.v index e74582ae6..0489f4dc8 100644 --- a/src/AbstractInterpretation.v +++ b/src/AbstractInterpretation.v @@ -685,6 +685,12 @@ Module Compilers. (ZRange.four_corners Z.add x y) (ZRange.eight_corners (fun x y m => Z.max 0 (x + y - m)) x y m))) + | ident.Z_combine_at_bitwidth as idc + => fun bitwidth lo hi + => bitwidth <- to_literal bitwidth; + lo <- lo; + hi <- hi; + Some (ZRange.four_corners (ident.interp idc bitwidth) lo hi) | ident.Z_cast range => fun r : option zrange => interp_Z_cast range r diff --git a/src/AbstractInterpretationZRangeProofs.v b/src/AbstractInterpretationZRangeProofs.v index 2ca2ea140..7fb9f4072 100644 --- a/src/AbstractInterpretationZRangeProofs.v +++ b/src/AbstractInterpretationZRangeProofs.v @@ -22,6 +22,7 @@ Require Import Crypto.Util.ZUtil.CC. Require Import Crypto.Util.ZUtil.MulSplit. Require Import Crypto.Util.ZUtil.Rshi. Require Import Crypto.Util.ZUtil.Zselect. +Require Import Crypto.Util.ZUtil.Combine. Require Import Crypto.Util.ZUtil.Le. Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. Require Import Crypto.Util.ZUtil.Tactics.SplitMinMax. diff --git a/src/CStringification.v b/src/CStringification.v index 0862aec64..8926f3d97 100644 --- a/src/CStringification.v +++ b/src/CStringification.v @@ -339,6 +339,8 @@ Module Compilers. => fun args => (show_application with_casts (fun _ => "Z.rshi") args, ZRange.type.base.option.None) | ident.Z_cc_m => fun args => (show_application with_casts (fun _ => "Z.cc_m") args, ZRange.type.base.option.None) + | ident.Z_combine_at_bitwidth + => fun args => (show_application with_casts (fun _ => "Z.combine_at_bitwidth") args, ZRange.type.base.option.None) | ident.Z_cast range => fun '((x, xr), tt) => (x, Some range) | ident.Z_cast2 (r1, r2) @@ -453,6 +455,7 @@ Module Compilers. | ident.Z_add_modulo => "Z.add_modulo" | ident.Z_rshi => "Z.rshi" | ident.Z_cc_m => "Z.cc_m" + | ident.Z_combine_at_bitwidth => "Z.combine_at_bitwidth" | ident.Z_cast range => "(" ++ show_range_or_ctype range ++ ")" | ident.Z_cast2 (r1, r2) => "(" ++ show_range_or_ctype r1 ++ ", " ++ show_range_or_ctype r2 ++ ")" | ident.Build_zrange => "Build_zrange" @@ -1348,6 +1351,7 @@ Module Compilers. | ident.Z_add_modulo | ident.Z_rshi | ident.Z_cc_m + | ident.Z_combine_at_bitwidth | ident.Z_cast _ | ident.Z_cast2 _ | ident.Build_zrange diff --git a/src/GENERATEDIdentifiersWithoutTypes.v b/src/GENERATEDIdentifiersWithoutTypes.v index c84540ba1..3dfa4713b 100644 --- a/src/GENERATEDIdentifiersWithoutTypes.v +++ b/src/GENERATEDIdentifiersWithoutTypes.v @@ -716,6 +716,15 @@ print_ident = r"""Inductive ident : defaults.type -> Set := (Compilers.base.type.type_base base.type.Z) -> (fun x : Compilers.base.type => type.base x) (Compilers.base.type.type_base base.type.Z))%ptype + | Z_combine_at_bitwidth : ident + ((fun x : Compilers.base.type => type.base x) + (Compilers.base.type.type_base base.type.Z) -> + (fun x : Compilers.base.type => type.base x) + (Compilers.base.type.type_base base.type.Z) -> + (fun x : Compilers.base.type => type.base x) + (Compilers.base.type.type_base base.type.Z) -> + (fun x : Compilers.base.type => type.base x) + (Compilers.base.type.type_base base.type.Z))%ptype | Z_cast : zrange -> ident ((fun x : Compilers.base.type => type.base x) @@ -944,6 +953,7 @@ show_match_ident = r"""match # with | ident.Z_add_modulo => | ident.Z_rshi => | ident.Z_cc_m => + | ident.Z_combine_at_bitwidth => | ident.Z_cast range => | ident.Z_cast2 range => | ident.option_Some A => @@ -1556,6 +1566,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Z_add_modulo | Z_rshi | Z_cc_m + | Z_combine_at_bitwidth | Z_cast | Z_cast2 | option_Some @@ -1652,6 +1663,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Z_add_modulo => unit | Z_rshi => unit | Z_cc_m => unit + | Z_combine_at_bitwidth => unit | Z_cast => zrange | Z_cast2 => zrange * zrange | option_Some => Compilers.base.type @@ -1749,6 +1761,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Z_add_modulo => true | Z_rshi => true | Z_cc_m => true + | Z_combine_at_bitwidth => true | Z_cast => true | Z_cast2 => true | option_Some => false @@ -1846,6 +1859,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Z_add_modulo, Compilers.ident.Z_add_modulo => Datatypes.Some tt | Z_rshi, Compilers.ident.Z_rshi => Datatypes.Some tt | Z_cc_m, Compilers.ident.Z_cc_m => Datatypes.Some tt + | Z_combine_at_bitwidth, Compilers.ident.Z_combine_at_bitwidth => Datatypes.Some tt | Z_cast, Compilers.ident.Z_cast range => Datatypes.Some range | Z_cast2, Compilers.ident.Z_cast2 range => Datatypes.Some range | option_Some, Compilers.ident.option_Some A => Datatypes.Some A @@ -1939,6 +1953,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Z_add_modulo, _ | Z_rshi, _ | Z_cc_m, _ + | Z_combine_at_bitwidth, _ | Z_cast, _ | Z_cast2, _ | option_Some, _ @@ -2037,6 +2052,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Z_add_modulo => fun _ => (type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z))%ptype | Z_rshi => fun _ => (type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z))%ptype | Z_cc_m => fun _ => (type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z))%ptype + | Z_combine_at_bitwidth => fun _ => (type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z))%ptype | Z_cast => fun range => (type.base (Compilers.base.type.type_base base.type.Z) -> type.base (Compilers.base.type.type_base base.type.Z))%ptype | Z_cast2 => fun range => (type.base (Compilers.base.type.type_base base.type.Z * Compilers.base.type.type_base base.type.Z)%etype -> type.base (Compilers.base.type.type_base base.type.Z * Compilers.base.type.type_base base.type.Z)%etype)%ptype | option_Some => fun A => (type.base A -> type.base (Compilers.base.type.option A))%ptype @@ -2134,6 +2150,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Z_add_modulo => fun _ => @Compilers.ident.Z_add_modulo | Z_rshi => fun _ => @Compilers.ident.Z_rshi | Z_cc_m => fun _ => @Compilers.ident.Z_cc_m + | Z_combine_at_bitwidth => fun _ => @Compilers.ident.Z_combine_at_bitwidth | Z_cast => fun range => @Compilers.ident.Z_cast range | Z_cast2 => fun range => @Compilers.ident.Z_cast2 range | option_Some => fun A => @Compilers.ident.option_Some A @@ -2238,6 +2255,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Compilers.ident.Z_add_modulo => f _ Compilers.ident.Z_add_modulo | Compilers.ident.Z_rshi => f _ Compilers.ident.Z_rshi | Compilers.ident.Z_cc_m => f _ Compilers.ident.Z_cc_m + | Compilers.ident.Z_combine_at_bitwidth => f _ Compilers.ident.Z_combine_at_bitwidth | Compilers.ident.Z_cast range => f _ (@Compilers.ident.Z_cast range) | Compilers.ident.Z_cast2 range => f _ (@Compilers.ident.Z_cast2 range) | Compilers.ident.option_Some A => f _ (@Compilers.ident.option_Some A) @@ -2334,6 +2352,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Z_add_modulo : ident (type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z))%ptype | Z_rshi : ident (type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z))%ptype | Z_cc_m : ident (type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z))%ptype + | Z_combine_at_bitwidth : ident (type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z))%ptype | Z_cast : ident (type.base (base.type.type_base base.type.Z) -> type.base (base.type.type_base base.type.Z))%ptype | Z_cast2 : ident (type.base (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%pbtype -> type.base (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%pbtype)%ptype | option_Some {A : base.type} : ident (type.base A -> type.base (base.type.option A))%ptype @@ -2430,6 +2449,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | @Z_add_modulo => Raw.ident.Z_add_modulo | @Z_rshi => Raw.ident.Z_rshi | @Z_cc_m => Raw.ident.Z_cc_m + | @Z_combine_at_bitwidth => Raw.ident.Z_combine_at_bitwidth | @Z_cast => Raw.ident.Z_cast | @Z_cast2 => Raw.ident.Z_cast2 | @option_Some A => Raw.ident.option_Some @@ -2527,6 +2547,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | @Z_add_modulo => [] | @Z_rshi => [] | @Z_cc_m => [] + | @Z_combine_at_bitwidth => [] | @Z_cast => [zrange : Type] | @Z_cast2 => [zrange * zrange : Type] | @option_Some A => [] @@ -2624,6 +2645,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | @Z_add_modulo => fun _ => @Compilers.ident.Z_add_modulo | @Z_rshi => fun _ => @Compilers.ident.Z_rshi | @Z_cc_m => fun _ => @Compilers.ident.Z_cc_m + | @Z_combine_at_bitwidth => fun _ => @Compilers.ident.Z_combine_at_bitwidth | @Z_cast => fun arg => let '(range, _) := eta2r arg in @Compilers.ident.Z_cast range | @Z_cast2 => fun arg => let '(range, _) := eta2r arg in @Compilers.ident.Z_cast2 range | @option_Some A => fun _ => @Compilers.ident.option_Some _ @@ -2725,6 +2747,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | @Z_add_modulo, Compilers.ident.Z_add_modulo => Datatypes.Some tt | @Z_rshi, Compilers.ident.Z_rshi => Datatypes.Some tt | @Z_cc_m, Compilers.ident.Z_cc_m => Datatypes.Some tt + | @Z_combine_at_bitwidth, Compilers.ident.Z_combine_at_bitwidth => Datatypes.Some tt | @Z_cast, Compilers.ident.Z_cast range' => Datatypes.Some (range', tt) | @Z_cast2, Compilers.ident.Z_cast2 range' => Datatypes.Some (range', tt) | @option_Some A, Compilers.ident.option_Some A' => Datatypes.Some tt @@ -2818,6 +2841,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | @Z_add_modulo, _ | @Z_rshi, _ | @Z_cc_m, _ + | @Z_combine_at_bitwidth, _ | @Z_cast, _ | @Z_cast2, _ | @option_Some _, _ @@ -2916,6 +2940,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Compilers.ident.Z_add_modulo => @Z_add_modulo | Compilers.ident.Z_rshi => @Z_rshi | Compilers.ident.Z_cc_m => @Z_cc_m + | Compilers.ident.Z_combine_at_bitwidth => @Z_combine_at_bitwidth | Compilers.ident.Z_cast range => @Z_cast | Compilers.ident.Z_cast2 range => @Z_cast2 | Compilers.ident.option_Some A => @option_Some (base.relax A) @@ -3013,6 +3038,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Compilers.ident.Z_add_modulo => tt | Compilers.ident.Z_rshi => tt | Compilers.ident.Z_cc_m => tt + | Compilers.ident.Z_combine_at_bitwidth => tt | Compilers.ident.Z_cast range => (range, tt) | Compilers.ident.Z_cast2 range => (range, tt) | Compilers.ident.option_Some A => tt diff --git a/src/Language.v b/src/Language.v index 4b112dfdb..2e29b4d77 100644 --- a/src/Language.v +++ b/src/Language.v @@ -925,6 +925,7 @@ Module Compilers. | Z_add_modulo : ident (Z -> Z -> Z -> Z) | Z_rshi : ident (Z -> Z -> Z -> Z -> Z) | Z_cc_m : ident (Z -> Z -> Z) + | Z_combine_at_bitwidth : ident (Z -> Z -> Z -> Z) | Z_cast (range : ZRange.zrange) : ident (Z -> Z) | Z_cast2 (range : ZRange.zrange * ZRange.zrange) : ident ((Z * Z) -> (Z * Z)) | option_Some {A:base.type} : ident (A -> option A) @@ -1114,6 +1115,7 @@ Module Compilers. | Z_lnot_modulo => Z.lnot_modulo | Z_rshi => Z.rshi | Z_cc_m => Z.cc_m + | Z_combine_at_bitwidth => Z.combine_at_bitwidth | Z_cast r => cast r | Z_cast2 (r1, r2) => fun '(x1, x2) => (cast r1 x1, cast r2 x2) | Some A => @Datatypes.Some _ @@ -1453,6 +1455,7 @@ Module Compilers. | Z.add_modulo => then_tac ident.Z_add_modulo | Z.rshi => then_tac ident.Z_rshi | Z.cc_m => then_tac ident.Z_cc_m + | Z.combine_at_bitwidth => then_tac ident.Z_combine_at_bitwidth | ident.cast _ ?r => then_tac (ident.Z_cast r) | ident.cast2 _ ?r => then_tac (ident.Z_cast2 r) | @Some ?A diff --git a/src/Rewriter.v b/src/Rewriter.v index 1d4910a5f..8b074aca6 100644 --- a/src/Rewriter.v +++ b/src/Rewriter.v @@ -1431,6 +1431,7 @@ Module Compilers. | ident.Z_add_modulo | ident.Z_rshi | ident.Z_cc_m + | ident.Z_combine_at_bitwidth | ident.Z_cast _ | ident.Z_cast2 _ | ident.option_Some _ diff --git a/src/Util/ZUtil/Combine.v b/src/Util/ZUtil/Combine.v new file mode 100644 index 000000000..3ac4179c5 --- /dev/null +++ b/src/Util/ZUtil/Combine.v @@ -0,0 +1,34 @@ +Require Import Coq.Classes.Morphisms. +Require Import Coq.ZArith.ZArith. +Require Import Coq.micromega.Lia. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Crypto.Util.ZUtil.Notations. +Local Open Scope Z_scope. + +Module Z. + Lemma combine_at_bitwidth_correct bitwidth lo hi + : Z.combine_at_bitwidth bitwidth lo hi = lo + (hi << bitwidth). + Proof. reflexivity. Qed. + + Lemma combine_at_bitwidth_Proper bitwidth + : Proper (Z.le ==> Z.le ==> Z.le) (Z.combine_at_bitwidth bitwidth). + Proof. + cbv [Proper respectful]; intros; rewrite !combine_at_bitwidth_correct. + destruct bitwidth as [|bitwidth|bitwidth]; + [ | assert (0 <= 2^Z.pos bitwidth) by (apply Z.pow_nonneg; lia).. ]; + rewrite ?Z.shiftl_mul_pow2, ?Z.shiftl_div_pow2, ?Z.pow_0_r by lia; cbn [Z.opp]; try nia; + Z.div_mod_to_quot_rem; nia. + Qed. + Hint Resolve combine_at_bitwidth_Proper : zarith. + + Lemma combine_at_bitwidth_Proper1 bitwidth x + : Proper (Z.le ==> Z.le) (Z.combine_at_bitwidth bitwidth x). + Proof. repeat intro; eapply combine_at_bitwidth_Proper; (eassumption + reflexivity). Qed. + Hint Resolve combine_at_bitwidth_Proper1 : zarith. + + Lemma combine_at_bitwidth_Proper2 bitwidth x + : Proper (Z.le ==> Z.le) (fun y => Z.combine_at_bitwidth bitwidth y x). + Proof. repeat intro; eapply combine_at_bitwidth_Proper; (eassumption + reflexivity). Qed. + Hint Resolve combine_at_bitwidth_Proper2 : zarith. +End Z. diff --git a/src/Util/ZUtil/Definitions.v b/src/Util/ZUtil/Definitions.v index 8fe5772f5..a5ac1a95b 100644 --- a/src/Util/ZUtil/Definitions.v +++ b/src/Util/ZUtil/Definitions.v @@ -85,6 +85,9 @@ Module Z. then mul_split_at_bitwidth (Z.log2 s) x y else ((x * y) mod s, (x * y) / s). + Definition combine_at_bitwidth (bitwidth lo hi : Z) : Z + := lo + (hi << bitwidth). + (** if positive, round up to 2^k-1 (0b11111....); if negative, round down to -2^k (0b...111000000...) *) Definition round_lor_land_bound (x : Z) : Z := if (0 <=? x)%Z -- cgit v1.2.3