aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/Z/ArithmeticSimplifier.v
blob: b2621c625bf125850d10f8f171414b4628add043 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
(** * SimplifyArith: Remove things like (_ * 1), (_ + 0), etc *)
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Compilers.Syntax.
Require Import Crypto.Compilers.Rewriter.
Require Import Crypto.Compilers.Z.Syntax.

Section language.
  Local Notation exprf := (@exprf base_type op).
  Local Notation Expr := (@Expr base_type op).

  Section with_var.
    Context {var : base_type -> Type}.

    Inductive inverted_expr t :=
    | const_of (v : Z)
    | gen_expr (e : exprf (var:=var) (Tbase t))
    | neg_expr (e : exprf (var:=var) (Tbase t)).

    Fixpoint interp_as_expr_or_const {t} (x : exprf (var:=var) t)
      : option (interp_flat_type inverted_expr t)
      := match x in Syntax.exprf _ _ t return option (interp_flat_type _ t) with
         | Op t1 (Tbase _) opc args
           => Some (match opc in op src dst return exprf dst -> exprf src -> inverted_expr match dst with Tbase t => t | _ => TZ end with
                    | OpConst _ z => fun _ _ => const_of _ z
                    | Opp TZ TZ => fun _ args => neg_expr _ args
                    | _ => fun e _ => gen_expr _ e
                    end (Op opc args) args)
         | TT => Some tt
         | Var t v => Some (gen_expr _ (Var v))
         | Op _ _ _ _
         | LetIn _ _ _ _
           => None
         | Pair tx ex ty ey
           => match @interp_as_expr_or_const tx ex, @interp_as_expr_or_const ty ey with
              | Some vx, Some vy => Some (vx, vy)
              | _, None | None, _ => None
              end
         end.

    Definition simplify_op_expr {src dst} (opc : op src dst)
      : exprf (var:=var) src -> exprf (var:=var) dst
      := match opc in op src dst return exprf src -> exprf dst with
         | Add TZ TZ TZ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (const_of l, const_of r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (const_of v, gen_expr e)
                 | Some (gen_expr e, const_of v)
                   => if (v =? 0)%Z
                      then e
                      else Op opc args
                 | Some (const_of v, neg_expr e)
                 | Some (neg_expr e, const_of v)
                   => if (v =? 0)%Z
                      then Op (Opp _ _) e
                      else Op opc args
                 | Some (gen_expr ep, neg_expr en)
                 | Some (neg_expr en, gen_expr ep)
                   => Op (Sub _ _ _) (Pair ep en)
                 | _ => Op opc args
                 end
         | Sub TZ TZ TZ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (const_of l, const_of r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (gen_expr e, const_of v)
                   => if (v =? 0)%Z
                      then e
                      else Op opc args
                 | Some (neg_expr e, const_of v)
                   => if (v =? 0)%Z
                      then Op (Opp _ _) e
                      else Op opc args
                 | Some (gen_expr e1, neg_expr e2)
                   => Op (Add _ _ _) (Pair e1 e2)
                 | Some (neg_expr e1, neg_expr e2)
                   => Op (Sub _ _ _) (Pair e2 e1)
                 | _ => Op opc args
                 end
         | Mul TZ TZ TZ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (const_of l, const_of r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (const_of v, gen_expr e)
                 | Some (gen_expr e, const_of v)
                   => if (v =? 0)%Z
                      then Op (OpConst 0%Z) TT
                      else if (v =? 1)%Z
                           then e
                           else if (v =? -1)%Z
                                then Op (Opp _ _) e
                                else Op opc args
                 | Some (const_of v, neg_expr e)
                 | Some (neg_expr e, const_of v)
                   => if (v =? 0)%Z
                      then Op (OpConst 0%Z) TT
                      else if (v =? 1)%Z
                           then Op (Opp _ _) e
                           else if (v =? -1)%Z
                                then e
                                else Op opc args
                 | Some (gen_expr e1, neg_expr e2)
                 | Some (neg_expr e1, gen_expr e2)
                   => Op (Opp _ _) (Op (Mul _ _ TZ) (Pair e1 e2))
                 | Some (neg_expr e1, neg_expr e2)
                   => Op (Mul _ _ _) (Pair e1 e2)
                 | _ => Op opc args
                 end
         | Shl TZ TZ TZ as opc
         | Shr TZ TZ TZ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (const_of l, const_of r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (gen_expr e, const_of v)
                   => if (v =? 0)%Z
                      then e
                      else Op opc args
                 | Some (neg_expr e, const_of v)
                   => if (v =? 0)%Z
                      then Op (Opp _ _) e
                      else Op opc args
                 | _ => Op opc args
                 end
         | Land TZ TZ TZ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (const_of l, const_of r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (const_of v, gen_expr _)
                 | Some (gen_expr _, const_of v)
                 | Some (const_of v, neg_expr _)
                 | Some (neg_expr _, const_of v)
                   => if (v =? 0)%Z
                      then Op (OpConst 0%Z) TT
                      else Op opc args
                 | _ => Op opc args
                 end
         | Lor TZ TZ TZ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (const_of l, const_of r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (const_of v, gen_expr e)
                 | Some (gen_expr e, const_of v)
                   => if (v =? 0)%Z
                      then e
                      else Op opc args
                 | Some (const_of v, neg_expr e)
                 | Some (neg_expr e, const_of v)
                   => if (v =? 0)%Z
                      then Op (Opp _ _) e
                      else Op opc args
                 | _ => Op opc args
                 end
         | Opp TZ TZ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (const_of v)
                   => Op (OpConst (interp_op _ _ opc v)) TT
                 | Some (neg_expr e)
                   => e
                 | _
                   => Op opc args
                 end
         | Add _ _ _ as opc
         | Sub _ _ _ as opc
         | Mul _ _ _ as opc
         | Shl _ _ _ as opc
         | Shr _ _ _ as opc
         | Land _ _ _ as opc
         | Lor _ _ _ as opc
         | OpConst _ _ as opc
         | Opp _ _ as opc
           => Op opc
         end.
  End with_var.

  Definition SimplifyArith {t} (e : Expr t) : Expr t
    := @RewriteOp base_type op (@simplify_op_expr) t e.
End language.