aboutsummaryrefslogtreecommitdiff
path: root/src/Reflection/Z/ArithmeticSimplifier.v
blob: 68731f7b6528df90804dd4aa2e19eb53bdeb1bfd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
(** * SimplifyArith: Remove things like (_ * 1), (_ + 0), etc *)
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Reflection.Syntax.
Require Import Crypto.Reflection.Rewriter.
Require Import Crypto.Reflection.Z.Syntax.

Section language.
  Local Notation exprf := (@exprf base_type op).
  Local Notation Expr := (@Expr base_type op).

  Section with_var.
    Context {var : base_type -> Type}.

    Fixpoint interp_as_expr_or_const {t} (x : exprf (var:=var) t)
      : option (interp_flat_type (fun t => Z + (exprf (var:=var) (Tbase t)))%type t)
      := match x in Syntax.exprf _ _ t return option (interp_flat_type _ t) with
         | Op t1 (Tbase _) opc args
           => Some (match opc with
                    | OpConst _ z => fun _ => inl z
                    | _ => fun x => x
                    end (inr (Op opc args)))
         | TT => Some tt
         | Var t v => Some (inr (Var v))
         | Op _ _ _ _
         | LetIn _ _ _ _
           => None
         | Pair tx ex ty ey
           => match @interp_as_expr_or_const tx ex, @interp_as_expr_or_const ty ey with
              | Some vx, Some vy => Some (vx, vy)
              | _, None | None, _ => None
              end
         end.

    Definition simplify_op_expr {src dst} (opc : op src dst)
      : exprf (var:=var) src -> exprf (var:=var) dst
      := match opc in op src dst return exprf src -> exprf dst with
         | Add _ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (inl l, inl r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (inl v, inr e)
                 | Some (inr e, inl v)
                   => if (v =? 0)%Z
                      then e
                      else Op opc args
                 | _ => Op opc args
                 end
         | Sub _ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (inl l, inl r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (inr e, inl v)
                   => if (v =? 0)%Z
                      then e
                      else Op opc args
                 | _ => Op opc args
                 end
         | Mul _ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (inl l, inl r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (inl v, inr e)
                 | Some (inr e, inl v)
                   => if (v =? 0)%Z
                      then Op (OpConst 0%Z) TT
                      else if (v =? 1)%Z
                           then e
                           else Op opc args
                 | _ => Op opc args
                 end
         | Shl _ as opc
         | Shr _ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (inl l, inl r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (inr e, inl v)
                   => if (v =? 0)%Z
                      then e
                      else Op opc args
                 | _ => Op opc args
                 end
         | Land _ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (inl l, inl r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (inl v, inr e)
                 | Some (inr e, inl v)
                   => if (v =? 0)%Z
                      then Op (OpConst 0%Z) TT
                      else Op opc args
                 | _ => Op opc args
                 end
         | Lor _ as opc
           => fun args
              => match interp_as_expr_or_const args with
                 | Some (inl l, inl r)
                   => Op (OpConst (interp_op _ _ opc (l, r))) TT
                 | Some (inl v, inr e)
                 | Some (inr e, inl v)
                   => if (v =? 0)%Z
                      then e
                      else Op opc args
                 | _ => Op opc args
                 end
         | Cast _ _ as opc
         | OpConst _ _ as opc
         | Neg _ _ as opc
         | Cmovne _ as opc
         | Cmovle _ as opc
           => Op opc
         end.
  End with_var.

  Definition SimplifyArith {t} (e : Expr t) : Expr t
    := @RewriteOp base_type op (@simplify_op_expr) t e.
End language.