aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/CountLets.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Compilers/CountLets.v')
-rw-r--r--src/Compilers/CountLets.v66
1 files changed, 66 insertions, 0 deletions
diff --git a/src/Compilers/CountLets.v b/src/Compilers/CountLets.v
new file mode 100644
index 000000000..4810162c8
--- /dev/null
+++ b/src/Compilers/CountLets.v
@@ -0,0 +1,66 @@
+(** * Counts how many binders there are *)
+Require Import Crypto.Compilers.Syntax.
+Require Import Crypto.Compilers.SmartMap.
+
+Local Open Scope ctype_scope.
+Section language.
+ Context {base_type_code : Type}
+ {op : flat_type base_type_code -> flat_type base_type_code -> Type}.
+
+ Local Notation flat_type := (flat_type base_type_code).
+ Local Notation type := (type base_type_code).
+ Local Notation Expr := (@Expr base_type_code op).
+
+ Fixpoint count_pairs (t : flat_type) : nat
+ := match t with
+ | Tbase _ => 1
+ | Unit => 0
+ | Prod A B => count_pairs A + count_pairs B
+ end%nat.
+
+ Section with_var.
+ Context {var : base_type_code -> Type}
+ (mkVar : forall t, var t).
+
+ Local Notation exprf := (@exprf base_type_code op var).
+ Local Notation expr := (@expr base_type_code op var).
+
+ Section gen.
+ Context (count_type_let : flat_type -> nat).
+ Context (count_type_abs : flat_type -> nat).
+
+ Fixpoint count_lets_genf {t} (e : exprf t) : nat
+ := match e with
+ | LetIn tx _ _ eC
+ => count_type_let tx + @count_lets_genf _ (eC (SmartValf var mkVar tx))
+ | Op _ _ _ e => @count_lets_genf _ e
+ | Pair _ ex _ ey => @count_lets_genf _ ex + @count_lets_genf _ ey
+ | _ => 0
+ end.
+ Definition count_lets_gen {t} (e : expr t) : nat
+ := match e with
+ | Abs tx _ f => count_type_abs tx + @count_lets_genf _ (f (SmartValf _ mkVar tx))
+ end.
+ End gen.
+
+ Definition count_let_bindersf {t} (e : exprf t) : nat
+ := count_lets_genf count_pairs e.
+ Definition count_letsf {t} (e : exprf t) : nat
+ := count_lets_genf (fun _ => 1) e.
+ Definition count_let_binders {t} (e : expr t) : nat
+ := count_lets_gen count_pairs (fun _ => 0) e.
+ Definition count_lets {t} (e : expr t) : nat
+ := count_lets_gen (fun _ => 1) (fun _ => 0) e.
+ Definition count_binders {t} (e : expr t) : nat
+ := count_lets_gen count_pairs count_pairs e.
+ End with_var.
+
+ Definition CountLetsGen (count_type_let : flat_type -> nat) (count_type_abs : flat_type -> nat) {t} (e : Expr t) : nat
+ := count_lets_gen (fun _ => tt) count_type_let count_type_abs (e _).
+ Definition CountLetBinders {t} (e : Expr t) : nat
+ := count_let_binders (fun _ => tt) (e _).
+ Definition CountLets {t} (e : Expr t) : nat
+ := count_lets (fun _ => tt) (e _).
+ Definition CountBinders {t} (e : Expr t) : nat
+ := count_binders (fun _ => tt) (e _).
+End language.