aboutsummaryrefslogtreecommitdiff
path: root/src/Reflection/Linearize.v
blob: b3ce3249b07ad3d9ce2063002eddb89eebfb8d40 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
(** * Linearize: Place all and only operations in let binders *)
Require Import Crypto.Reflection.Syntax.
Require Import Crypto.Util.Tactics.

Local Open Scope ctype_scope.
Section language.
  Context (base_type_code : Type).
  Context (interp_base_type : base_type_code -> Type).
  Context (op : flat_type base_type_code -> flat_type base_type_code -> Type).

  Local Notation flat_type := (flat_type base_type_code).
  Local Notation type := (type base_type_code).
  Let Tbase := @Tbase base_type_code.
  Local Coercion Tbase : base_type_code >-> Syntax.flat_type.
  Let interp_type := interp_type interp_base_type.
  Let interp_flat_type := interp_flat_type_gen interp_base_type.
  Local Notation Expr := (@Expr base_type_code interp_base_type op).

  Section with_var.
    Context {var : base_type_code -> Type}.
    Local Notation exprf := (@exprf base_type_code interp_base_type op var).
    Local Notation expr := (@expr base_type_code interp_base_type op var).

    Section under_lets.
      Fixpoint let_bind_const {t} (e : interp_flat_type t) {struct t}
        : forall {tC} (C : interp_flat_type_gen var t -> exprf tC), exprf tC
        := match t return forall (e : interp_flat_type t) {tC} (C : interp_flat_type_gen var t -> exprf tC), exprf tC with
           | Prod A B => fun e _ C => @let_bind_const A (fst e) _ (fun x =>
                                      @let_bind_const B (snd e) _ (fun y =>
                                      C (x, y)))
           | Syntax.Tbase _ => fun e _ C => Let (Const e) C
           end e.

      Fixpoint under_letsf {t} (e : exprf t)
        : forall {tC} (C : interp_flat_type_gen var t -> exprf tC), exprf tC
        := match e in Syntax.exprf _ _ _ t return forall {tC} (C : interp_flat_type_gen var t -> exprf tC), exprf tC with
           | Let _ ex _ eC
             => fun _ C => @under_letsf _ ex _ (fun v => @under_letsf _ (eC v) _ C)
           | Const _ x => fun _ C => let_bind_const x C
           | Var _ x => fun _ C => C x
           | Op _ _ op args as e => fun _ C => Let e C
           | Pair A x B y => fun _ C => @under_letsf A x _ (fun x =>
                                        @under_letsf B y _ (fun y =>
                                        C (x, y)))
           end.
    End under_lets.

    Fixpoint linearizef {t} (e : exprf t) : exprf t
      := match e in Syntax.exprf _ _ _ t return exprf t with
         | Let _ ex _ eC
           => under_letsf (@linearizef _ ex) (fun x => @linearizef _ (eC x))
         | Const _ x => Const x
         | Var _ x => Var x
         | Op _ _ op args
           => under_letsf (@linearizef _ args) (fun args => Let (Op op (SmartVar args)) SmartVar)
         | Pair A ex B ey
           => under_letsf (@linearizef _ ex) (fun x =>
              under_letsf (@linearizef _ ey) (fun y =>
              SmartVar (t:=Prod A B) (x, y)))
         end.

    Fixpoint linearize {t} (e : expr t) : expr t
      := match e in Syntax.expr _ _ _ t return expr t with
         | Return _ x => linearizef x
         | Abs _ _ f => Abs (fun x => @linearize _ (f x))
         end.
  End with_var.

  Section inline.
    Context {var : base_type_code -> Type}.
    Local Notation exprf := (@exprf base_type_code interp_base_type op).
    Local Notation expr := (@expr base_type_code interp_base_type op).

    Fixpoint inline_constf {t} (e : @exprf (@exprf var) t) : @exprf var t
      := match e in Syntax.exprf _ _ _ t return @exprf var t with
         | Let _ ex tC eC
           => match @inline_constf _ ex in Syntax.exprf _ _ _ t' return (interp_flat_type_gen _ t' -> @exprf var tC) -> @exprf var tC with
              | Const _ x => fun eC => eC (SmartConst (op:=op) (var:=var) x)
              | ex => fun eC => Let ex (fun x => eC (SmartVarVar x))
              end (fun x => @inline_constf _ (eC x))
         | Var _ x => x
         | Const _ x => Const x
         | Pair _ ex _ ey => Pair (@inline_constf _ ex) (@inline_constf _ ey)
         | Op _ _ op args => Op op (@inline_constf _ args)
         end.

    Fixpoint inline_const {t} (e : @expr (@exprf var) t) : @expr var t
      := match e in Syntax.expr _ _ _ t return @expr var t with
         | Return _ x => Return (inline_constf x)
         | Abs _ _ f => Abs (fun x => @inline_const _ (f (Var x)))
         end.
  End inline.

  Definition Linearize {t} (e : Expr t) : Expr t
    := fun var => linearize (e _).
  Definition InlineConst {t} (e : Expr t) : Expr t
    := fun var => inline_const (e _).
End language.

Arguments Linearize {_ _ _ _} _ var.
Arguments InlineConst {_ _ _ _} _ var.