aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/Z/Named/RewriteAddToAdc.v
blob: c64663c8b999eb8215d53dc991409ef19ad9885c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Compilers.Named.Context.
Require Import Crypto.Compilers.Syntax.
Require Import Crypto.Compilers.Z.Syntax.
Require Import Crypto.Compilers.Named.GetNames.
Require Crypto.Compilers.Named.Syntax.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Notations.

Local Open Scope Z_scope.

Section named.
  Context {Name : Type}
          (name_beq : Name -> Name -> bool).
  Import Named.Syntax.
  Local Notation flat_type := (flat_type base_type).
  Local Notation type := (type base_type).
  Local Notation exprf := (@exprf base_type op Name).
  Local Notation expr := (@expr base_type op Name).

  Local Notation tZ := (Tbase TZ).
  Local Notation ADC bw c x y := (Op (@AddWithGetCarry bw TZ TZ TZ TZ TZ)
                                     (Pair (Pair (t1:=tZ) c (t2:=tZ) x) (t2:=tZ) y)).
  Local Notation ADD bw x y := (ADC bw (Op (OpConst 0) TT) x y).
  Local Notation ADX x y := (Op (@Add TZ TZ TZ) (Pair (t1:=tZ) x (t2:=tZ) y)).
  Local Infix "=Z?" := Z.eqb.
  Local Infix "=n?" := name_beq.

  Definition is_const_or_var {t} (v : exprf t)
    := match v with
       | Var _ _ => true
       | Op _ _ (OpConst _ _) TT => true
       | _ => false
       end.

  Fixpoint name_list_has_duplicate (ls : list Name) : bool
    := match ls with
       | nil => false
       | cons n ns
         => orb (name_list_has_duplicate ns)
                (List.fold_left orb (List.map (name_beq n) ns) false)
       end.

  Definition do_rewrite
             {t} (e : exprf t)
    : option (exprf t)
    := match e in Named.exprf _ _ _ t return option (exprf t) with
       |           (nlet (a2, c1) : tZ * tZ := (ADD bw1 a b as ex0) in P0)
         => match P0 in Named.exprf _ _ _ t return option (exprf t) with
            |      (nlet (s , c2) : tZ * tZ := (ADD bw2 c0 (Var _ a2') as ex1) in P1)
              => match P1 in Named.exprf _ _ _ t return option (exprf t) with
                 | (nlet c        : tZ      := (ADX (Var _ c1') (Var _ c2') as ex2) in P)
                   => if (((bw1 =Z? bw2) && (a2 =n? a2') && (c1 =n? c1') && (c2 =n? c2'))
                            && (is_const_or_var c0 && is_const_or_var a && is_const_or_var b)
                            && negb (name_list_has_duplicate (a2::c1::s::c2::c::nil ++ get_namesf c0 ++ get_namesf a ++ get_namesf b)%list))
                      then Some (nlet (a2, c1) : tZ * tZ := ex0 in
                                 nlet (s , c2) : tZ * tZ := ex1 in
                                 nlet c        : tZ      := ex2 in
                                 nlet (s, c)   : tZ * tZ := ADC bw1 c0 a b in
                                 P)
                      else None
                 | P1' => None
                 end
            | P0' => None
            end
       | _ => None
       end%core%nexpr%bool.

  Definition do_rewriteo {t} (e : exprf t) : exprf t
    := match do_rewrite e with
       | Some e' => e'
       | None => e
       end.

  Definition rewrite_exprf_prestep
             (rewrite_exprf : forall {t} (e : exprf t), exprf t)
             {t} (e : exprf t)
    : exprf t
    := match e in Named.exprf _ _ _ t return exprf t with
       | TT => TT
       | Var _ n => Var n
       | Op _ _ opc args
         => Op opc (@rewrite_exprf _ args)
       | (nlet nx := ex in eC)
         => (nlet nx := @rewrite_exprf _ ex in @rewrite_exprf _ eC)
       | Pair tx ex ty ey
         => Pair (@rewrite_exprf tx ex) (@rewrite_exprf ty ey)
       end%nexpr.

  Fixpoint rewrite_exprf {t} (e : exprf t) : exprf t
    := do_rewriteo (@rewrite_exprf_prestep (@rewrite_exprf) t e).

  Definition rewrite_expr {t} (e : expr t) : expr t
    := match e in Named.expr _ _ _ t return expr t with
       | Abs _ _ n f => Abs n (rewrite_exprf f)
       end.
End named.