aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-06-18 21:51:05 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-06-18 21:51:05 -0400
commit7673940e6896e358a6ebc543fd92be89bd1e6d20 (patch)
tree02e17696905e9d1e9fde91260b98845eb437dac3 /src/Compilers
parentc2436f0f5b314c4765b4087ebe6d2001d459b402 (diff)
Better simplification of mulsplit
Diffstat (limited to 'src/Compilers')
-rw-r--r--src/Compilers/Z/ArithmeticSimplifier.v36
-rw-r--r--src/Compilers/Z/ArithmeticSimplifierInterp.v140
2 files changed, 123 insertions, 53 deletions
diff --git a/src/Compilers/Z/ArithmeticSimplifier.v b/src/Compilers/Z/ArithmeticSimplifier.v
index 6732d6c25..8f5ca7c9e 100644
--- a/src/Compilers/Z/ArithmeticSimplifier.v
+++ b/src/Compilers/Z/ArithmeticSimplifier.v
@@ -5,6 +5,7 @@ Require Import Crypto.Compilers.ExprInversion.
Require Import Crypto.Compilers.Rewriter.
Require Import Crypto.Compilers.Z.Syntax.
Require Import Crypto.Compilers.Z.Syntax.Equality.
+Require Import Crypto.Util.ZUtil.Definitions.
Section language.
Context (convert_adc_to_sbb : bool).
@@ -240,6 +241,41 @@ Section language.
| _
=> Op opc args
end
+ | MulSplit bitwidth (TWord bw1 as T1) (TWord bw2 as T2) (TWord bwout1 as Tout1) (TWord bwout2 as Tout2) as opc
+ => fun args
+ => let sz1 := (2^Z.of_nat (2^bw1))%Z in
+ let sz2 := (2^Z.of_nat (2^bw2))%Z in
+ let szout1 := (2^Z.of_nat (2^bwout1))%Z in
+ let szout2 := (2^Z.of_nat (2^bwout2))%Z in
+ match interp_as_expr_or_const args with
+ | Some (const_of l, const_of r)
+ => let '(a, b) := Z.mul_split_at_bitwidth bitwidth (Z.max 0 l mod sz1) (Z.max 0 r mod sz2) in
+ Pair (Op (OpConst (a mod szout1)%Z) TT)
+ (Op (OpConst (b mod szout2)%Z) TT)
+ | Some (const_of v, gen_expr e)
+ => let v' := (Z.max 0 v mod sz1)%Z in
+ if (v' =? 0)%Z
+ then Pair (Op (OpConst 0%Z) TT) (Op (OpConst 0%Z) TT)
+ else if ((v' =? 1) && (2^Z.of_nat (2^bw2) <=? 2^bitwidth))%Z%bool
+ then match base_type_eq_semidec_transparent T2 Tout1 with
+ | Some pf => Pair (eq_rect _ (fun t => exprf (Tbase t)) e _ pf)
+ (Op (OpConst 0%Z) TT)
+ | None => Op opc args
+ end
+ else Op opc args
+ | Some (gen_expr e, const_of v)
+ => let v' := (Z.max 0 v mod sz2)%Z in
+ if (v' =? 0)%Z
+ then Pair (Op (OpConst 0%Z) TT) (Op (OpConst 0%Z) TT)
+ else if ((v' =? 1) && (2^Z.of_nat (2^bw1) <=? 2^bitwidth))%Z%bool
+ then match base_type_eq_semidec_transparent T1 Tout1 with
+ | Some pf => Pair (eq_rect _ (fun t => exprf (Tbase t)) e _ pf)
+ (Op (OpConst 0%Z) TT)
+ | None => Op opc args
+ end
+ else Op opc args
+ | _ => Op opc args
+ end
| IdWithAlt (TWord _ as T1) _ (TWord _ as Tout) as opc
=> fun args
=> match base_type_eq_semidec_transparent T1 Tout with
diff --git a/src/Compilers/Z/ArithmeticSimplifierInterp.v b/src/Compilers/Z/ArithmeticSimplifierInterp.v
index 77ec72038..b55f069e1 100644
--- a/src/Compilers/Z/ArithmeticSimplifierInterp.v
+++ b/src/Compilers/Z/ArithmeticSimplifierInterp.v
@@ -100,6 +100,7 @@ Local Arguments Z.sub !_ !_.
Local Arguments Z.opp !_.
Local Arguments interp_op _ _ !_ _ / .
Local Arguments lift_op / .
+Local Opaque Z.pow.
Lemma InterpSimplifyArith {convert_adc_to_sbb} {t} (e : Expr t)
: forall x, Interp interp_op (SimplifyArith convert_adc_to_sbb e) x = Interp interp_op e x.
@@ -143,60 +144,93 @@ Proof.
| rewrite FixedWordSizesEquality.ZToWord_wordToZ_ZToWord by reflexivity
| rewrite FixedWordSizesEquality.wordToZ_ZToWord_0
| rewrite !FixedWordSizesEquality.wordToZ_ZToWord_mod_full ].
+ all:repeat match goal with
+ | [ H : _ = Some eq_refl |- _ ] => clear H
+ | [ H : interp_as_expr_or_const _ = Some _ |- _ ] => clear H
+ | [ H : interpf _ _ = _ |- _ ] => clear H
+ | [ H : Syntax.exprf _ _ _ |- _ ] => clear H
+ | [ H : Expr _ |- _ ] => clear H
+ | [ H : type _ |- _ ] => clear H
+ | [ H : bool |- _ ] => clear H
+ | [ |- context[FixedWordSizes.wordToZ ?e] ]
+ => pose proof (FixedWordSizesEquality.wordToZ_range e);
+ lazymatch e with
+ | interpf interp_op ?e'
+ => generalize dependent (FixedWordSizes.wordToZ e); clear e'; intros
+ | _ => is_var e; generalize dependent (FixedWordSizes.wordToZ e);
+ clear e; intros
+ end
+ | [ |- context[interpf interp_op ?e] ]
+ => is_var e; generalize dependent (interpf interp_op e); clear e; intros
+ | [ |- context[Z.of_nat (2^?e)] ]
+ => is_var e; assert ((0 < Z.of_nat (2^e))%Z)
+ by (rewrite Z.pow_Zpow; simpl Z.of_nat; Z.zero_bounds);
+ generalize dependent (Z.of_nat (2^e)); clear e; intros
+ end.
all:try nia.
- all:repeat first [ reflexivity
- | omega
- | progress rewrite ?Z.land_0_l, ?Z.land_0_r, ?Z.lor_0_l, ?Z.lor_0_r, ?Z.opp_involutive
- | rewrite !Z.sub_with_borrow_to_add_get_carry
- | progress cbv [Z.add_with_carry]
- | rewrite Z.mod_mod by Z.zero_bounds
- | match goal with
- | [ |- context[FixedWordSizes.wordToZ_gen ?x] ]
- => lazymatch goal with
- | [ H : (0 <= FixedWordSizes.wordToZ_gen x)%Z |- _ ] => fail
- | [ H : (0 <= FixedWordSizes.wordToZ_gen x < _)%Z |- _ ] => fail
- | _ => pose proof (FixedWordSizesEquality.wordToZ_gen_range x)
- end
- | [ |- context[FixedWordSizes.wordToZ ?e] ]
- => lazymatch goal with
- | [ H : (0 <= FixedWordSizes.wordToZ e)%Z |- _ ] => fail
- | [ H : (0 <= FixedWordSizes.wordToZ e < _)%Z |- _ ] => fail
- | _ => pose proof (FixedWordSizesEquality.wordToZ_range e)
- end
- | [ |- context[(?x mod ?y)%Z] ]
- => lazymatch goal with
- | [ H : (0 <= x mod y)%Z |- _ ] => fail
- | [ H : (0 <= x mod y < _)%Z |- _ ] => fail
- | _ => assert (0 <= x mod y < y)%Z by (apply Z.mod_pos_bound; Z.zero_bounds)
- end
- | [ H : (2^Z.of_nat ?bw <= ?bw')%Z |- context[(2^?bw')%Z] ]
- => unique assert ((2^Z.of_nat (2^bw) <= 2^bw')%Z)
- by (rewrite Z.pow_Zpow; simpl @Z.of_nat; auto with zarith)
- end
- | progress autorewrite with zsimplify_const in *
- | match goal with
- | [ H : (0 <= ?x < _)%Z |- context[Z.max 0 ?x] ]
- => rewrite (Z.max_r 0 x) in * by apply H
- | [ H : (0 <= ?x < _)%Z, H' : (0 <= ?y < _)%Z |- context[Z.max 0 (?x * ?y)] ]
- => rewrite (Z.max_r 0 (x * y)) in * by (apply Z.mul_nonneg_nonneg; first [ apply H | apply H' ])
- | [ H : (0 <= ?x < _)%Z, H' : (0 <= ?y < _)%Z |- context[Z.max 0 (?x + ?y)] ]
- => rewrite (Z.max_r 0 (x + y)) in * by (apply Z.add_nonneg_nonneg; first [ apply H | apply H' ])
- | [ H : ?x = 0%Z |- context[?x] ] => rewrite H
- | [ H : (?x mod ?y = 0)%Z |- context[((?x * _) mod ?y)%Z] ]
- => rewrite (Z.mul_mod_full x _ y)
- | [ H : (?x mod ?y = 0)%Z |- context[((_ * ?x) mod ?y)%Z] ]
- => rewrite (Z.mul_mod_full _ x y)
- | [ H : ?x = Z.pos _ |- context[?x] ] => rewrite H
- | [ H : (?x mod ?y = Z.pos _)%Z |- context[((?x * _) mod ?y)%Z] ]
- => rewrite (Z.mul_mod_full x _ y)
- | [ H : (?x mod ?y = Z.pos _)%Z |- context[((_ * ?x) mod ?y)%Z] ]
- => rewrite (Z.mul_mod_full _ x y)
- | [ |- context[(?x mod ?m)%Z] ]
- => rewrite (Z.mod_small x m) by Z.rewrite_mod_small_solver
- | [ |- context[(?x / ?m)%Z] ]
- => rewrite (Z.div_small x m) by Z.rewrite_mod_small_solver
- end
- | progress pull_Zmod ].
+ Time
+ all:repeat first [ reflexivity
+ | omega
+ | progress change (2^0)%Z with 1%Z in *
+ | progress change (2^1)%Z with 2%Z in *
+ | progress rewrite ?Z.land_0_l, ?Z.land_0_r, ?Z.lor_0_l, ?Z.lor_0_r, ?Z.opp_involutive, ?Z.shiftr_0_r
+ | progress rewrite ?Z.land_ones by lia
+ | progress autorewrite with Zshift_to_pow in *
+ | rewrite !Z.sub_with_borrow_to_add_get_carry
+ | progress cbv [Z.add_with_carry]
+ | rewrite Z.mod_mod by Z.zero_bounds
+ | match goal with
+ | [ |- context[(?x mod ?y)%Z] ]
+ => lazymatch goal with
+ | [ H : (0 <= x mod y)%Z |- _ ] => fail
+ | [ H : (0 <= x mod y < _)%Z |- _ ] => fail
+ | _ => assert (0 <= x mod y < y)%Z by (apply Z.mod_pos_bound; Z.zero_bounds; lia)
+ end
+ | [ |- context[(?x / ?y)%Z] ]
+ => lazymatch goal with
+ | [ H : (0 <= x / y)%Z |- _ ] => fail
+ | _ => assert (0 <= x / y)%Z by Z.zero_bounds
+ end
+ | [ H : (2^Z.of_nat ?bw <= ?bw')%Z |- context[(2^?bw')%Z] ]
+ => unique assert ((2^Z.of_nat (2^bw) <= 2^bw')%Z)
+ by (rewrite Z.pow_Zpow; simpl @Z.of_nat; auto with zarith)
+ end
+ | progress autorewrite with zsimplify_const in *
+ | match goal with
+ | [ H : (2^?x <= 1)%Z, H' : (0 < ?x)%Z |- _ ]
+ => lazymatch goal with
+ | [ |- False ] => fail
+ | _ => exfalso; clear -H H'; assert (2^1 <= 2^x)%Z by auto with zarith
+ end
+ | [ H : (0 <= ?x < _)%Z |- context[Z.max 0 ?x] ]
+ => rewrite (Z.max_r 0 x) in * by apply H
+ | [ H : (0 <= ?x)%Z |- context[Z.max 0 ?x] ]
+ => rewrite (Z.max_r 0 x) in * by apply H
+ | [ H : (0 <= ?x < _)%Z, H' : (0 <= ?y < _)%Z |- context[Z.max 0 (?x * ?y)] ]
+ => rewrite (Z.max_r 0 (x * y)) in * by (apply Z.mul_nonneg_nonneg; first [ apply H | apply H' ])
+ | [ H : (0 <= ?x < _)%Z, H' : (0 <= ?y < _)%Z |- context[Z.max 0 (?x + ?y)] ]
+ => rewrite (Z.max_r 0 (x + y)) in * by (apply Z.add_nonneg_nonneg; first [ apply H | apply H' ])
+ | [ H : ?x = 0%Z |- context[?x] ] => rewrite H
+ | [ H : ?x = 0%Z, H' : context[?x] |- _ ] => rewrite H in H'
+ | [ H : ?x = Z.pos _ |- context[?x] ] => rewrite H
+ | [ H : ?x = Z.pos _, H' : context[?x] |- _ ] => rewrite H in H'
+ | [ H : context[(_^Z.neg ?p)%Z] |- _ ]
+ => rewrite (Z.pow_neg_r _ (Z.neg p)) in H by lia
+ | [ H : (?x mod ?y = 0)%Z |- context[((?x * _) mod ?y)%Z] ]
+ => rewrite (Z.mul_mod_full x _ y)
+ | [ H : (?x mod ?y = 0)%Z |- context[((_ * ?x) mod ?y)%Z] ]
+ => rewrite (Z.mul_mod_full _ x y)
+ | [ H : ?x = Z.pos _ |- context[?x] ] => rewrite H
+ | [ H : (?x mod ?y = Z.pos _)%Z |- context[((?x * _) mod ?y)%Z] ]
+ => rewrite (Z.mul_mod_full x _ y)
+ | [ H : (?x mod ?y = Z.pos _)%Z |- context[((_ * ?x) mod ?y)%Z] ]
+ => rewrite (Z.mul_mod_full _ x y)
+ | [ |- context[(?x mod ?m)%Z] ]
+ => rewrite (Z.mod_small x m) by Z.rewrite_mod_small_solver
+ | [ |- context[(?x / ?m)%Z] ]
+ => rewrite (Z.div_small x m) by Z.rewrite_mod_small_solver
+ end
+ | progress pull_Zmod ].
Qed.
Hint Rewrite @InterpSimplifyArith : reflective_interp.