diff options
author | Jason Gross <jgross@mit.edu> | 2017-06-17 21:42:03 -0400 |
---|---|---|
committer | Jason Gross <jgross@mit.edu> | 2017-06-17 21:42:03 -0400 |
commit | 78a236f14d1e64f12719154947725a53a48262c9 (patch) | |
tree | a4ef19b83894a93dc39796e08bb1df0f93e617e0 | |
parent | 855c990570ac90cbdff5ca88e987cab96bcf0a00 (diff) |
More arithmetic simplification for adc, mul
-rw-r--r-- | src/Compilers/Z/ArithmeticSimplifier.v | 99 | ||||
-rw-r--r-- | src/Compilers/Z/ArithmeticSimplifierInterp.v | 32 |
2 files changed, 99 insertions, 32 deletions
diff --git a/src/Compilers/Z/ArithmeticSimplifier.v b/src/Compilers/Z/ArithmeticSimplifier.v index a0dc75bf5..8f231e415 100644 --- a/src/Compilers/Z/ArithmeticSimplifier.v +++ b/src/Compilers/Z/ArithmeticSimplifier.v @@ -155,6 +155,33 @@ Section language. => Op (Mul _ _ _) (Pair e1 e2) | _ => Op opc args end + | Mul (TWord bw1 as T1) (TWord bw2 as T2) (TWord bwout as Tout) as opc + => fun args + => let sz1 := (2^Z.of_nat (2^bw1))%Z in + let sz2 := (2^Z.of_nat (2^bw2))%Z in + match interp_as_expr_or_const args with + | Some (const_of l, const_of r) + => Op (OpConst ((Z.max 0 l mod sz1) * (Z.max 0 r mod sz2))%Z) TT + | Some (const_of v, gen_expr e) + => if (Z.max 0 v mod sz1 =? 0)%Z + then Op (OpConst 0%Z) TT + else if (Z.max 0 v mod sz1 =? 1)%Z + then match base_type_eq_semidec_transparent T2 Tout with + | Some pf => eq_rect _ (fun t => exprf (Tbase t)) e _ pf + | None => Op opc args + end + else Op opc args + | Some (gen_expr e, const_of v) + => if (Z.max 0 v mod sz2 =? 0)%Z + then Op (OpConst 0%Z) TT + else if (Z.max 0 v mod sz2 =? 1)%Z + then match base_type_eq_semidec_transparent T1 Tout with + | Some pf => eq_rect _ (fun t => exprf (Tbase t)) e _ pf + | None => Op opc args + end + else Op opc args + | _ => Op opc args + end | Shl TZ TZ TZ as opc | Shr TZ TZ TZ as opc => fun args @@ -371,39 +398,49 @@ Section language. else Op opc args | AddWithGetCarry bw (TWord bw1 as T1) (TWord bw2 as T2) (TWord bw3 as T3) (TWord bwout as Tout) Tout2 as opc => fun args - => match interp_as_expr_or_const args with - | Some (const_of c, const_of x, const_of y) - => if ((c =? 0) && (x =? 0) && (y =? 0))%Z%bool - then Pair (Op (OpConst 0) TT) (Op (OpConst 0) TT) - else Op opc args - | Some (gen_expr e, const_of c1, const_of c2) - => match base_type_eq_semidec_transparent T1 Tout with - | Some pf - => if ((c1 =? 0) && (c2 =? 0) && (2^Z.of_nat bw1 <=? bw))%Z%bool - then Pair (eq_rect _ (fun t => exprf (Tbase t)) e _ pf) (Op (OpConst 0) TT) - else Op opc args - | None - => Op opc args - end - | Some (const_of c1, gen_expr e, const_of c2) - => match base_type_eq_semidec_transparent T2 Tout with - | Some pf - => if ((c1 =? 0) && (c2 =? 0) && (2^Z.of_nat bw2 <=? bw))%Z%bool - then Pair (eq_rect _ (fun t => exprf (Tbase t)) e _ pf) (Op (OpConst 0) TT) - else Op opc args - | None - => Op opc args - end - | Some (const_of c1, const_of c2, gen_expr e) - => match base_type_eq_semidec_transparent T3 Tout with - | Some pf - => if ((c1 =? 0) && (c2 =? 0) && (2^Z.of_nat bw3 <=? bw))%Z%bool - then Pair (eq_rect _ (fun t => exprf (Tbase t)) e _ pf) (Op (OpConst 0) TT) + => let pass0 + := if ((bw1 =? 0) && (bw2 =? 0) && (bw3 =? 0) && (0 <? bwout) && (1 <? bw)%Z)%nat%bool + then Some (Pair (LetIn args (fun '(a, b, c) => Op (Add _ _ _) (Pair (Op (Add _ _ Tout) (Pair (Var a) (Var b))) (Var c)))) + (Op (OpConst 0) TT)) + else None + in + match pass0 with + | Some e => e + | None + => match interp_as_expr_or_const args with + | Some (const_of c, const_of x, const_of y) + => if ((c =? 0) && (x =? 0) && (y =? 0))%Z%bool + then Pair (Op (OpConst 0) TT) (Op (OpConst 0) TT) else Op opc args - | None - => Op opc args + | Some (gen_expr e, const_of c1, const_of c2) + => match base_type_eq_semidec_transparent T1 Tout with + | Some pf + => if ((c1 =? 0) && (c2 =? 0) && (2^Z.of_nat bw1 <=? bw))%Z%bool + then Pair (eq_rect _ (fun t => exprf (Tbase t)) e _ pf) (Op (OpConst 0) TT) + else Op opc args + | None + => Op opc args + end + | Some (const_of c1, gen_expr e, const_of c2) + => match base_type_eq_semidec_transparent T2 Tout with + | Some pf + => if ((c1 =? 0) && (c2 =? 0) && (2^Z.of_nat bw2 <=? bw))%Z%bool + then Pair (eq_rect _ (fun t => exprf (Tbase t)) e _ pf) (Op (OpConst 0) TT) + else Op opc args + | None + => Op opc args + end + | Some (const_of c1, const_of c2, gen_expr e) + => match base_type_eq_semidec_transparent T3 Tout with + | Some pf + => if ((c1 =? 0) && (c2 =? 0) && (2^Z.of_nat bw3 <=? bw))%Z%bool + then Pair (eq_rect _ (fun t => exprf (Tbase t)) e _ pf) (Op (OpConst 0) TT) + else Op opc args + | None + => Op opc args + end + | _ => Op opc args end - | _ => Op opc args end | SubWithBorrow TZ TZ TZ TZ as opc => fun args diff --git a/src/Compilers/Z/ArithmeticSimplifierInterp.v b/src/Compilers/Z/ArithmeticSimplifierInterp.v index e736f2631..124c93c95 100644 --- a/src/Compilers/Z/ArithmeticSimplifierInterp.v +++ b/src/Compilers/Z/ArithmeticSimplifierInterp.v @@ -138,6 +138,8 @@ Proof. | rewrite !Z.sub_with_borrow_to_add_get_carry | progress autorewrite with zsimplify_fast | progress cbv [cast_const ZToInterp interpToZ] + | progress change (Z.pow_pos 2 1) with 2%Z in * + | progress change (Z.pow_pos 2 2) with 4%Z in * | nia | progress cbv [Z.add_with_carry] | match goal with @@ -145,7 +147,35 @@ Proof. => rewrite (Z.mod_small x m) by solve_word_small () | [ |- context[(?x / ?m)%Z] ] => rewrite (Z.div_small x m) by solve_word_small () - end ]. + | [ H : ?x = 0%Z |- context[?x] ] => rewrite H + | [ H : ?x = 1%Z |- context[?x] ] => rewrite H + | [ H : (_ =? _)%nat = true |- _ ] => apply beq_nat_true in H + | [ H : (_ <? _)%nat = true |- _ ] => apply NPeano.Nat.ltb_lt in H + | [ |- context[FixedWordSizes.wordToZ_gen ?x] ] + => lazymatch goal with + | [ H : (0 <= FixedWordSizes.wordToZ_gen x)%Z |- _ ] => fail + | [ H : (0 <= FixedWordSizes.wordToZ_gen x < _)%Z |- _ ] => fail + | _ => pose proof (FixedWordSizesEquality.wordToZ_gen_range x) + end + | [ |- context[Z.max ?x ?y] ] + => first [ rewrite (Z.max_r x y) by omega + | rewrite (Z.max_l x y) by omega ] + | [ H : 0 < ?e |- context[(_ mod (2^Z.of_nat (2^?e)))%Z] ] + => lazymatch goal with + | [ H : (_ <= 2^Z.of_nat (2^e))%Z |- _ ] => fail + | _ => assert (2^Z.of_nat (2^1) <= 2^Z.of_nat (2^e))%Z + by (rewrite !Z.pow_Zpow; simpl Z.of_nat; auto with zarith) + end + | [ H : (1 < ?e)%Z |- context[(_ mod (2^?e))%Z] ] + => lazymatch goal with + | [ H : (_ <= 2^e)%Z |- _ ] => fail + | _ => assert (2^2 <= 2^e)%Z + by auto with zarith + end + end + | rewrite !FixedWordSizesEquality.wordToZ_ZToWord_mod_full + | progress Z.rewrite_mod_small + | rewrite Z.div_small by omega ]. Qed. Hint Rewrite @InterpSimplifyArith : reflective_interp. |