diff options
author | Jason Gross <jgross@mit.edu> | 2017-06-18 21:51:05 -0400 |
---|---|---|
committer | Jason Gross <jgross@mit.edu> | 2017-06-18 21:51:05 -0400 |
commit | 7673940e6896e358a6ebc543fd92be89bd1e6d20 (patch) | |
tree | 02e17696905e9d1e9fde91260b98845eb437dac3 | |
parent | c2436f0f5b314c4765b4087ebe6d2001d459b402 (diff) |
Better simplification of mulsplit
-rw-r--r-- | src/Compilers/Z/ArithmeticSimplifier.v | 36 | ||||
-rw-r--r-- | src/Compilers/Z/ArithmeticSimplifierInterp.v | 140 |
2 files changed, 123 insertions, 53 deletions
diff --git a/src/Compilers/Z/ArithmeticSimplifier.v b/src/Compilers/Z/ArithmeticSimplifier.v index 6732d6c25..8f5ca7c9e 100644 --- a/src/Compilers/Z/ArithmeticSimplifier.v +++ b/src/Compilers/Z/ArithmeticSimplifier.v @@ -5,6 +5,7 @@ Require Import Crypto.Compilers.ExprInversion. Require Import Crypto.Compilers.Rewriter. Require Import Crypto.Compilers.Z.Syntax. Require Import Crypto.Compilers.Z.Syntax.Equality. +Require Import Crypto.Util.ZUtil.Definitions. Section language. Context (convert_adc_to_sbb : bool). @@ -240,6 +241,41 @@ Section language. | _ => Op opc args end + | MulSplit bitwidth (TWord bw1 as T1) (TWord bw2 as T2) (TWord bwout1 as Tout1) (TWord bwout2 as Tout2) as opc + => fun args + => let sz1 := (2^Z.of_nat (2^bw1))%Z in + let sz2 := (2^Z.of_nat (2^bw2))%Z in + let szout1 := (2^Z.of_nat (2^bwout1))%Z in + let szout2 := (2^Z.of_nat (2^bwout2))%Z in + match interp_as_expr_or_const args with + | Some (const_of l, const_of r) + => let '(a, b) := Z.mul_split_at_bitwidth bitwidth (Z.max 0 l mod sz1) (Z.max 0 r mod sz2) in + Pair (Op (OpConst (a mod szout1)%Z) TT) + (Op (OpConst (b mod szout2)%Z) TT) + | Some (const_of v, gen_expr e) + => let v' := (Z.max 0 v mod sz1)%Z in + if (v' =? 0)%Z + then Pair (Op (OpConst 0%Z) TT) (Op (OpConst 0%Z) TT) + else if ((v' =? 1) && (2^Z.of_nat (2^bw2) <=? 2^bitwidth))%Z%bool + then match base_type_eq_semidec_transparent T2 Tout1 with + | Some pf => Pair (eq_rect _ (fun t => exprf (Tbase t)) e _ pf) + (Op (OpConst 0%Z) TT) + | None => Op opc args + end + else Op opc args + | Some (gen_expr e, const_of v) + => let v' := (Z.max 0 v mod sz2)%Z in + if (v' =? 0)%Z + then Pair (Op (OpConst 0%Z) TT) (Op (OpConst 0%Z) TT) + else if ((v' =? 1) && (2^Z.of_nat (2^bw1) <=? 2^bitwidth))%Z%bool + then match base_type_eq_semidec_transparent T1 Tout1 with + | Some pf => Pair (eq_rect _ (fun t => exprf (Tbase t)) e _ pf) + (Op (OpConst 0%Z) TT) + | None => Op opc args + end + else Op opc args + | _ => Op opc args + end | IdWithAlt (TWord _ as T1) _ (TWord _ as Tout) as opc => fun args => match base_type_eq_semidec_transparent T1 Tout with diff --git a/src/Compilers/Z/ArithmeticSimplifierInterp.v b/src/Compilers/Z/ArithmeticSimplifierInterp.v index 77ec72038..b55f069e1 100644 --- a/src/Compilers/Z/ArithmeticSimplifierInterp.v +++ b/src/Compilers/Z/ArithmeticSimplifierInterp.v @@ -100,6 +100,7 @@ Local Arguments Z.sub !_ !_. Local Arguments Z.opp !_. Local Arguments interp_op _ _ !_ _ / . Local Arguments lift_op / . +Local Opaque Z.pow. Lemma InterpSimplifyArith {convert_adc_to_sbb} {t} (e : Expr t) : forall x, Interp interp_op (SimplifyArith convert_adc_to_sbb e) x = Interp interp_op e x. @@ -143,60 +144,93 @@ Proof. | rewrite FixedWordSizesEquality.ZToWord_wordToZ_ZToWord by reflexivity | rewrite FixedWordSizesEquality.wordToZ_ZToWord_0 | rewrite !FixedWordSizesEquality.wordToZ_ZToWord_mod_full ]. + all:repeat match goal with + | [ H : _ = Some eq_refl |- _ ] => clear H + | [ H : interp_as_expr_or_const _ = Some _ |- _ ] => clear H + | [ H : interpf _ _ = _ |- _ ] => clear H + | [ H : Syntax.exprf _ _ _ |- _ ] => clear H + | [ H : Expr _ |- _ ] => clear H + | [ H : type _ |- _ ] => clear H + | [ H : bool |- _ ] => clear H + | [ |- context[FixedWordSizes.wordToZ ?e] ] + => pose proof (FixedWordSizesEquality.wordToZ_range e); + lazymatch e with + | interpf interp_op ?e' + => generalize dependent (FixedWordSizes.wordToZ e); clear e'; intros + | _ => is_var e; generalize dependent (FixedWordSizes.wordToZ e); + clear e; intros + end + | [ |- context[interpf interp_op ?e] ] + => is_var e; generalize dependent (interpf interp_op e); clear e; intros + | [ |- context[Z.of_nat (2^?e)] ] + => is_var e; assert ((0 < Z.of_nat (2^e))%Z) + by (rewrite Z.pow_Zpow; simpl Z.of_nat; Z.zero_bounds); + generalize dependent (Z.of_nat (2^e)); clear e; intros + end. all:try nia. - all:repeat first [ reflexivity - | omega - | progress rewrite ?Z.land_0_l, ?Z.land_0_r, ?Z.lor_0_l, ?Z.lor_0_r, ?Z.opp_involutive - | rewrite !Z.sub_with_borrow_to_add_get_carry - | progress cbv [Z.add_with_carry] - | rewrite Z.mod_mod by Z.zero_bounds - | match goal with - | [ |- context[FixedWordSizes.wordToZ_gen ?x] ] - => lazymatch goal with - | [ H : (0 <= FixedWordSizes.wordToZ_gen x)%Z |- _ ] => fail - | [ H : (0 <= FixedWordSizes.wordToZ_gen x < _)%Z |- _ ] => fail - | _ => pose proof (FixedWordSizesEquality.wordToZ_gen_range x) - end - | [ |- context[FixedWordSizes.wordToZ ?e] ] - => lazymatch goal with - | [ H : (0 <= FixedWordSizes.wordToZ e)%Z |- _ ] => fail - | [ H : (0 <= FixedWordSizes.wordToZ e < _)%Z |- _ ] => fail - | _ => pose proof (FixedWordSizesEquality.wordToZ_range e) - end - | [ |- context[(?x mod ?y)%Z] ] - => lazymatch goal with - | [ H : (0 <= x mod y)%Z |- _ ] => fail - | [ H : (0 <= x mod y < _)%Z |- _ ] => fail - | _ => assert (0 <= x mod y < y)%Z by (apply Z.mod_pos_bound; Z.zero_bounds) - end - | [ H : (2^Z.of_nat ?bw <= ?bw')%Z |- context[(2^?bw')%Z] ] - => unique assert ((2^Z.of_nat (2^bw) <= 2^bw')%Z) - by (rewrite Z.pow_Zpow; simpl @Z.of_nat; auto with zarith) - end - | progress autorewrite with zsimplify_const in * - | match goal with - | [ H : (0 <= ?x < _)%Z |- context[Z.max 0 ?x] ] - => rewrite (Z.max_r 0 x) in * by apply H - | [ H : (0 <= ?x < _)%Z, H' : (0 <= ?y < _)%Z |- context[Z.max 0 (?x * ?y)] ] - => rewrite (Z.max_r 0 (x * y)) in * by (apply Z.mul_nonneg_nonneg; first [ apply H | apply H' ]) - | [ H : (0 <= ?x < _)%Z, H' : (0 <= ?y < _)%Z |- context[Z.max 0 (?x + ?y)] ] - => rewrite (Z.max_r 0 (x + y)) in * by (apply Z.add_nonneg_nonneg; first [ apply H | apply H' ]) - | [ H : ?x = 0%Z |- context[?x] ] => rewrite H - | [ H : (?x mod ?y = 0)%Z |- context[((?x * _) mod ?y)%Z] ] - => rewrite (Z.mul_mod_full x _ y) - | [ H : (?x mod ?y = 0)%Z |- context[((_ * ?x) mod ?y)%Z] ] - => rewrite (Z.mul_mod_full _ x y) - | [ H : ?x = Z.pos _ |- context[?x] ] => rewrite H - | [ H : (?x mod ?y = Z.pos _)%Z |- context[((?x * _) mod ?y)%Z] ] - => rewrite (Z.mul_mod_full x _ y) - | [ H : (?x mod ?y = Z.pos _)%Z |- context[((_ * ?x) mod ?y)%Z] ] - => rewrite (Z.mul_mod_full _ x y) - | [ |- context[(?x mod ?m)%Z] ] - => rewrite (Z.mod_small x m) by Z.rewrite_mod_small_solver - | [ |- context[(?x / ?m)%Z] ] - => rewrite (Z.div_small x m) by Z.rewrite_mod_small_solver - end - | progress pull_Zmod ]. + Time + all:repeat first [ reflexivity + | omega + | progress change (2^0)%Z with 1%Z in * + | progress change (2^1)%Z with 2%Z in * + | progress rewrite ?Z.land_0_l, ?Z.land_0_r, ?Z.lor_0_l, ?Z.lor_0_r, ?Z.opp_involutive, ?Z.shiftr_0_r + | progress rewrite ?Z.land_ones by lia + | progress autorewrite with Zshift_to_pow in * + | rewrite !Z.sub_with_borrow_to_add_get_carry + | progress cbv [Z.add_with_carry] + | rewrite Z.mod_mod by Z.zero_bounds + | match goal with + | [ |- context[(?x mod ?y)%Z] ] + => lazymatch goal with + | [ H : (0 <= x mod y)%Z |- _ ] => fail + | [ H : (0 <= x mod y < _)%Z |- _ ] => fail + | _ => assert (0 <= x mod y < y)%Z by (apply Z.mod_pos_bound; Z.zero_bounds; lia) + end + | [ |- context[(?x / ?y)%Z] ] + => lazymatch goal with + | [ H : (0 <= x / y)%Z |- _ ] => fail + | _ => assert (0 <= x / y)%Z by Z.zero_bounds + end + | [ H : (2^Z.of_nat ?bw <= ?bw')%Z |- context[(2^?bw')%Z] ] + => unique assert ((2^Z.of_nat (2^bw) <= 2^bw')%Z) + by (rewrite Z.pow_Zpow; simpl @Z.of_nat; auto with zarith) + end + | progress autorewrite with zsimplify_const in * + | match goal with + | [ H : (2^?x <= 1)%Z, H' : (0 < ?x)%Z |- _ ] + => lazymatch goal with + | [ |- False ] => fail + | _ => exfalso; clear -H H'; assert (2^1 <= 2^x)%Z by auto with zarith + end + | [ H : (0 <= ?x < _)%Z |- context[Z.max 0 ?x] ] + => rewrite (Z.max_r 0 x) in * by apply H + | [ H : (0 <= ?x)%Z |- context[Z.max 0 ?x] ] + => rewrite (Z.max_r 0 x) in * by apply H + | [ H : (0 <= ?x < _)%Z, H' : (0 <= ?y < _)%Z |- context[Z.max 0 (?x * ?y)] ] + => rewrite (Z.max_r 0 (x * y)) in * by (apply Z.mul_nonneg_nonneg; first [ apply H | apply H' ]) + | [ H : (0 <= ?x < _)%Z, H' : (0 <= ?y < _)%Z |- context[Z.max 0 (?x + ?y)] ] + => rewrite (Z.max_r 0 (x + y)) in * by (apply Z.add_nonneg_nonneg; first [ apply H | apply H' ]) + | [ H : ?x = 0%Z |- context[?x] ] => rewrite H + | [ H : ?x = 0%Z, H' : context[?x] |- _ ] => rewrite H in H' + | [ H : ?x = Z.pos _ |- context[?x] ] => rewrite H + | [ H : ?x = Z.pos _, H' : context[?x] |- _ ] => rewrite H in H' + | [ H : context[(_^Z.neg ?p)%Z] |- _ ] + => rewrite (Z.pow_neg_r _ (Z.neg p)) in H by lia + | [ H : (?x mod ?y = 0)%Z |- context[((?x * _) mod ?y)%Z] ] + => rewrite (Z.mul_mod_full x _ y) + | [ H : (?x mod ?y = 0)%Z |- context[((_ * ?x) mod ?y)%Z] ] + => rewrite (Z.mul_mod_full _ x y) + | [ H : ?x = Z.pos _ |- context[?x] ] => rewrite H + | [ H : (?x mod ?y = Z.pos _)%Z |- context[((?x * _) mod ?y)%Z] ] + => rewrite (Z.mul_mod_full x _ y) + | [ H : (?x mod ?y = Z.pos _)%Z |- context[((_ * ?x) mod ?y)%Z] ] + => rewrite (Z.mul_mod_full _ x y) + | [ |- context[(?x mod ?m)%Z] ] + => rewrite (Z.mod_small x m) by Z.rewrite_mod_small_solver + | [ |- context[(?x / ?m)%Z] ] + => rewrite (Z.div_small x m) by Z.rewrite_mod_small_solver + end + | progress pull_Zmod ]. Qed. Hint Rewrite @InterpSimplifyArith : reflective_interp. |