diff options
Diffstat (limited to 'src/Experiments/NewPipeline/Rewriter.v')
-rw-r--r-- | src/Experiments/NewPipeline/Rewriter.v | 1780 |
1 files changed, 1780 insertions, 0 deletions
diff --git a/src/Experiments/NewPipeline/Rewriter.v b/src/Experiments/NewPipeline/Rewriter.v new file mode 100644 index 000000000..a055c4735 --- /dev/null +++ b/src/Experiments/NewPipeline/Rewriter.v @@ -0,0 +1,1780 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.ListUtil.FoldBool. +Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ZRange.Operations. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.OptionList. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.CPSNotations. +Require Crypto.Util.PrimitiveProd. +Require Crypto.Util.PrimitiveHList. +Require Import Crypto.Experiments.NewPipeline.Language. +Require Import Crypto.Experiments.NewPipeline.UnderLets. +Require Import Crypto.Experiments.NewPipeline.GENERATEDIdentifiersWithoutTypes. +Require Import Crypto.Util.Notations. +Import ListNotations. Local Open Scope bool_scope. Local Open Scope Z_scope. + +Module Compilers. + Export Language.Compilers. + Export UnderLets.Compilers. + Export GENERATEDIdentifiersWithoutTypes.Compilers. + Import invert_expr. + + Module pattern. + Export GENERATEDIdentifiersWithoutTypes.Compilers.pattern. + + Module base. + Local Notation einterp := type.interp. + Module type. + Inductive type := any | type_base (t : Compilers.base.type.base) | prod (A B : type) | list (A : type). + End type. + Notation type := type.type. + + Module Notations. + Global Coercion type.type_base : Compilers.base.type.base >-> type.type. + Bind Scope pbtype_scope with type.type. + (*Bind Scope ptype_scope with Compilers.type.type type.type.*) (* COQBUG(https://github.com/coq/coq/issues/7699) *) + Delimit Scope ptype_scope with ptype. + Delimit Scope pbtype_scope with pbtype. + Notation "A * B" := (type.prod A%ptype B%ptype) : ptype_scope. + Notation "A * B" := (type.prod A%pbtype B%pbtype) : pbtype_scope. + Notation "()" := (type.type_base base.type.unit) : pbtype_scope. + Notation "()" := (type.base (type.type_base base.type.unit)) : ptype_scope. + Notation "A -> B" := (type.arrow A%ptype B%ptype) : ptype_scope. + Notation "??" := type.any : pbtype_scope. + Notation "??" := (type.base type.any) : ptype_scope. + End Notations. + End base. + Notation type := (type.type base.type). + Export base.Notations. + + Inductive pattern {ident : Type} := + | Wildcard (t : type) + | Ident (idc : ident) + | App (f x : pattern). + + Global Arguments Wildcard {ident%type} t%ptype. + + Notation ident := ident.ident. + + Module Export Notations. + Export base.Notations. + Delimit Scope pattern_scope with pattern. + Bind Scope pattern_scope with pattern. + Local Open Scope pattern_scope. + Notation "#?()" := (Ident ident.LiteralUnit) : pattern_scope. + Notation "#?N" := (Ident ident.LiteralNat) : pattern_scope. + Notation "#?ℕ" := (Ident ident.LiteralNat) : pattern_scope. + Notation "#?Z" := (Ident ident.LiteralZ) : pattern_scope. + Notation "#?ℤ" := (Ident ident.LiteralZ) : pattern_scope. + Notation "#?B" := (Ident ident.LiteralBool) : pattern_scope. + Notation "#?𝔹" := (Ident ident.LiteralBool) : pattern_scope. + Notation "??{ t }" := (Wildcard t) (format "??{ t }") : pattern_scope. + Notation "??" := (??{??})%pattern : pattern_scope. + Notation "# idc" := (Ident idc) : pattern_scope. + Infix "@" := App : pattern_scope. + Notation "( x , y , .. , z )" := (#ident.pair @ .. (#ident.pair @ x @ y) .. @ z) : pattern_scope. + Notation "x :: xs" := (#ident.cons @ x @ xs) : pattern_scope. + Notation "xs ++ ys" := (#ident.List_app @ xs @ ys) : pattern_scope. + Notation "[ ]" := (#ident.nil) : pattern_scope. + Notation "[ x ]" := (x :: []) : pattern_scope. + Notation "[ x ; y ; .. ; z ]" := (x :: (y :: .. (z :: []) ..)) : pattern_scope. + Notation "x - y" := (#ident.Z_sub @ x @ y) : pattern_scope. + Notation "x + y" := (#ident.Z_add @ x @ y) : pattern_scope. + Notation "x / y" := (#ident.Z_div @ x @ y) : pattern_scope. + Notation "x * y" := (#ident.Z_mul @ x @ y) : pattern_scope. + Notation "x 'mod' y" := (#ident.Z_modulo @ x @ y)%pattern : pattern_scope. + Notation "- x" := (#ident.Z_opp @ x) : pattern_scope. + End Notations. + End pattern. + Export pattern.Notations. + Notation pattern := (@pattern.pattern pattern.ident). + + Module RewriteRules. + Module Import AnyExpr. + Record anyexpr {base_type} {ident var : type.type base_type -> Type} + := wrap { anyexpr_ty : base_type ; unwrap :> @expr.expr base_type ident var (type.base anyexpr_ty) }. + Global Arguments wrap {base_type ident var _} _. + End AnyExpr. + + Module Compile. + Section with_var0. + Context {base_type} {ident var : type.type base_type -> Type}. + Local Notation type := (type.type base_type). + Local Notation expr := (@expr.expr base_type ident var). + Local Notation UnderLets := (@UnderLets.UnderLets base_type ident var). + Let type_base (t : base_type) : type := type.base t. + Coercion type_base : base_type >-> type. + + Fixpoint value' (with_lets : bool) (t : type) + := match t with + | type.base t + => if with_lets then UnderLets (expr t) else expr t + | type.arrow s d + => value' false s -> value' true d + end. + Definition value := value' false. + Definition value_with_lets := value' true. + + Definition Base_value {t} : value t -> value_with_lets t + := match t with + | type.base t => fun v => UnderLets.Base v + | type.arrow _ _ => fun v => v + end. + + Fixpoint splice_under_lets_with_value {T t} (x : UnderLets T) : (T -> value_with_lets t) -> value_with_lets t + := match t return (T -> value_with_lets t) -> value_with_lets t with + | type.arrow s d + => fun k v => @splice_under_lets_with_value T d x (fun x' => k x' v) + | type.base _ => fun k => x <-- x; k x + end%under_lets. + Local Notation "x <--- v ; f" := (splice_under_lets_with_value x (fun v => f%under_lets)) : under_lets_scope. + Definition splice_value_with_lets {t t'} : value_with_lets t -> (value t -> value_with_lets t') -> value_with_lets t' + := match t return value_with_lets t -> (value t -> value_with_lets t') -> value_with_lets t' with + | type.arrow _ _ + => fun e k => k e + | type.base _ => fun e k => e <--- e; k e + end%under_lets. + End with_var0. + Section with_var. + Context {ident var : type.type base.type -> Type} + {pident : Type} + (*(invert_Literal_cps : forall t, ident t ~> option (type.interp base.interp t))*) + (*(beq_typed : forall t (X : pident) (Y : ident t), bool)*) + (full_types : pident -> Type) + (invert_bind_args : forall t (idc : ident t) (pidc : pident), option (full_types pidc)) + (type_of_pident : forall (pidc : pident), full_types pidc -> type.type base.type) + (pident_to_typed : forall (pidc : pident) (args : full_types pidc), ident (type_of_pident pidc args)) + (eta_ident_cps : forall {T : type.type base.type -> Type} {t} (idc : ident t) + (f : forall t', ident t' -> T t'), + T t) + (of_typed_ident : forall {t}, ident t -> pident) + (arg_types : pident -> option Type) + (bind_args : forall {t} (idc : ident t), match arg_types (of_typed_ident idc) return Type with Some t => t | None => unit end) + (pident_beq : pident -> pident -> bool) + (try_make_transport_ident_cps : forall (P : pident -> Type) (idc1 idc2 : pident), ~> option (P idc1 -> P idc2)). + Local Notation type := (type.type base.type). + Local Notation expr := (@expr.expr base.type ident var). + Local Notation anyexpr := (@anyexpr ident var). + Local Notation pattern := (@pattern.pattern pident). + Local Notation UnderLets := (@UnderLets.UnderLets base.type ident var). + Local Notation ptype := (type.type pattern.base.type). + Local Notation value' := (@value' base.type ident var). + Local Notation value := (@value base.type ident var). + Local Notation value_with_lets := (@value_with_lets base.type ident var). + Local Notation Base_value := (@Base_value base.type ident var). + Local Notation splice_under_lets_with_value := (@splice_under_lets_with_value base.type ident var). + Local Notation splice_value_with_lets := (@splice_value_with_lets base.type ident var). + Let type_base (t : base.type) : type := type.base t. + Coercion type_base : base.type >-> type. + + Context (reify_and_let_binds_base_cps : forall (t : base.type), expr t -> forall T, (expr t -> UnderLets T) -> UnderLets T). + + Local Notation "e <---- e' ; f" := (splice_value_with_lets e' (fun e => f%under_lets)) : under_lets_scope. + Local Notation "e <----- e' ; f" := (splice_under_lets_with_value e' (fun e => f%under_lets)) : under_lets_scope. + + Fixpoint reify {with_lets} {t} : value' with_lets t -> expr t + := match t, with_lets return value' with_lets t -> expr t with + | type.base _, false => fun v => v + | type.base _, true => fun v => UnderLets.to_expr v + | type.arrow s d, _ + => fun f + => λ x , @reify _ d (f (@reflect _ s ($x))) + end%expr%under_lets%cps + with reflect {with_lets} {t} : expr t -> value' with_lets t + := match t, with_lets return expr t -> value' with_lets t with + | type.base _, false => fun v => v + | type.base _, true => fun v => UnderLets.Base v + | type.arrow s d, _ + => fun f (x : value' _ _) => @reflect _ d (f @ (@reify _ s x)) + end%expr%under_lets. + + Definition reify_and_let_binds_cps {with_lets} {t} : value' with_lets t -> forall T, (expr t -> UnderLets T) -> UnderLets T + := match t, with_lets return value' with_lets t -> forall T, (expr t -> UnderLets T) -> UnderLets T with + | type.base _, false => reify_and_let_binds_base_cps _ + | type.base _, true => fun v => fun T k => v' <-- v; reify_and_let_binds_base_cps _ v' T k + | type.arrow s d, _ + => fun f T k => k (reify f) + end%expr%under_lets%cps. + + Inductive rawexpr : Type := + | rIdent {t} (idc : ident t) {t'} (alt : expr t') + | rApp (f x : rawexpr) {t} (alt : expr t) + | rExpr {t} (e : expr t) + | rValue {t} (e : value t). + + Definition type_of_rawexpr (e : rawexpr) : type + := match e with + | rIdent t idc t' alt => t' + | rApp f x t alt => t + | rExpr t e => t + | rValue t e => t + end. + Definition expr_of_rawexpr (e : rawexpr) : expr (type_of_rawexpr e) + := match e with + | rIdent t idc t' alt => alt + | rApp f x t alt => alt + | rExpr t e => e + | rValue t e => reify e + end. + Definition value_of_rawexpr (e : rawexpr) : value (type_of_rawexpr e) + := Eval cbv [expr_of_rawexpr] in + match e with + | rValue t e => e + | e => reflect (expr_of_rawexpr e) + end. + Definition rValueOrExpr {t} : value t -> rawexpr + := match t with + | type.base _ => @rExpr _ + | type.arrow _ _ => @rValue _ + end. + Definition rValueOrExpr2 {t} : value t -> expr t -> rawexpr + := match t with + | type.base _ => fun v e => @rExpr _ e + | type.arrow _ _ => fun v e => @rValue _ v + end. + + Definition try_rExpr_cps {T t} (k : option rawexpr -> T) : expr t -> T + := match t with + | type.base _ => fun e => k (Some (rExpr e)) + | type.arrow _ _ => fun _ => k None + end. + + Definition reveal_rawexpr_cps (e : rawexpr) : ~> rawexpr + := fun T k + => match e with + | rExpr _ e as r + | rValue (type.base _) e as r + => match e with + | expr.Ident t idc => k (rIdent idc e) + | expr.App s d f x => k (rApp (rExpr f) (rExpr x) e) + | _ => k r + end + | e' => k e' + end. + + Inductive quant_type := qforall | qexists. + + (* p for pattern *) + Fixpoint pbase_type_interp_cps (quant : quant_type) (t : pattern.base.type) (K : base.type -> Type) : Type + := match t with + | pattern.base.type.any + => match quant with + | qforall => forall t : base.type, K t + | qexists => { t : base.type & K t } + end + | pattern.base.type.type_base t => K t + | pattern.base.type.prod A B + => @pbase_type_interp_cps + quant A + (fun A' + => @pbase_type_interp_cps + quant B (fun B' => K (A' * B')%etype)) + | pattern.base.type.list A + => @pbase_type_interp_cps + quant A (fun A' => K (base.type.list A')) + end. + + Fixpoint ptype_interp_cps (quant : quant_type) (t : ptype) (K : type -> Type) {struct t} : Type + := match t with + | type.base t + => pbase_type_interp_cps quant t (fun t => K (type.base t)) + | type.arrow s d + => @ptype_interp_cps + quant s + (fun s => @ptype_interp_cps + quant d (fun d => K (type.arrow s d))) + end. + + Definition ptype_interp (quant : quant_type) (t : ptype) (K : Type -> Type) : Type + := ptype_interp_cps quant t (fun t => K (value t)). + + Fixpoint binding_dataT (p : pattern) : Type + := match p return Type with + | pattern.Wildcard t => ptype_interp qexists t id + | pattern.Ident idc => match arg_types idc return Type with + | Some t => t + | None => unit + end + | pattern.App f x => binding_dataT f * binding_dataT x + end%type. + + Fixpoint bind_base_cps {t1 t2} + (K : base.type -> Type) + (v : K t2) + {struct t1} + : ~> option (pbase_type_interp_cps qexists t1 K) + := match t1 return ~> option (pbase_type_interp_cps qexists t1 K) with + | pattern.base.type.any + => (return (Some (existT K t2 v))) + | pattern.base.type.type_base t + => (tr <-- base.try_make_transport_cps _ _ _; + return (Some (tr v))) + | pattern.base.type.prod A B + => fun T k + => match t2 return K t2 -> T with + | base.type.prod A' B' + => fun v + => (v' <-- @bind_base_cps B B' (fun B' => K (A' * B')%etype) v; + v'' <-- @bind_base_cps A A' (fun A' => pbase_type_interp_cps qexists B (fun B' => K (A' * B')%etype)) v'; + return (Some v'')) + T k + | _ => fun _ => k None + end v + | pattern.base.type.list A + => fun T k + => match t2 return K t2 -> T with + | base.type.list A' + => fun v => @bind_base_cps A A' (fun A' => K (base.type.list A')) v T k + | _ => fun _ => k None + end v + end%cps. + + Fixpoint bind_value_cps {t1 t2} + (K : type -> Type) + (v : K t2) + {struct t1} + : ~> option (ptype_interp_cps qexists t1 K) + := match t1 return ~> option (ptype_interp_cps qexists t1 K) with + | type.base t1 + => fun T k + => match t2 return K t2 -> T with + | type.base t2 + => fun v => bind_base_cps (fun t => K (type.base t)) v T k + | _ => fun _ => k None + end v + | type.arrow A B + => fun T k + => match t2 return K t2 -> T with + | type.arrow A' B' + => fun v + => (v' <-- @bind_value_cps B B' (fun B' => K (A' -> B')%etype) v; + v'' <-- @bind_value_cps A A' (fun A' => ptype_interp_cps qexists B (fun B' => K (A' -> B')%etype)) v'; + return (Some v'')) + T k + | _ => fun _ => k None + end v + end%cps. + + Fixpoint bind_data_cps (e : rawexpr) (p : pattern) + : ~> option (binding_dataT p) + := match p, e return ~> option (binding_dataT p) with + | pattern.Wildcard t, _ + => bind_value_cps value (value_of_rawexpr e) + | pattern.Ident pidc, rIdent _ idc _ _ + => (tr <-- (try_make_transport_ident_cps + (fun idc => match arg_types idc with + | Some t1 => t1 + | None => unit + end) _ _); + return (Some (tr (bind_args _ idc)))) + | pattern.App pf px, rApp f x _ _ + => (f' <-- bind_data_cps f pf; + x' <-- bind_data_cps x px; + return (Some (f', x'))) + | pattern.Ident _, _ + | pattern.App _ _, _ + => (return None) + end%cps. + + (** We follow + http://moscova.inria.fr/~maranget/papers/ml05e-maranget.pdf, + "Compiling Pattern Matching to Good Decision Trees" by Luc + Maranget. A [decision_tree] describes how to match a + vector (or list) of patterns against a vector of + expressions. The cases of a [decision_tree] are: + + - [TryLeaf k onfailure]: Try the kth rewrite rule; if it + fails, keep going with [onfailure] + + - [Failure]: Abort; nothing left to try + + - [Switch icases app_case default]: With the first element + of the vector, match on its kind; if it is an identifier + matching something in [icases], remove the first element + of the vector run that decision tree; if it is an + application and [app_case] is not [None], try the + [app_case] decision_tree, replacing the first element of + each vector with the two elements of the function and + the argument its applied to; otherwise, don't modify the + vectors, and use the [default] decision tree. + + - [Swap i cont]: Swap the first element of the vector with + the ith element, and keep going with [cont] *) + Inductive decision_tree := + | TryLeaf (k : nat) (onfailure : decision_tree) + | Failure + | Switch (icases : list (pident * decision_tree)) + (app_case : option decision_tree) + (default : decision_tree) + | Swap (i : nat) (cont : decision_tree). + + Definition swap_list {A} (i j : nat) (ls : list A) : option (list A) + := match nth_error ls i, nth_error ls j with + | Some vi, Some vj => Some (set_nth i vj (set_nth j vi ls)) + | _, _ => None + end. + + Fixpoint eval_decision_tree {T} (ctx : list rawexpr) (d : decision_tree) (cont : option nat -> list rawexpr -> option (unit -> T) -> T) {struct d} : T + := match d with + | TryLeaf k onfailure + => cont (Some k) ctx + (Some (fun 'tt => @eval_decision_tree T ctx onfailure cont)) + | Failure => cont None ctx None + | Switch icases app_case default_case + => match ctx with + | nil => cont None ctx None + | ctx0 :: ctx' + => let default _ := @eval_decision_tree T ctx default_case cont in + reveal_rawexpr_cps + ctx0 _ + (fun ctx0' + => match ctx0' with + | rIdent t idc t' alt + => fold_right + (fun '(pidc, icase) default 'tt + => match invert_bind_args _ idc pidc with + | Some args + => @eval_decision_tree + T ctx' icase + (fun k ctx'' + => cont k (rIdent (pident_to_typed pidc args) alt :: ctx'')) + | None => default tt + end) + default + icases + tt + | rApp f x t alt + => match app_case with + | Some app_case + => @eval_decision_tree + T (f :: x :: ctx') app_case + (fun k ctx'' + => match ctx'' with + | f' :: x' :: ctx''' + => cont k (rApp f' x' alt :: ctx''') + | _ => cont None ctx + end) + | None => default tt + end + | rExpr t e + | rValue t e + => default tt + end) + end + | Swap i d' + => match swap_list 0 i ctx with + | Some ctx' + => @eval_decision_tree + T ctx' d' + (fun k ctx'' + => match swap_list 0 i ctx'' with + | Some ctx''' => cont k ctx''' + | None => cont None ctx + end) + | None => cont None ctx None + end + end. + + Local Notation opt_anyexprP ivar + := (fun should_do_again : bool => UnderLets (@AnyExpr.anyexpr base.type ident (if should_do_again then ivar else var))) + (only parsing). + Local Notation opt_anyexpr ivar + := (option (sigT (opt_anyexprP ivar))) (only parsing). + + Definition rewrite_ruleTP + := (fun p : pattern => binding_dataT p -> forall T, (opt_anyexpr value -> T) -> T). + Definition rewrite_ruleT := sigT rewrite_ruleTP. + Definition rewrite_rulesT + := (list rewrite_ruleT). + + Definition eval_rewrite_rules + (do_again : forall t : base.type, @expr.expr base.type ident value t -> UnderLets (expr t)) + (maybe_do_again + := fun (should_do_again : bool) (t : base.type) + => if should_do_again return ((@expr.expr base.type ident (if should_do_again then value else var) t) -> UnderLets (expr t)) + then do_again t + else UnderLets.Base) + (d : decision_tree) + (rew : rewrite_rulesT) + (e : rawexpr) + : UnderLets (expr (type_of_rawexpr e)) + := eval_decision_tree + (e::nil) d + (fun k ctx default_on_rewrite_failure + => match k, ctx return UnderLets (expr (type_of_rawexpr e)) with + | Some k', e'::nil + => match nth_error rew k' return UnderLets (expr (type_of_rawexpr e)) with + | Some (existT p f) + => bind_data_cps + e' p _ + (fun v + => match v with + | Some v + => f v _ + (fun fv + => match fv return UnderLets (expr (type_of_rawexpr e)) with + | Some (existT should_do_again fv) + => (fv <-- fv; + fv <-- maybe_do_again should_do_again _ fv; + type.try_transport_cps + base.try_make_transport_cps _ _ _ fv _ + (fun fv' + => match fv', default_on_rewrite_failure with + | Some fv'', _ => UnderLets.Base fv'' + | None, Some default => default tt + | None, None => UnderLets.Base (expr_of_rawexpr e) + end))%under_lets + | None => match default_on_rewrite_failure with + | Some default => default tt + | None => UnderLets.Base (expr_of_rawexpr e) + end + end) + | None => UnderLets.Base (expr_of_rawexpr e) + end) + | None => UnderLets.Base (expr_of_rawexpr e) + end + | _, _ => UnderLets.Base (expr_of_rawexpr e) + end). + + Local Notation enumerate ls + := (List.combine (List.seq 0 (List.length ls)) ls). + + Fixpoint first_satisfying_helper {A B} (f : A -> option B) (ls : list A) : option B + := match ls with + | nil => None + | cons x xs + => match f x with + | Some v => Some v + | None => first_satisfying_helper f xs + end + end. + + Definition get_index_of_first_non_wildcard (p : list pattern) : option nat + := first_satisfying_helper + (fun '(n, x) => match x with + | pattern.Wildcard _ => None + | _ => Some n + end) + (enumerate p). + + Definition filter_pattern_wildcard (p : list (nat * list pattern)) : list (nat * list pattern) + := filter (fun '(_, p) => match p with + | pattern.Wildcard _::_ => true + | _ => false + end) + p. + + Fixpoint get_unique_pattern_ident' (p : list (nat * list pattern)) (so_far : list pident) : list pident + := match p with + | nil => List.rev so_far + | (_, pattern.Ident pidc :: _) :: ps + => let so_far' := if existsb (pident_beq pidc) so_far + then so_far + else pidc :: so_far in + get_unique_pattern_ident' ps so_far' + | _ :: ps => get_unique_pattern_ident' ps so_far + end. + + Definition get_unique_pattern_ident p : list pident := get_unique_pattern_ident' p nil. + + Definition contains_pattern_pident (pidc : pident) (p : list (nat * list pattern)) : bool + := existsb (fun '(n, p) => match p with + | pattern.Ident pidc'::_ => pident_beq pidc pidc' + | _ => false + end) + p. + + Definition contains_pattern_app (p : list (nat * list pattern)) : bool + := existsb (fun '(n, p) => match p with + | pattern.App _ _::_ => true + | _ => false + end) + p. + + Definition refine_pattern_app (p : nat * list pattern) : option (nat * list pattern) + := match p with + | (n, pattern.Wildcard d::ps) + => Some (n, (??{?? -> d} :: ?? :: ps)%list%pattern) + | (n, pattern.App f x :: ps) + => Some (n, f :: x :: ps) + | (_, pattern.Ident _::_) + | (_, nil) + => None + end. + + Definition refine_pattern_pident (pidc : pident) (p : nat * list pattern) : option (nat * list pattern) + := match p with + | (n, pattern.Wildcard _::ps) + => Some (n, ps) + | (n, pattern.Ident pidc'::ps) + => if pident_beq pidc pidc' + then Some (n, ps) + else None + | (_, pattern.App _ _::_) + | (_, nil) + => None + end. + + Definition compile_rewrites_step + (compile_rewrites : list (nat * list pattern) -> option decision_tree) + (pattern_matrix : list (nat * list pattern)) + : option decision_tree + := match pattern_matrix with + | nil => Some Failure + | (n1, p1) :: ps + => match get_index_of_first_non_wildcard p1 with + | None (* p1 is all wildcards *) + => (onfailure <- compile_rewrites ps; + Some (TryLeaf n1 onfailure)) + | Some Datatypes.O + => default_case <- compile_rewrites (filter_pattern_wildcard pattern_matrix); + app_case <- (if contains_pattern_app pattern_matrix + then option_map Some (compile_rewrites (Option.List.map refine_pattern_app pattern_matrix)) + else Some None); + let pidcs := get_unique_pattern_ident pattern_matrix in + let icases := Option.List.map + (fun pidc => option_map (pair pidc) (compile_rewrites (Option.List.map (refine_pattern_pident pidc) pattern_matrix))) + pidcs in + Some (Switch icases app_case default_case) + | Some i + => let pattern_matrix' + := List.map + (fun '(n, ps) + => (n, + match swap_list 0 i ps with + | Some ps' => ps' + | None => nil (* should be impossible *) + end)) + pattern_matrix in + d <- compile_rewrites pattern_matrix'; + Some (Swap i d) + end + end%option. + + Fixpoint compile_rewrites' (fuel : nat) (pattern_matrix : list (nat * list pattern)) + : option decision_tree + := match fuel with + | Datatypes.O => None + | Datatypes.S fuel' => compile_rewrites_step (@compile_rewrites' fuel') pattern_matrix + end. + + Definition compile_rewrites (fuel : nat) (ps : rewrite_rulesT) + := compile_rewrites' fuel (enumerate (List.map (fun p => projT1 p :: nil) ps)). + + + Fixpoint with_bindingsT (p : pattern) (T : Type) + := match p return Type with + | pattern.Wildcard t => ptype_interp qforall t (fun eT => eT -> T) + | pattern.Ident idc + => match arg_types idc with + | Some t => t -> T + | None => T + end + | pattern.App f x => with_bindingsT f (with_bindingsT x T) + end. + + Fixpoint lift_pbase_type_interp_cps {K1 K2} {quant} (F : forall t : base.type, K1 t -> K2 t) {t} + : pbase_type_interp_cps quant t K1 + -> pbase_type_interp_cps quant t K2 + := match t, quant return pbase_type_interp_cps quant t K1 + -> pbase_type_interp_cps quant t K2 with + | pattern.base.type.any, qforall + => fun f t => F t (f t) + | pattern.base.type.any, qexists + => fun tf => existT _ _ (F _ (projT2 tf)) + | pattern.base.type.type_base t, _ + => F _ + | pattern.base.type.prod A B, _ + => @lift_pbase_type_interp_cps + _ _ quant + (fun A' + => @lift_pbase_type_interp_cps + _ _ quant (fun _ => F _) B) + A + | pattern.base.type.list A, _ + => @lift_pbase_type_interp_cps + _ _ quant (fun _ => F _) A + end. + + Fixpoint lift_ptype_interp_cps {K1 K2} {quant} (F : forall t : type.type base.type, K1 t -> K2 t) {t} + : ptype_interp_cps quant t K1 + -> ptype_interp_cps quant t K2 + := match t return ptype_interp_cps quant t K1 + -> ptype_interp_cps quant t K2 with + | type.base t + => lift_pbase_type_interp_cps F + | type.arrow A B + => @lift_ptype_interp_cps + _ _ quant + (fun A' + => @lift_ptype_interp_cps + _ _ quant (fun _ => F _) B) + A + end. + + Fixpoint lift_with_bindings {p} {A B : Type} (F : A -> B) {struct p} : with_bindingsT p A -> with_bindingsT p B + := match p return with_bindingsT p A -> with_bindingsT p B with + | pattern.Wildcard t + => lift_ptype_interp_cps + (K1:=fun t => value t -> A) + (K2:=fun t => value t -> B) + (fun _ f v => F (f v)) + | pattern.Ident idc + => match arg_types idc as ty + return match ty with + | Some t => t -> A + | None => A + end -> match ty with + | Some t => t -> B + | None => B + end + with + | Some _ => fun f v => F (f v) + | None => F + end + | pattern.App f x + => @lift_with_bindings + f _ _ + (@lift_with_bindings x _ _ F) + end. + + Fixpoint app_pbase_type_interp_cps {T : Type} {K1 K2 : base.type -> Type} + (F : forall t, K1 t -> K2 t -> T) + {t} + : pbase_type_interp_cps qforall t K1 + -> pbase_type_interp_cps qexists t K2 -> T + := match t return pbase_type_interp_cps qforall t K1 + -> pbase_type_interp_cps qexists t K2 -> T with + | pattern.base.type.any + => fun f tv => F _ (f _) (projT2 tv) + | pattern.base.type.type_base t + => fun f v => F _ f v + | pattern.base.type.prod A B + => @app_pbase_type_interp_cps + _ + (fun A' => pbase_type_interp_cps qforall B (fun B' => K1 (A' * B')%etype)) + (fun A' => pbase_type_interp_cps qexists B (fun B' => K2 (A' * B')%etype)) + (fun A' + => @app_pbase_type_interp_cps + _ + (fun B' => K1 (A' * B')%etype) + (fun B' => K2 (A' * B')%etype) + (fun _ => F _) + B) + A + | pattern.base.type.list A + => @app_pbase_type_interp_cps T (fun A' => K1 (base.type.list A')) (fun A' => K2 (base.type.list A')) (fun _ => F _) A + end. + + Fixpoint app_ptype_interp_cps {T : Type} {K1 K2 : type -> Type} + (F : forall t, K1 t -> K2 t -> T) + {t} + : ptype_interp_cps qforall t K1 + -> ptype_interp_cps qexists t K2 -> T + := match t return ptype_interp_cps qforall t K1 + -> ptype_interp_cps qexists t K2 -> T with + | type.base t => app_pbase_type_interp_cps F + | type.arrow A B + => @app_ptype_interp_cps + _ + (fun A' => ptype_interp_cps qforall B (fun B' => K1 (A' -> B')%etype)) + (fun A' => ptype_interp_cps qexists B (fun B' => K2 (A' -> B')%etype)) + (fun A' + => @app_ptype_interp_cps + _ + (fun B' => K1 (A' -> B')%etype) + (fun B' => K2 (A' -> B')%etype) + (fun _ => F _) + B) + A + end. + + Fixpoint app_binding_data {T p} : forall (f : with_bindingsT p T) (v : binding_dataT p), T + := match p return forall (f : with_bindingsT p T) (v : binding_dataT p), T with + | pattern.Wildcard t + => app_ptype_interp_cps + (K1:=fun t => value t -> T) + (K2:=fun t => value t) + (fun _ f v => f v) + | pattern.Ident idc + => match arg_types idc as ty + return match ty with + | Some t => t -> T + | None => T + end -> match ty return Type with + | Some t => t + | None => unit + end -> T + with + | Some t => fun f x => f x + | None => fun v 'tt => v + end + | pattern.App f x + => fun F '(vf, vx) + => @app_binding_data _ x (@app_binding_data _ f F vf) vx + end. + + (** XXX MOVEME? *) + Definition mkcast {P : type -> Type} {t1 t2 : type} : ~> (option (P t1 -> P t2)) + := fun T k => type.try_make_transport_cps base.try_make_transport_cps P t1 t2 _ k. + Definition cast {P : type -> Type} {t1 t2 : type} (v : P t1) : ~> (option (P t2)) + := fun T k => type.try_transport_cps base.try_make_transport_cps P t1 t2 v _ k. + Definition castb {P : base.type -> Type} {t1 t2 : base.type} (v : P t1) : ~> (option (P t2)) + := fun T k => base.try_transport_cps P t1 t2 v _ k. + Definition castbe {t1 t2 : base.type} (v : expr t1) : ~> (option (expr t2)) + := @castb expr t1 t2 v. + Definition castv {t1 t2} (v : value t1) : ~> (option (value t2)) + := fun T k => type.try_transport_cps base.try_make_transport_cps value t1 t2 v _ k. + + Section with_do_again. + Context (dtree : decision_tree) + (rewrite_rules : rewrite_rulesT) + (default_fuel : nat) + (do_again : forall t : base.type, @expr.expr base.type ident value t -> UnderLets (expr t)). + + Let dorewrite1 (e : rawexpr) : UnderLets (expr (type_of_rawexpr e)) + := eval_rewrite_rules do_again dtree rewrite_rules e. + + Fixpoint assemble_identifier_rewriters' (t : type) : forall e : rawexpr, (forall P, P (type_of_rawexpr e) -> P t) -> value_with_lets t + := match t return forall e : rawexpr, (forall P, P (type_of_rawexpr e) -> P t) -> value_with_lets t with + | type.base _ + => fun e k => k (fun t => UnderLets (expr t)) (dorewrite1 e) + | type.arrow s d + => fun f k (x : value' _ _) + => let x' := reify x in + @assemble_identifier_rewriters' d (rApp f (rValueOrExpr2 x x') (k _ (expr_of_rawexpr f) @ x'))%expr (fun _ => id) + end%under_lets. + + Definition assemble_identifier_rewriters {t} (idc : ident t) : value_with_lets t + := eta_ident_cps _ _ idc (fun t' idc' => assemble_identifier_rewriters' t' (rIdent idc' #idc') (fun _ => id)). + End with_do_again. + End with_var. + + Section full. + Context {var : type.type base.type -> Type}. + Local Notation expr := (@expr base.type ident). + Local Notation value := (@Compile.value base.type ident var). + Local Notation value_with_lets := (@Compile.value_with_lets base.type ident var). + Local Notation UnderLets := (UnderLets.UnderLets base.type ident var). + Local Notation reify_and_let_binds_cps := (@Compile.reify_and_let_binds_cps ident var (@UnderLets.reify_and_let_binds_base_cps var)). + Local Notation reflect := (@Compile.reflect ident var). + Section with_rewrite_head. + Context (rewrite_head : forall t (idc : ident t), value_with_lets t). + + Local Notation "e <---- e' ; f" := (Compile.splice_value_with_lets e' (fun e => f%under_lets)) : under_lets_scope. + Local Notation "e <----- e' ; f" := (Compile.splice_under_lets_with_value e' (fun e => f%under_lets)) : under_lets_scope. + + Fixpoint rewrite_bottomup {t} (e : @expr value t) : value_with_lets t + := match e in expr.expr t return value_with_lets t with + | expr.Ident t idc + => rewrite_head _ idc + | expr.App s d f x => let f : value s -> value_with_lets d := @rewrite_bottomup _ f in x <---- @rewrite_bottomup _ x; f x + | expr.LetIn A B x f => x <---- @rewrite_bottomup A x; + xv <----- reify_and_let_binds_cps x _ UnderLets.Base; + @rewrite_bottomup B (f (reflect xv)) + | expr.Var t v => Compile.Base_value v + | expr.Abs s d f => fun x : value s => @rewrite_bottomup d (f x) + end%under_lets. + End with_rewrite_head. + + Notation nbe := (@rewrite_bottomup (fun t idc => reflect (expr.Ident idc))). + + Fixpoint repeat_rewrite + (rewrite_head : forall (do_again : forall t : base.type, @expr value (type.base t) -> UnderLets (@expr var (type.base t))) + t (idc : ident t), value_with_lets t) + (fuel : nat) {t} e : value_with_lets t + := @rewrite_bottomup + (rewrite_head + (fun t' e' + => match fuel with + | Datatypes.O => nbe e' + | Datatypes.S fuel' => @repeat_rewrite rewrite_head fuel' (type.base t') e' + end%under_lets)) + t e. + + Definition rewrite rewrite_head fuel {t} e : expr t + := reify (@repeat_rewrite rewrite_head fuel t e). + End full. + + Definition Rewrite rewrite_head fuel {t} (e : expr.Expr (ident:=ident) t) : expr.Expr (ident:=ident) t + := fun var => @rewrite var (rewrite_head var) fuel t (e _). + End Compile. + + Module pident := pattern.ident. + + Module Make. + Section make_rewrite_rules. + Import Compile. + Context {var : type.type base.type -> Type}. + Local Notation type := (type.type base.type). + Local Notation expr := (@expr.expr base.type ident var). + Local Notation value := (@value base.type ident var). + Local Notation anyexpr := (@anyexpr ident var). + Local Notation pattern := (@pattern.pattern pattern.ident). + Local Notation UnderLets := (@UnderLets.UnderLets base.type ident var). + Local Notation ptype := (type.type pattern.base.type). + Let type_base (t : base.type) : type := type.base t. + Let ptype_base (t : pattern.base.type) : ptype := type.base t. + Let ptype_base' (t : base.type.base) : ptype := @type.base pattern.base.type t. + Coercion ptype_base' : base.type.base >-> ptype. + Coercion type_base : base.type >-> type. + Coercion ptype_base : pattern.base.type >-> ptype. + Local Notation opt_anyexprP ivar + := (fun should_do_again : bool => UnderLets (@AnyExpr.anyexpr base.type ident (if should_do_again then ivar else var))). + Local Notation opt_anyexpr ivar + := (option (sigT (opt_anyexprP ivar))). + Local Notation binding_dataT := (@binding_dataT ident var pattern.ident pattern.ident.arg_types). + Local Notation lift_with_bindings := (@lift_with_bindings ident var pattern.ident pattern.ident.arg_types). + Local Notation app_binding_data := (@app_binding_data ident var pattern.ident pattern.ident.arg_types). + Local Notation rewrite_rulesT := (@rewrite_rulesT ident var pattern.ident pattern.ident.arg_types). + Local Notation rewrite_ruleT := (@rewrite_ruleT ident var pattern.ident pattern.ident.arg_types). + Local Notation castv := (@castv ident var). + + Definition make_base_Literal_pattern (t : base.type.base) : pattern + := Eval cbv [pident.of_typed_ident] in + pattern.Ident (pident.of_typed_ident (@ident.Literal t DefaultValue.type.base.default)). + + Definition bind_base_Literal_pattern (t : base.type.base) : binding_dataT (make_base_Literal_pattern t) ~> base.interp t + := match t return binding_dataT (make_base_Literal_pattern t) ~> base.interp t with + | base.type.unit + | base.type.Z + | base.type.bool + | base.type.nat + => fun v => (return v) + end%cps. + + Fixpoint make_Literal_pattern (t : base.type) : option { p : pattern & binding_dataT p ~> base.interp t } + := match t return option { p : pattern & binding_dataT p ~> base.interp t } with + | base.type.type_base t => Some (existT _ (make_base_Literal_pattern t) (bind_base_Literal_pattern t)) + | base.type.prod A B + => (a <- make_Literal_pattern A; + b <- make_Literal_pattern B; + Some (existT + (fun p : pattern => binding_dataT p ~> base.interp (A * B)) + (#pident.pair @ (projT1 a) @ (projT1 b))%pattern + (fun '(args : unit * binding_dataT (projT1 a) * binding_dataT (projT1 b)) + => (av <--- projT2 a (snd (fst args)); + bv <--- projT2 b (snd args); + return (av, bv))))) + | base.type.list A => None + end%option%cps. + + Fixpoint make_interp_rewrite' (t : type) (p : pattern) (rew : binding_dataT p ~> type.interp base.interp t) {struct t} + : option rewrite_ruleT + := match t return (_ ~> type.interp base.interp t) -> _ with + | type.base t + => fun rew + => Some (existT _ p (fun args => v <--- rew args; + return (Some (existT _ false (UnderLets.Base (AnyExpr.wrap (ident.smart_Literal v))))))) + | type.arrow (type.base s) d + => fun rew + => (lit_s <- make_Literal_pattern s; + @make_interp_rewrite' + d + (pattern.App p (projT1 lit_s)) + (fun (args : binding_dataT p * binding_dataT (projT1 lit_s)) + => (rewp <--- rew (fst args); + sv <--- projT2 lit_s (snd args); + return (rewp sv)))) + | type.arrow _ _ => fun _ => None + end%option%cps rew. + + Definition make_interp_rewrite'' {t} (idc : ident t) : option rewrite_ruleT + := make_interp_rewrite' + t + (pattern.Ident (pident.of_typed_ident idc)) + (fun iargs => return (ident.interp (pident.retype_ident idc iargs)))%cps. + (* + Definition make_interp_rewrite {t} (idc : ident t) + := invert_Some (make_interp_rewrite'' idc). + *) + + Local Ltac get_all_valid_interp_rules_from body so_far := + let next := match body with + | context[@Some (sigT (fun x : pattern => binding_dataT x ~> opt_anyexpr value)) ?rew] + => lazymatch so_far with + | context[cons rew _] => constr:(I : I) + | _ => lazymatch rew with + | existT _ _ _ => constr:(Some rew) + | _ => constr:(I : I) + end + end + | _ => constr:(@None unit) + end in + lazymatch next with + | Some ?rew => get_all_valid_interp_rules_from body (cons rew so_far) + | None => (eval cbv [List.rev List.app] in (List.rev so_far)) + end. + Local Ltac make_valid_interp_rules := + let body := constr:(fun t idc => @pident.eta_ident_cps _ t idc (@make_interp_rewrite'')) in + let body := (eval cbv [pident.eta_ident_cps make_interp_rewrite'' make_interp_rewrite' make_Literal_pattern pident.of_typed_ident Option.bind projT1 projT2 cpsbind cpsreturn cpscall ident.interp pident.retype_ident ident.gen_interp bind_base_Literal_pattern make_base_Literal_pattern] in body) in + let body := (eval cbn [base.interp binding_dataT pattern.ident.arg_types base.base_interp ident.smart_Literal fold_right map] in body) in + let retv := get_all_valid_interp_rules_from body (@nil rewrite_ruleT) in + exact retv. + Definition interp_rewrite_rules : rewrite_rulesT + := ltac:(make_valid_interp_rules). + End make_rewrite_rules. + End Make. + + Section with_var. + Import Compile. + Context {var : type.type base.type -> Type}. + Local Notation type := (type.type base.type). + Local Notation expr := (@expr.expr base.type ident var). + Local Notation value := (@value base.type ident var). + Local Notation anyexpr := (@anyexpr ident var). + Local Notation pattern := (@pattern.pattern pattern.ident). + Local Notation UnderLets := (@UnderLets.UnderLets base.type ident var). + Local Notation ptype := (type.type pattern.base.type). + Let type_base (t : base.type) : type := type.base t. + Let ptype_base (t : pattern.base.type) : ptype := type.base t. + Let ptype_base' (t : base.type.base) : ptype := @type.base pattern.base.type t. + Coercion ptype_base' : base.type.base >-> ptype. + Coercion type_base : base.type >-> type. + Coercion ptype_base : pattern.base.type >-> ptype. + Local Notation opt_anyexprP ivar + := (fun should_do_again : bool => UnderLets (@AnyExpr.anyexpr base.type ident (if should_do_again then ivar else var))). + Local Notation opt_anyexpr ivar + := (option (sigT (opt_anyexprP ivar))). + Local Notation binding_dataT := (@binding_dataT ident var pattern.ident pattern.ident.arg_types). + Local Notation lift_with_bindings := (@lift_with_bindings ident var pattern.ident pattern.ident.arg_types). + Local Notation app_binding_data := (@app_binding_data ident var pattern.ident pattern.ident.arg_types). + Local Notation rewrite_ruleTP := (@rewrite_ruleTP ident var pattern.ident pattern.ident.arg_types). + Local Notation rewrite_rulesT := (@rewrite_rulesT ident var pattern.ident pattern.ident.arg_types). + Local Notation castv := (@castv ident var). + Local Notation assemble_identifier_rewriters := (@assemble_identifier_rewriters ident var pattern.ident pattern.ident.full_types (@pattern.ident.invert_bind_args) pattern.ident.type_of pattern.ident.to_typed (@pattern.ident.eta_ident_cps) (@pattern.ident.of_typed_ident) pattern.ident.arg_types (@pattern.ident.bind_args) pattern.ident.try_make_transport_ident_cps). + + Let UnderLetsExpr {btype bident ivar} t := @UnderLets.UnderLets base.type ident var (@expr.expr btype bident ivar t). + Let UnderLetsAnyExpr {btype ident ivar} := @UnderLets.UnderLets btype ident ivar (@AnyExpr.anyexpr btype ident ivar). + Let UnderLetsAnyExprCpsOpt {btype bident ivar} := ~> option (@UnderLets.UnderLets base.type ident var (@AnyExpr.anyexpr btype bident ivar)). + (*Let UnderLetsAnyAnyExpr {btype ident ivar} := @UnderLets.UnderLets btype ident ivar (@AnyAnyExpr.anyexpr btype ident ivar).*) + Let BaseWrapUnderLetsAnyExpr {btype bident ivar t} : @UnderLetsExpr btype bident ivar t -> @UnderLetsAnyExprCpsOpt btype bident ivar + := fun e T k + => k (match t return @UnderLets.UnderLets _ _ _ (@expr.expr _ _ _ t) -> _ with + | type.base _ => fun e => Some (e <-- e; UnderLets.Base (AnyExpr.wrap e))%under_lets + | type.arrow _ _ => fun _ => None + end e)%cps. + Let BaseExpr {btype ident ivar t} : @expr.expr btype ident ivar t -> @UnderLetsExpr btype ident ivar t := UnderLets.Base. + (*Let BaseAnyAnyExpr {btype ident ivar t} : @expr.expr btype ident ivar t -> @UnderLets.UnderLets btype ident ivar (@expr.expr btype ident ivar t) := UnderLets.Base.*) + Coercion BaseWrapUnderLetsAnyExpr : UnderLetsExpr >-> UnderLetsAnyExprCpsOpt. + Coercion BaseExpr : expr >-> UnderLetsExpr. + Notation ret v := ((v : UnderLetsExpr _) : UnderLetsAnyExprCpsOpt). + Notation oret v := (fun T k => k (Some v)). + (*Coercion BaseExpr : expr >-> UnderLets.*) + Notation make_rewrite'_cps p f + := (existT + (fun p' : pattern => binding_dataT p' ~> (opt_anyexpr value)) + p%pattern + (fun v T (k : opt_anyexpr value -> T) + => @app_binding_data _ p%pattern f%expr v T k)). + Notation make_rewrite' p f + := (existT + (fun p' : pattern => binding_dataT p' ~> (opt_anyexpr value)) + p%pattern + (fun v T (k : opt_anyexpr value -> T) + => k (@app_binding_data _ p%pattern f%expr v))). + Notation make_rewrite p f + := (let f' := (@lift_with_bindings p _ _ (fun x:@UnderLetsAnyExprCpsOpt base.type ident var => (x' <-- x; oret (existT (opt_anyexprP value) false x'))%cps) f%expr) in + make_rewrite'_cps p f'). + Notation make_rewrite_step p f + := (let f' := (@lift_with_bindings p _ _ (fun x:@UnderLetsAnyExprCpsOpt base.type ident value => (x' <-- x; oret (existT (opt_anyexprP value) true x'))%cps) f%expr) in + make_rewrite'_cps p f'). + + Local Notation "x' <- v ; C" := (fun T k => v%cps T (fun x' => match x' with Some x' => (C%cps : UnderLetsAnyExprCpsOpt) T k | None => k None end)) : cps_scope. + Local Notation "x <-- y ; f" := (UnderLets.splice y (fun x => (f%cps : UnderLetsExpr _))) : cps_scope. + Local Notation "x <--- y ; f" := (UnderLets.splice_list y (fun x => (f%cps : UnderLetsExpr _))) : cps_scope. + Local Notation "x <---- y ; f" := (fun T k => match y with Some x => (f%cps : UnderLetsAnyExprCpsOpt) T k | None => k None end) : cps_scope. + + Definition rlist_rect {A P} + {ivar} + (Pnil : @UnderLetsExpr base.type ident ivar (type.base P)) + (Pcons : expr (type.base A) -> list (expr (type.base A)) -> @expr.expr base.type ident ivar (type.base P) -> @UnderLetsExpr base.type ident ivar (type.base P)) + (e : expr (type.base (base.type.list A))) + : @UnderLetsAnyExprCpsOpt base.type ident ivar + := (ls <- reflect_list_cps e; + list_rect + (fun _ => UnderLetsExpr (type.base P)) + Pnil + (fun x xs rec => rec' <-- rec; Pcons x xs rec') + ls)%cps. + + Definition rlist_rect_cast {A A' P} + {ivar} + (Pnil : @UnderLetsExpr base.type ident ivar (type.base P)) + (Pcons : expr (type.base A) -> list (expr (type.base A)) -> @expr.expr base.type ident ivar (type.base P) -> @UnderLetsExpr base.type ident ivar (type.base P)) + (e : expr (type.base A')) + : @UnderLetsAnyExprCpsOpt base.type ident ivar + := (e <- castbe e; rlist_rect Pnil Pcons e)%cps. + + Definition rwhen {ivar} (v : @UnderLetsAnyExprCpsOpt base.type ident ivar) (cond : bool) + : @UnderLetsAnyExprCpsOpt base.type ident ivar + := fun T k => if cond then v T k else k None. + + Local Notation "e 'when' cond" := (rwhen e%cps cond) (only parsing, at level 100). + + Local Notation ℤ := base.type.Z. + Local Notation ℕ := base.type.nat. + Local Notation bool := base.type.bool. + Local Notation list := pattern.base.type.list. + + Local Arguments Make.interp_rewrite_rules / . + + (** + The follow are rules for rewriting expressions. On the left is a pattern to match: + ??: any expression whose type contains no arrows. + ??{x}: any expression whose type is x. + ??{pattern.base.type.list ??}: for example, a list with elements of a captured type. (The captured type does not match a type with arrows.) + x @ y: x applied to y. + #?x: a value, know at compile time, with type x. (Where x is one of {ℕ or N (nat), 𝔹 or B (bool), ℤ or Z (integers)}.) + #x: the identifer x. + + A matched expression is replaced with the right-hand-side, which is a function that returns a syntax tree, or None to indicate that the match didn't really match. The syntax tree is under three monads: continuation, option, and custom UnderLets monad. + + The function takes the elements that where matched on the LHS as arguments. The arguments are given in the same order as on the LHS, but where wildcards in a type appear before the outer wildcard for that element. So ??{??} results in two arguments, the second wildcard comes first, and ??{?? -> ??} gives arguments in the order 2, 3, 1. + + Sometimes matching an identifer will also result in arguments. Depends on the identifer. Good luck! + +In the RHS, the follow notation applies: + ##x: the literal value x + #x: the identifier x + x @ y: x applied to y + $x: PHOAS variable named x + λ: PHOAS abstraction / functions + + On the RHS, since we're returning a value under three monads, there's some fun notion for dealing with different levels of the monad stack in a single expression: + ret: return something of type [UnderLets expr] + <-: bind, under the CPS+Option monad. + <--: bind, under the UnderLets monad + <---: bind, under the UnderLets+List monad + <----: bind, under the Option monad. + + If you have an expression of type expr or UnderLetsExpr or UnderLetsAnyExprCpsOpt, coercions will handle it; if you have an expression of type [UnderLets expr], you will need [ret]. + + If stuck, email Jason. + *) + Definition rewrite_rules : rewrite_rulesT + := Eval cbn [Make.interp_rewrite_rules List.app] in + Make.interp_rewrite_rules + ++ [ + make_rewrite (#pident.fst @ (??, ??)) (fun _ x _ y => x) + ; make_rewrite (#pident.snd @ (??, ??)) (fun _ x _ y => y) + ; make_rewrite (#pident.List_repeat @ ?? @ #?ℕ) (fun _ x n => reify_list (repeat x n)) + ; make_rewrite + (#pident.bool_rect @ ??{() -> ??} @ ??{() -> ??} @ #?𝔹) + (fun _ t _ f b + => if b return UnderLetsExpr (type.base (if b then _ else _)) + then t ##tt + else f ##tt) + ; make_rewrite + (#pident.pair_rect @ ??{?? -> ?? -> ??} @ (??, ??)) + (fun _ _ _ f _ x _ y + => x <- castbe x; y <- castbe y; ret (f x y)) + ; make_rewrite + (??{list ??} ++ ??{list ??}) + (fun _ xs _ ys => rlist_rect_cast ys (fun x _ xs_ys => x :: xs_ys) xs) + ; make_rewrite + (#pident.List_rev @ ??{list ??}) + (fun _ xs + => xs <- reflect_list_cps xs; + reify_list (List.rev xs)) + ; make_rewrite_step + (#pident.List_flat_map @ ??{?? -> list ??} @ ??{list ??}) + (fun _ B f _ xs + => rlist_rect_cast + [] + (fun x _ flat_map_tl => fx <-- f x; UnderLets.Base ($fx ++ flat_map_tl)) + xs) + ; make_rewrite_step + (#pident.List_partition @ ??{?? -> base.type.bool} @ ??{list ??}) + (fun _ f _ xs + => rlist_rect_cast + ([], []) + (fun x tl partition_tl + => fx <-- f x; + (#ident.pair_rect + @ (λ g d, #ident.bool_rect + @ (λ _, ($x :: $g, $d)) + @ (λ _, ($g, $x :: $d)) + @ $fx) + @ partition_tl)) + xs) + ; make_rewrite + (#pident.List_fold_right @ ??{?? -> ?? -> ??} @ ?? @ ??{list ??}) + (fun _ _ _ f B init A xs + => f <- @castv _ (A -> B -> B)%etype f; + rlist_rect + init + (fun x _ y => f x y) + xs) + ; make_rewrite + (#pident.list_rect @ ??{() -> ??} @ ??{?? -> ?? -> ?? -> ??} @ ??{list ??}) + (fun P Pnil _ _ _ _ Pcons A xs + => Pcons <- @castv _ (A -> base.type.list A -> P -> P) Pcons; + rlist_rect + (Pnil ##tt) + (fun x' xs' rec => Pcons x' (reify_list xs') rec) + xs) + ; make_rewrite + (#pident.list_case @ ??{() -> ??} @ ??{?? -> ?? -> ??} @ []) (fun _ Pnil _ _ _ Pcons => ret (Pnil ##tt)) + ; make_rewrite + (#pident.list_case @ ??{() -> ??} @ ??{?? -> ?? -> ??} @ (?? :: ??)) + (fun _ Pnil _ _ _ Pcons _ x _ xs + => x <- castbe x; xs <- castbe xs; ret (Pcons x xs)) + ; make_rewrite + (#pident.List_map @ ??{?? -> ??} @ ??{list ??}) + (fun _ _ f _ xs + => rlist_rect_cast + [] + (fun x _ fxs => fx <-- f x; fx :: fxs) + xs) + ; make_rewrite + (#pident.List_nth_default @ ?? @ ??{list ??} @ #?ℕ) + (fun _ default _ ls n + => default <- castbe default; + ls <- reflect_list_cps ls; + nth_default default ls n) + ; make_rewrite + (#pident.nat_rect @ ??{() -> ??} @ ??{base.type.nat -> ?? -> ??} @ #?ℕ) + (fun P O_case _ _ S_case n + => S_case <- @castv _ (@type.base base.type base.type.nat -> type.base P -> type.base P) S_case; + ret (nat_rect _ (O_case ##tt) (fun n' rec => rec <-- rec; S_case ##n' rec) n)) + ; make_rewrite + (#pident.List_length @ ??{list ??}) + (fun _ xs => xs <- reflect_list_cps xs; ##(List.length xs)) + ; make_rewrite + (#pident.List_combine @ ??{list ??} @ ??{list ??}) + (fun _ xs _ ys + => xs <- reflect_list_cps xs; + ys <- reflect_list_cps ys; + reify_list (List.map (fun '((x, y)%core) => (x, y)) (List.combine xs ys))) + ; make_rewrite + (#pident.List_update_nth @ #?ℕ @ ??{?? -> ??} @ ??{list ??}) + (fun n _ _ f A ls + => f <- @castv _ (A -> A) f; + ls <- reflect_list_cps ls; + ret + (retv <--- (update_nth + n + (fun x => x <-- x; f x) + (List.map UnderLets.Base ls)); + reify_list retv)) + ; make_rewrite (#?ℤ + ??{ℤ}) (fun z v => v when Z.eqb z 0) + ; make_rewrite (??{ℤ} + #?ℤ ) (fun v z => v when Z.eqb z 0) + ; make_rewrite (#?ℤ + (-??{ℤ})) (fun z v => ##z - v when Z.gtb z 0) + ; make_rewrite ((-??{ℤ}) + #?ℤ ) (fun v z => ##z - v when Z.gtb z 0) + ; make_rewrite (#?ℤ + (-??{ℤ})) (fun z v => -(##((-z)%Z) + v) when Z.ltb z 0) + ; make_rewrite ((-??{ℤ}) + #?ℤ ) (fun v z => -(v + ##((-z)%Z)) when Z.ltb z 0) + ; make_rewrite ((-??{ℤ}) + (-??{ℤ})) (fun x y => -(x + y)) + ; make_rewrite ((-??{ℤ}) + ??{ℤ} ) (fun x y => y - x) + ; make_rewrite ( ??{ℤ} + (-??{ℤ})) (fun x y => x - y) + + ; make_rewrite (#?ℤ - (-??{ℤ})) (fun z v => v when Z.eqb z 0) + ; make_rewrite (#?ℤ - ??{ℤ} ) (fun z v => -v when Z.eqb z 0) + ; make_rewrite (??{ℤ} - #?ℤ ) (fun v z => v when Z.eqb z 0) + ; make_rewrite (#?ℤ - (-??{ℤ})) (fun z v => ##z + v when Z.gtb z 0) + ; make_rewrite (#?ℤ - (-??{ℤ})) (fun z v => v - ##((-z)%Z) when Z.ltb z 0) + ; make_rewrite (#?ℤ - ??{ℤ} ) (fun z v => -(##((-z)%Z) + v) when Z.ltb z 0) + ; make_rewrite ((-??{ℤ}) - #?ℤ ) (fun v z => -(v + ##((-z)%Z)) when Z.gtb z 0) + ; make_rewrite ((-??{ℤ}) - #?ℤ ) (fun v z => ##((-z)%Z) - v when Z.ltb z 0) + ; make_rewrite ( ??{ℤ} - #?ℤ ) (fun v z => v + ##((-z)%Z) when Z.ltb z 0) + ; make_rewrite ((-??{ℤ}) - (-??{ℤ})) (fun x y => y - x) + ; make_rewrite ((-??{ℤ}) - ??{ℤ} ) (fun x y => -(x + y)) + ; make_rewrite ( ??{ℤ} - (-??{ℤ})) (fun x y => x + y) + + ; make_rewrite (#?ℤ * ??{ℤ}) (fun z v => ##0 when Z.eqb z 0) + ; make_rewrite (??{ℤ} * #?ℤ ) (fun v z => ##0 when Z.eqb z 0) + ; make_rewrite (#?ℤ * ??{ℤ}) (fun z v => v when Z.eqb z 1) + ; make_rewrite (??{ℤ} * #?ℤ ) (fun v z => v when Z.eqb z 1) + ; make_rewrite (#?ℤ * (-??{ℤ})) (fun z v => v when Z.eqb z (-1)) + ; make_rewrite ((-??{ℤ}) * #?ℤ ) (fun v z => v when Z.eqb z (-1)) + ; make_rewrite (#?ℤ * ??{ℤ} ) (fun z v => -v when Z.eqb z (-1)) + ; make_rewrite (??{ℤ} * #?ℤ ) (fun v z => -v when Z.eqb z (-1)) + ; make_rewrite (#?ℤ * ??{ℤ} ) (fun z v => -(##((-z)%Z) * v) when Z.ltb z 0) + ; make_rewrite (??{ℤ} * #?ℤ ) (fun v z => -(v * ##((-z)%Z)) when Z.ltb z 0) + ; make_rewrite ((-??{ℤ}) * (-??{ℤ})) (fun x y => x * y) + ; make_rewrite ((-??{ℤ}) * ??{ℤ} ) (fun x y => -(x * y)) + ; make_rewrite ( ??{ℤ} * (-??{ℤ})) (fun x y => -(x * y)) + + ; make_rewrite (??{ℤ} * #?ℤ) (fun x y => x << (Z.log2 y) when Z.eqb y (2^Z.log2 y)) + ; make_rewrite (#?ℤ * ??{ℤ}) (fun y x => x << (Z.log2 y) when Z.eqb y (2^Z.log2 y)) + ; make_rewrite (??{ℤ} / #?ℤ) (fun x y => x >> (Z.log2 y) when Z.eqb y (2^Z.log2 y)) + ; make_rewrite (??{ℤ} mod #?ℤ) (fun x y => #(ident.Z_land (y-1)) @ x when Z.eqb y (2^Z.log2 y)) + ; make_rewrite (-(-??{ℤ})) (fun v => v) + + (** TODO(jadep): These next two are only here for demonstration purposes; remove them once you no longer need it as a template *) + (* if it's a concrete pair, we can opp the second value *) + ; make_rewrite (#pident.Z_neg_snd @ (??{ℤ}, ??{ℤ})) (fun x y => (x, -y)) + (* if it's not a concrete pair, let-bind the pair and negate the second element *) + ; make_rewrite + (#pident.Z_neg_snd @ ??{ℤ * ℤ}) + (fun xy => ret (UnderLets.UnderLet xy (fun xyv => UnderLets.Base (#ident.fst @ $xyv, -(#ident.snd @ $xyv))))) + + ; make_rewrite (#pident.Z_mul_split @ #?ℤ @ #?ℤ @ ??{ℤ}) (fun s xx y => (##0, ##0)%Z when Z.eqb xx 0) + ; make_rewrite (#pident.Z_mul_split @ #?ℤ @ ??{ℤ} @ #?ℤ) (fun s y xx => (##0, ##0)%Z when Z.eqb xx 0) + ; make_rewrite (#pident.Z_mul_split @ #?ℤ @ #?ℤ @ ??{ℤ}) (fun s xx y => (y, ##0)%Z when Z.eqb xx 1) + ; make_rewrite (#pident.Z_mul_split @ #?ℤ @ ??{ℤ} @ #?ℤ) (fun s y xx => (y, ##0)%Z when Z.eqb xx 1) + ; make_rewrite (#pident.Z_mul_split @ #?ℤ @ #?ℤ @ ??{ℤ}) (fun s xx y => (-y, ##0%Z) when Z.eqb xx (-1)) + ; make_rewrite (#pident.Z_mul_split @ #?ℤ @ ??{ℤ} @ #?ℤ) (fun s y xx => (-y, ##0%Z) when Z.eqb xx (-1)) + + ; make_rewrite (#pident.Z_add_get_carry @ #?ℤ @ #?ℤ @ ??{ℤ}) (fun s xx y => (y, ##0%Z) when Z.eqb xx 0) + ; make_rewrite (#pident.Z_add_get_carry @ #?ℤ @ ??{ℤ} @ #?ℤ) (fun s y xx => (y, ##0%Z) when Z.eqb xx 0) + + ; make_rewrite (#pident.Z_add_with_carry @ #?ℤ @ ??{ℤ} @ ??{ℤ}) (fun c x y => x + y when Z.eqb c 0) + + + ; make_rewrite + (#pident.Z_add_with_get_carry @ #?ℤ @ #?ℤ @ #?ℤ @ ??{ℤ}) (fun s cc xx y => (y, ##0) when (cc =? 0) && (xx =? 0)) + ; make_rewrite + (#pident.Z_add_with_get_carry @ #?ℤ @ #?ℤ @ ??{ℤ} @ #?ℤ) (fun s cc y xx => (y, ##0) when (cc =? 0) && (xx =? 0)) + ; make_rewrite (* carry = 0: ADC x y -> ADD x y *) + (#pident.Z_add_with_get_carry @ #?ℤ @ #?ℤ @ ??{ℤ} @ ??{ℤ}) + (fun s cc x y => #(ident.Z_add_get_carry_concrete s) @ x @ y when cc =? 0) + ; make_rewrite (* ADC 0 0 -> (ADX 0 0, 0) *) + (#pident.Z_add_with_get_carry @ #?ℤ @ ??{ℤ} @ #?ℤ @ #?ℤ) + (fun s c xx yy => #ident.Z_add_with_carry @ ##s @ ##xx @ ##yy when (xx =? 0) && (yy =? 0)) + + ; make_rewrite + (#pident.Z_add_get_carry @ #?ℤ @ (-??{ℤ}) @ ??{ℤ}) + (fun s y x => ret (UnderLets.UnderLet + (#(ident.Z_sub_get_borrow_concrete s) @ x @ y) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc))))) + ; make_rewrite + (#pident.Z_add_get_carry @ #?ℤ @ ??{ℤ} @ (-??{ℤ})) + (fun s x y => ret (UnderLets.UnderLet + (#(ident.Z_sub_get_borrow_concrete s) @ x @ y) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc))))) + + + ; make_rewrite + (#pident.Z_add_with_get_carry @ #?ℤ @ (-??{ℤ}) @ (-??{ℤ}) @ ??{ℤ}) + (fun s c y x => ret (UnderLets.UnderLet + (#(ident.Z_sub_with_get_borrow_concrete s) @ c @ x @ y) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc))))) + ; make_rewrite + (#pident.Z_add_with_get_carry @ #?ℤ @ (-??{ℤ}) @ ??{ℤ} @ (-??{ℤ})) + (fun s c x y => ret (UnderLets.UnderLet + (#(ident.Z_sub_with_get_borrow_concrete s) @ c @ x @ y) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc))))) + + ; make_rewrite + (#pident.Z_add_get_carry_concrete @ (-??{ℤ}) @ ??{ℤ}) + (fun s y x => ret (UnderLets.UnderLet + (#(ident.Z_sub_get_borrow_concrete s) @ x @ y) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc))))) + ; make_rewrite + (#pident.Z_add_get_carry_concrete @ ??{ℤ} @ (-??{ℤ})) + (fun s x y => ret (UnderLets.UnderLet + (#(ident.Z_sub_get_borrow_concrete s) @ x @ y) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc))))) + ; make_rewrite + (#pident.Z_add_get_carry_concrete @ #?ℤ @ ??{ℤ}) + (fun s yy x => ret (UnderLets.UnderLet + (#(ident.Z_sub_get_borrow_concrete s) @ x @ ##(-yy)%Z) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc)))) + when yy <=? 0) + ; make_rewrite + (#pident.Z_add_get_carry_concrete @ ??{ℤ} @ #?ℤ) + (fun s x yy => ret (UnderLets.UnderLet + (#(ident.Z_sub_get_borrow_concrete s) @ x @ ##(-yy)%Z) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc)))) + when yy <=? 0) + + + ; make_rewrite + (#pident.Z_add_with_get_carry_concrete @ (-??{ℤ}) @ (-??{ℤ}) @ ??{ℤ}) + (fun s c y x => ret (UnderLets.UnderLet + (#(ident.Z_sub_with_get_borrow_concrete s) @ c @ x @ y) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc))))) + ; make_rewrite + (#pident.Z_add_with_get_carry_concrete @ (-??{ℤ}) @ ??{ℤ} @ (-??{ℤ})) + (fun s c x y => ret (UnderLets.UnderLet + (#(ident.Z_sub_with_get_borrow_concrete s) @ c @ x @ y) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc))))) + ; make_rewrite + (#pident.Z_add_with_get_carry_concrete @ (-??{ℤ}) @ #?ℤ @ ??{ℤ}) + (fun s c yy x => ret (UnderLets.UnderLet + (#(ident.Z_sub_with_get_borrow_concrete s) @ c @ x @ ##(-yy)%Z) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc)))) + when yy <=? 0) + ; make_rewrite + (#pident.Z_add_with_get_carry_concrete @ (-??{ℤ}) @ ??{ℤ} @ #?ℤ) + (fun s c x yy => ret (UnderLets.UnderLet + (#(ident.Z_sub_with_get_borrow_concrete s) @ c @ x @ ##(-yy)%Z) + (fun vc => UnderLets.Base (#ident.fst @ $vc, -(#ident.snd @ $vc)))) + when yy <=? 0) + + ; make_rewrite (#pident.Z_add_get_carry_concrete @ #?ℤ @ ??{ℤ}) (fun s xx y => (y, ##0) when xx =? 0) + ; make_rewrite (#pident.Z_add_get_carry_concrete @ ??{ℤ} @ #?ℤ) (fun s y xx => (y, ##0) when xx =? 0) + + (** XXX TODO: Do we still need the _concrete versions? *) + ; make_rewrite (#pident.Z_mul_split @ #?ℤ @ ??{ℤ} @ ??{ℤ}) (fun s x y => #(ident.Z_mul_split_concrete s) @ x @ y) + ; make_rewrite (#pident.Z_rshi @ #?ℤ @ ??{ℤ} @ ??{ℤ} @ #?ℤ) (fun x y z a => #(ident.Z_rshi_concrete x a) @ y @ z) + ; make_rewrite (#pident.Z_cc_m @ #?ℤ @ ??{ℤ}) (fun x y => #(ident.Z_cc_m_concrete x) @ y) + ; make_rewrite (#pident.Z_add_get_carry @ #?ℤ @ ??{ℤ} @ ??{ℤ}) (fun s x y => #(ident.Z_add_get_carry_concrete s) @ x @ y) + ; make_rewrite (#pident.Z_add_with_get_carry @ #?ℤ @ ??{ℤ} @ ??{ℤ} @ ??{ℤ}) (fun s c x y => #(ident.Z_add_with_get_carry_concrete s) @ c @ x @ y) + ; make_rewrite (#pident.Z_sub_get_borrow @ #?ℤ @ ??{ℤ} @ ??{ℤ}) (fun s x y => #(ident.Z_sub_get_borrow_concrete s) @ x @ y) + ; make_rewrite (#pident.Z_sub_with_get_borrow @ #?ℤ @ ??{ℤ} @ ??{ℤ} @ ??{ℤ}) (fun s x y b => #(ident.Z_sub_with_get_borrow_concrete s) @ x @ y @ b) + + ; make_rewrite_step (* _step, so that if one of the arguments is concrete, we automatically get the rewrite rule for [Z_cast] applying to it *) + (#pident.Z_cast2 @ (??{ℤ}, ??{ℤ})) (fun r x y => (#(ident.Z_cast (fst r)) @ $x, #(ident.Z_cast (snd r)) @ $y)) + ]%list%pattern%cps%option%under_lets%Z%bool. + + Definition dtree' + := Eval compute in @compile_rewrites ident var pattern.ident pattern.ident.arg_types pattern.ident.ident_beq 100 rewrite_rules. + Definition dtree : decision_tree + := Eval compute in invert_Some dtree'. + Definition default_fuel := Eval compute in List.length rewrite_rules. + + Import PrimitiveHList. + (* N.B. The [combine_hlist] call MUST eta-expand + [pr2_rewrite_rules]. That is, it MUST NOT block reduction of + the resulting list of cons cells on the pair-structure of + [pr2_rewrite_rules]. This is required so that we can use + [cbv -] to unfold the entire discrimination tree evaluation, + including choosing which rewrite rule to apply and binding + its arguments, without unfolding any of the identifiers used + to define the replacement value. (The symptom of messing + this up is that the [cbv -] will run out of memory when + trying to reduce things.) We accomplish this by making + [hlist] based on a primitive [prod] type with judgmental η, + so that matching on its structure never blocks reduction. *) + Definition split_rewrite_rules := Eval cbv [split_list projT1 projT2 rewrite_rules] in split_list rewrite_rules. + Definition pr1_rewrite_rules := Eval hnf in projT1 split_rewrite_rules. + Definition pr2_rewrite_rules := Eval hnf in projT2 split_rewrite_rules. + Definition all_rewrite_rules := combine_hlist (P:=rewrite_ruleTP) pr1_rewrite_rules pr2_rewrite_rules. + + Definition rewrite_head0 do_again {t} (idc : ident t) : value_with_lets t + := @assemble_identifier_rewriters dtree all_rewrite_rules do_again t idc. + + Section fancy. + Context (invert_low invert_high : Z (*log2wordmax*) -> Z -> option Z). + Definition fancy_rewrite_rules : rewrite_rulesT + := [ + (* +(Z.add_get_carry_concrete 2^256) @@ (?x, ?y << 128) --> (add 128) @@ (x, y) +(Z.add_get_carry_concrete 2^256) @@ (?x << 128, ?y) --> (add 128) @@ (y, x) +(Z.add_get_carry_concrete 2^256) @@ (?x, ?y >> 128) --> (add (- 128)) @@ (x, y) +(Z.add_get_carry_concrete 2^256) @@ (?x >> 128, ?y) --> (add (- 128)) @@ (y, x) +(Z.add_get_carry_concrete 2^256) @@ (?x, ?y) --> (add 0) @@ (y, x) +*) + make_rewrite + (#pident.Z_add_get_carry_concrete @ ??{ℤ} @ (#pident.Z_shiftl @ ??{ℤ})) + (fun s x offset y => #(ident.fancy_add (Z.log2 s) offset) @ (x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_add_get_carry_concrete @ (#pident.Z_shiftl @ ??{ℤ}) @ ??{ℤ}) + (fun s offset y x => #(ident.fancy_add (Z.log2 s) offset) @ (x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_add_get_carry_concrete @ ??{ℤ} @ (#pident.Z_shiftr @ ??{ℤ})) + (fun s x offset y => #(ident.fancy_add (Z.log2 s) (-offset)) @ (x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_add_get_carry_concrete @ (#pident.Z_shiftr @ ??{ℤ}) @ ??{ℤ}) + (fun s offset y x => #(ident.fancy_add (Z.log2 s) (-offset)) @ (x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_add_get_carry_concrete @ ??{ℤ} @ ??{ℤ}) + (fun s x y => #(ident.fancy_add (Z.log2 s) 0) @ (x, y) when s =? 2^Z.log2 s) +(* +(Z.add_with_get_carry_concrete 2^256) @@ (?c, ?x, ?y << 128) --> (addc 128) @@ (c, x, y) +(Z.add_with_get_carry_concrete 2^256) @@ (?c, ?x << 128, ?y) --> (addc 128) @@ (c, y, x) +(Z.add_with_get_carry_concrete 2^256) @@ (?c, ?x, ?y >> 128) --> (addc (- 128)) @@ (c, x, y) +(Z.add_with_get_carry_concrete 2^256) @@ (?c, ?x >> 128, ?y) --> (addc (- 128)) @@ (c, y, x) +(Z.add_with_get_carry_concrete 2^256) @@ (?c, ?x, ?y) --> (addc 0) @@ (c, y, x) + *) + ; make_rewrite + (#pident.Z_add_with_get_carry_concrete @ ??{ℤ} @ ??{ℤ} @ (#pident.Z_shiftl @ ??{ℤ})) + (fun s c x offset y => #(ident.fancy_addc (Z.log2 s) offset) @ (c, x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_add_with_get_carry_concrete @ ??{ℤ} @ (#pident.Z_shiftl @ ??{ℤ}) @ ??{ℤ}) + (fun s c offset y x => #(ident.fancy_addc (Z.log2 s) offset) @ (c, x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_add_with_get_carry_concrete @ ??{ℤ} @ ??{ℤ} @ (#pident.Z_shiftr @ ??{ℤ})) + (fun s c x offset y => #(ident.fancy_addc (Z.log2 s) (-offset)) @ (c, x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_add_with_get_carry_concrete @ ??{ℤ} @ (#pident.Z_shiftr @ ??{ℤ}) @ ??{ℤ}) + (fun s c offset y x => #(ident.fancy_addc (Z.log2 s) (-offset)) @ (c, x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_add_with_get_carry_concrete @ ??{ℤ} @ ??{ℤ} @ ??{ℤ}) + (fun s c x y => #(ident.fancy_addc (Z.log2 s) 0) @ (c, x, y) when s =? 2^Z.log2 s) +(* +(Z.sub_get_borrow_concrete 2^256) @@ (?x, ?y << 128) --> (sub 128) @@ (x, y) +(Z.sub_get_borrow_concrete 2^256) @@ (?x, ?y >> 128) --> (sub (- 128)) @@ (x, y) +(Z.sub_get_borrow_concrete 2^256) @@ (?x, ?y) --> (sub 0) @@ (y, x) + *) + ; make_rewrite + (#pident.Z_sub_get_borrow_concrete @ ??{ℤ} @ (#pident.Z_shiftl @ ??{ℤ})) + (fun s x offset y => #(ident.fancy_sub (Z.log2 s) offset) @ (x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_sub_get_borrow_concrete @ ??{ℤ} @ (#pident.Z_shiftr @ ??{ℤ})) + (fun s x offset y => #(ident.fancy_sub (Z.log2 s) (-offset)) @ (x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_sub_get_borrow_concrete @ ??{ℤ} @ ??{ℤ}) + (fun s x y => #(ident.fancy_sub (Z.log2 s) 0) @ (x, y) when s =? 2^Z.log2 s) +(* +(Z.sub_with_get_borrow_concrete 2^256) @@ (?c, ?x, ?y << 128) --> (subb 128) @@ (c, x, y) +(Z.sub_with_get_borrow_concrete 2^256) @@ (?c, ?x, ?y >> 128) --> (subb (- 128)) @@ (c, x, y) +(Z.sub_with_get_borrow_concrete 2^256) @@ (?c, ?x, ?y) --> (subb 0) @@ (c, y, x) + *) + ; make_rewrite + (#pident.Z_sub_with_get_borrow_concrete @ ??{ℤ} @ ??{ℤ} @ (#pident.Z_shiftl @ ??{ℤ})) + (fun s b x offset y => #(ident.fancy_subb (Z.log2 s) offset) @ (b, x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_sub_with_get_borrow_concrete @ ??{ℤ} @ ??{ℤ} @ (#pident.Z_shiftr @ ??{ℤ})) + (fun s b x offset y => #(ident.fancy_subb (Z.log2 s) (-offset)) @ (b, x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_sub_with_get_borrow_concrete @ ??{ℤ} @ ??{ℤ} @ ??{ℤ}) + (fun s b x y => #(ident.fancy_subb (Z.log2 s) 0) @ (b, x, y) when s =? 2^Z.log2 s) + (*(Z.rshi_concrete 2^256 ?n) @@ (?c, ?x, ?y) --> (rshi n) @@ (x, y)*) + ; make_rewrite + (#pident.Z_rshi_concrete @ ??{ℤ} @ ??{ℤ}) + (fun '((s, n)%core) x y => #(ident.fancy_rshi (Z.log2 s) n) @ (x, y) when s =? 2^Z.log2 s) +(* +Z.zselect @@ (Z.cc_m_concrete 2^256 ?c, ?x, ?y) --> selm @@ (c, x, y) +Z.zselect @@ (?c &' 1, ?x, ?y) --> sell @@ (c, x, y) +Z.zselect @@ (?c, ?x, ?y) --> selc @@ (c, x, y) + *) + ; make_rewrite + (#pident.Z_zselect @ (#pident.Z_cc_m_concrete @ ??{ℤ}) @ ??{ℤ} @ ??{ℤ}) + (fun s c x y => #(ident.fancy_selm (Z.log2 s)) @ (c, x, y) when s =? 2^Z.log2 s) + ; make_rewrite + (#pident.Z_zselect @ (#pident.Z_land @ ??{ℤ}) @ ??{ℤ} @ ??{ℤ}) + (fun mask c x y => #ident.fancy_sell @ (c, x, y) when mask =? 1) + ; make_rewrite + (#pident.Z_zselect @ ??{ℤ} @ ??{ℤ} @ ??{ℤ}) + (fun c x y => #ident.fancy_selc @ (c, x, y)) +(*Z.add_modulo @@ (?x, ?y, ?m) --> addm @@ (x, y, m)*) + ; make_rewrite + (#pident.Z_add_modulo @ ??{ℤ} @ ??{ℤ} @ ??{ℤ}) + (fun x y m => #ident.fancy_addm @ (x, y, m)) +(* +Z.mul @@ (?x &' (2^128-1), ?y &' (2^128-1)) --> mulll @@ (x, y) +Z.mul @@ (?x &' (2^128-1), ?y >> 128) --> mullh @@ (x, y) +Z.mul @@ (?x >> 128, ?y &' (2^128-1)) --> mulhl @@ (x, y) +Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) + *) + (* literal on left *) + ; make_rewrite + (#?ℤ * (#pident.Z_land @ ??{ℤ})) + (fun x mask y => let s := (2*Z.log2_up mask)%Z in x <---- invert_low s x; #(ident.fancy_mulll s) @ (##x, y) when (mask =? 2^(s/2)-1)) + ; make_rewrite + (#?ℤ * (#pident.Z_shiftr @ ??{ℤ})) + (fun x offset y => let s := (2*offset)%Z in x <---- invert_low s x; #(ident.fancy_mullh s) @ (##x, y)) + ; make_rewrite + (#?ℤ * (#pident.Z_land @ ??{ℤ})) + (fun x mask y => let s := (2*Z.log2_up mask)%Z in x <---- invert_high s x; #(ident.fancy_mulhl s) @ (##x, y) when mask =? 2^(s/2)-1) + ; make_rewrite + (#?ℤ * (#pident.Z_shiftr @ ??{ℤ})) + (fun x offset y => let s := (2*offset)%Z in x <---- invert_high s x; #(ident.fancy_mulhh s) @ (##x, y)) + + (* literal on right *) + ; make_rewrite + ((#pident.Z_land @ ??{ℤ}) * #?ℤ) + (fun mask x y => let s := (2*Z.log2_up mask)%Z in y <---- invert_low s y; #(ident.fancy_mulll s) @ (x, ##y) when (mask =? 2^(s/2)-1)) + ; make_rewrite + ((#pident.Z_land @ ??{ℤ}) * #?ℤ) + (fun mask x y => let s := (2*Z.log2_up mask)%Z in y <---- invert_high s y; #(ident.fancy_mullh s) @ (x, ##y) when mask =? 2^(s/2)-1) + ; make_rewrite + ((#pident.Z_shiftr @ ??{ℤ}) * #?ℤ) + (fun offset x y => let s := (2*offset)%Z in y <---- invert_low s y; #(ident.fancy_mulhl s) @ (x, ##y)) + ; make_rewrite + ((#pident.Z_shiftr @ ??{ℤ}) * #?ℤ) + (fun offset x y => let s := (2*offset)%Z in y <---- invert_high s y; #(ident.fancy_mulhh s) @ (x, ##y)) + + (* no literal *) + ; make_rewrite + ((#pident.Z_land @ ??{ℤ}) * (#pident.Z_land @ ??{ℤ})) + (fun mask1 x mask2 y => let s := (2*Z.log2_up mask1)%Z in #(ident.fancy_mulll s) @ (x, y) when (mask1 =? 2^(s/2)-1) && (mask2 =? 2^(s/2)-1)) + ; make_rewrite + ((#pident.Z_land @ ??{ℤ}) * (#pident.Z_shiftr @ ??{ℤ})) + (fun mask x offset y => let s := (2*offset)%Z in #(ident.fancy_mullh s) @ (x, y) when mask =? 2^(s/2)-1) + ; make_rewrite + ((#pident.Z_shiftr @ ??{ℤ}) * (#pident.Z_land @ ??{ℤ})) + (fun offset x mask y => let s := (2*offset)%Z in #(ident.fancy_mulhl s) @ (x, y) when mask =? 2^(s/2)-1) + ; make_rewrite + ((#pident.Z_shiftr @ ??{ℤ}) * (#pident.Z_shiftr @ ??{ℤ})) + (fun offset1 x offset2 y => let s := (2*offset1)%Z in #(ident.fancy_mulhh s) @ (x, y) when offset1 =? offset2) + + ]%list%pattern%cps%option%under_lets%Z%bool. + + Definition fancy_dtree' + := Eval compute in @compile_rewrites ident var pattern.ident pattern.ident.arg_types pattern.ident.ident_beq 100 fancy_rewrite_rules. + Definition fancy_dtree : decision_tree + := Eval compute in invert_Some fancy_dtree'. + Definition fancy_default_fuel := Eval compute in List.length fancy_rewrite_rules. + + Import PrimitiveHList. + Definition fancy_split_rewrite_rules := Eval cbv [split_list projT1 projT2 fancy_rewrite_rules] in split_list fancy_rewrite_rules. + Definition fancy_pr1_rewrite_rules := Eval hnf in projT1 fancy_split_rewrite_rules. + Definition fancy_pr2_rewrite_rules := Eval hnf in projT2 fancy_split_rewrite_rules. + Definition fancy_all_rewrite_rules := combine_hlist (P:=rewrite_ruleTP) fancy_pr1_rewrite_rules fancy_pr2_rewrite_rules. + + Definition fancy_rewrite_head0 do_again {t} (idc : ident t) : value_with_lets t + := @assemble_identifier_rewriters fancy_dtree fancy_all_rewrite_rules do_again t idc. + End fancy. + End with_var. + + Section red_fancy. + Context (invert_low invert_high : Z (*log2wordmax*) -> Z -> option Z) + {var : type.type base.type -> Type} + (do_again : forall t : base.type, @expr base.type ident (@Compile.value base.type ident var) (type.base t) + -> UnderLets.UnderLets base.type ident var (@expr base.type ident var (type.base t))) + {t} (idc : ident t). + + Time Let rewrite_head1 + := Eval cbv -[fancy_pr2_rewrite_rules + base.interp base.try_make_transport_cps + type.try_make_transport_cps type.try_transport_cps + UnderLets.splice UnderLets.to_expr + Compile.reflect Compile.reify Compile.reify_and_let_binds_cps UnderLets.reify_and_let_binds_base_cps + Compile.value' SubstVarLike.is_var_fst_snd_pair_opp + ] in @fancy_rewrite_head0 var invert_low invert_high do_again t idc. + (* Finished transaction in 1.434 secs (1.432u,0.s) (successful) *) + + Time Local Definition fancy_rewrite_head2 + := Eval cbv [id + rewrite_head1 fancy_pr2_rewrite_rules + projT1 projT2 + cpsbind cpscall cps_option_bind cpsreturn + pattern.ident.arg_types + Compile.app_binding_data + Compile.app_pbase_type_interp_cps + Compile.app_ptype_interp_cps + Compile.bind_base_cps + Compile.bind_data_cps + Compile.binding_dataT + Compile.bind_value_cps + Compile.eval_decision_tree + Compile.eval_rewrite_rules + Compile.expr_of_rawexpr + Compile.lift_pbase_type_interp_cps + Compile.lift_ptype_interp_cps + Compile.lift_with_bindings + Compile.pbase_type_interp_cps + Compile.ptype_interp + Compile.ptype_interp_cps + (*Compile.reflect*) + (*Compile.reify*) + Compile.reveal_rawexpr_cps + Compile.rValueOrExpr + Compile.swap_list + Compile.type_of_rawexpr + Compile.value + (*Compile.value'*) + Compile.value_of_rawexpr + Compile.value_with_lets + Compile.with_bindingsT + ident.smart_Literal + type.try_transport_cps + rlist_rect rlist_rect_cast rwhen + ] in rewrite_head1. + (* Finished transaction in 1.347 secs (1.343u,0.s) (successful) *) + + Local Arguments base.try_make_base_transport_cps _ !_ !_. + Local Arguments base.try_make_transport_cps _ !_ !_. + Local Arguments type.try_make_transport_cps _ _ _ !_ !_. + Local Arguments fancy_rewrite_head2 / . + + Time Definition fancy_rewrite_head + := Eval cbn [id + fancy_rewrite_head2 + cpsbind cpscall cps_option_bind cpsreturn + Compile.reify Compile.reify_and_let_binds_cps Compile.reflect Compile.value' + UnderLets.reify_and_let_binds_base_cps + UnderLets.splice UnderLets.splice_list UnderLets.to_expr + base.interp base.base_interp + type.try_make_transport_cps base.try_make_transport_cps base.try_make_base_transport_cps + PrimitiveProd.Primitive.fst PrimitiveProd.Primitive.snd Datatypes.fst Datatypes.snd + ] in fancy_rewrite_head2. + (* Finished transaction in 13.298 secs (13.283u,0.s) (successful) *) + + Redirect "/tmp/fancy_rewrite_head" Print fancy_rewrite_head. + End red_fancy. + + Section red. + Context {var : type.type base.type -> Type} + (do_again : forall t : base.type, @expr base.type ident (@Compile.value base.type ident var) (type.base t) + -> UnderLets.UnderLets base.type ident var (@expr base.type ident var (type.base t))) + {t} (idc : ident t). + + Time Let rewrite_head1 + := Eval cbv -[pr2_rewrite_rules + base.interp base.try_make_transport_cps + type.try_make_transport_cps type.try_transport_cps + UnderLets.splice UnderLets.to_expr + Compile.reflect UnderLets.reify_and_let_binds_base_cps Compile.reify Compile.reify_and_let_binds_cps + Compile.value' + SubstVarLike.is_var_fst_snd_pair_opp + ] in @rewrite_head0 var do_again t idc. + (* Finished transaction in 16.593 secs (16.567u,0.s) (successful) *) + + Time Local Definition rewrite_head2 + := Eval cbv [id + rewrite_head1 pr2_rewrite_rules + projT1 projT2 + cpsbind cpscall cps_option_bind cpsreturn + pattern.ident.arg_types + Compile.app_binding_data + Compile.app_pbase_type_interp_cps + Compile.app_ptype_interp_cps + Compile.bind_base_cps + Compile.bind_data_cps + Compile.binding_dataT + Compile.bind_value_cps + Compile.eval_decision_tree + Compile.eval_rewrite_rules + Compile.expr_of_rawexpr + Compile.lift_pbase_type_interp_cps + Compile.lift_ptype_interp_cps + Compile.lift_with_bindings + Compile.pbase_type_interp_cps + Compile.ptype_interp + Compile.ptype_interp_cps + (*Compile.reflect*) + (*Compile.reify*) + Compile.reveal_rawexpr_cps + Compile.rValueOrExpr + Compile.swap_list + Compile.type_of_rawexpr + Compile.value + (*Compile.value'*) + Compile.value_of_rawexpr + Compile.value_with_lets + Compile.with_bindingsT + ident.smart_Literal + type.try_transport_cps + rlist_rect rlist_rect_cast rwhen + ] in rewrite_head1. + (* Finished transaction in 29.683 secs (29.592u,0.048s) (successful) *) + + Local Arguments base.try_make_base_transport_cps _ !_ !_. + Local Arguments base.try_make_transport_cps _ !_ !_. + Local Arguments type.try_make_transport_cps _ _ _ !_ !_. + Local Arguments rewrite_head2 / . + + Time Definition rewrite_head + := Eval cbn [id + rewrite_head2 + cpsbind cpscall cps_option_bind cpsreturn + Compile.reify Compile.reify_and_let_binds_cps Compile.reflect Compile.value' + UnderLets.reify_and_let_binds_base_cps + UnderLets.splice UnderLets.splice_list UnderLets.to_expr + base.interp base.base_interp + type.try_make_transport_cps base.try_make_transport_cps base.try_make_base_transport_cps + PrimitiveProd.Primitive.fst PrimitiveProd.Primitive.snd Datatypes.fst Datatypes.snd + ] in rewrite_head2. + (* Finished transaction in 16.561 secs (16.54u,0.s) (successful) *) + + Redirect "/tmp/rewrite_head" Print rewrite_head. + End red. + + Definition Rewrite {t} (e : expr.Expr (ident:=ident) t) : expr.Expr (ident:=ident) t + := @Compile.Rewrite (@rewrite_head) default_fuel t e. + Definition RewriteToFancy + (invert_low invert_high : Z (*log2wordmax*) -> Z -> option Z) + {t} (e : expr.Expr (ident:=ident) t) : expr.Expr (ident:=ident) t + := @Compile.Rewrite (fun var _ => @fancy_rewrite_head invert_low invert_high var) fancy_default_fuel t e. + End RewriteRules. + + Import defaults. + + Definition PartialEvaluate {t} (e : Expr t) : Expr t := RewriteRules.Rewrite e. +End Compilers. |