aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/Linearize.v
blob: e5d136f5f40ac40fe55c4e02596cb4ec1f780fb5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
(** * Linearize: Place all and only operations in let binders *)
Require Import Crypto.Compilers.Syntax.
Require Import Crypto.Compilers.SmartMap.
(*Require Import Crypto.Util.Tactics.*)

Local Open Scope ctype_scope.
Section language.
  Context (let_bind_op_args : bool)
          {base_type_code : Type}
          {op : flat_type base_type_code -> flat_type base_type_code -> Type}.

  Local Notation flat_type := (flat_type base_type_code).
  Local Notation type := (type base_type_code).
  Local Notation Tbase := (@Tbase base_type_code).
  Local Notation Expr := (@Expr base_type_code op).

  Section with_var.
    Context {var : base_type_code -> Type}.
    Local Notation exprf := (@exprf base_type_code op var).
    Local Notation expr := (@expr base_type_code op var).

    Section under_lets.
      Fixpoint under_letsf' (bind_pairs : bool) {t} (e : exprf t)
        : forall {tC} (C : interp_flat_type var t + exprf t -> exprf tC), exprf tC
        := match e in Syntax.exprf _ _ t return forall {tC} (C : interp_flat_type var t + exprf t -> exprf tC), exprf tC with
           | LetIn _ ex _ eC
             => fun _ C => @under_letsf' true _ ex _ (fun v =>
                           match v with
                           | inl v => @under_letsf' false _ (eC v) _ C
                           | inr v => LetIn v (fun v => @under_letsf' false _ (eC v) _ C)
                           end)
           | TT => fun _ C => C (inl tt)
           | Var _ x => fun _ C => C (inl x)
           | Op _ _ op args as e
             => if let_bind_op_args
                then fun _ C => LetIn e (fun v => C (inl v))
                else fun _ C => C (inr e)
           | Pair A x B y => fun _ C => @under_letsf' bind_pairs A x _ (fun x =>
                                        @under_letsf' bind_pairs B y _ (fun y =>
                                        match x, y, bind_pairs with
                                        | inl x, inl y, _ => C (inl (x, y))
                                        | inl x, inr y, false => C (inr (Pair (SmartVarf x) y))
                                        | inr x, inl y, false => C (inr (Pair x (SmartVarf y)))
                                        | inr x, inr y, false => C (inr (Pair x y))
                                        | inl x, inr y, true => LetIn y (fun y => C (inl (x, y)))
                                        | inr x, inl y, true => LetIn x (fun x => C (inl (x, y)))
                                        | inr x, inr y, true => LetIn x (fun x => LetIn y (fun y => C (inl (x, y))))
                                        end))
           end.
      Definition under_letsf {t} (e : exprf t)
                 {tC} (C : exprf t -> exprf tC)
        : exprf tC
        := under_letsf' false e (fun v => match v with inl v => C (SmartVarf v) | inr v => C v end).
    End under_lets.

    Fixpoint linearizef_gen {t} (e : exprf t) : exprf t
      := match e in Syntax.exprf _ _ t return exprf t with
         | LetIn _ ex _ eC
           => under_letsf (LetIn (@linearizef_gen _ ex) (fun x => @linearizef_gen _ (eC x))) (fun x => x)
         | TT => TT
         | Var _ x => Var x
         | Op _ _ op args
           => if let_bind_op_args
              then under_letsf (@linearizef_gen _ args) (fun args => LetIn (Op op args) SmartVarf)
              else under_letsf (@linearizef_gen _ args) (fun args => Op op args)
         | Pair A ex B ey
           => under_letsf (Pair (@linearizef_gen _ ex) (@linearizef_gen _ ey)) (fun v => v)
         end.

    Definition linearize_gen {t} (e : expr t) : expr t
      := match e in Syntax.expr _ _ t return expr t with
         | Abs _ _ f => Abs (fun x => linearizef_gen (f x))
         end.
  End with_var.

  Definition Linearize_gen {t} (e : Expr t) : Expr t
    := fun var => linearize_gen (e _).
End language.

Global Arguments under_letsf' _ {_ _ _} bind_pairs {_} _ {tC} _.
Global Arguments under_letsf _ {_ _ _ _} _ {tC} _.
Global Arguments linearizef_gen _ {_ _ _ _} _.
Global Arguments linearize_gen _ {_ _ _ _} _.
Global Arguments Linearize_gen _ {_ _ _} _ var.

Definition linearizef {base_type_code op var t} e
  := @linearizef_gen false base_type_code op var t e.
Definition a_normalf {base_type_code op var t} e
  := @linearizef_gen true base_type_code op var t e.
Definition linearize {base_type_code op var t} e
  := @linearize_gen false base_type_code op var t e.
Definition a_normal {base_type_code op var t} e
  := @linearize_gen true base_type_code op var t e.
Definition Linearize {base_type_code op t} e
  := @Linearize_gen false base_type_code op t e.
Definition ANormal {base_type_code op t} e
  := @Linearize_gen true base_type_code op t e.