aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/NewPipeline/Rewriter.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Experiments/NewPipeline/Rewriter.v')
-rw-r--r--src/Experiments/NewPipeline/Rewriter.v134
1 files changed, 53 insertions, 81 deletions
diff --git a/src/Experiments/NewPipeline/Rewriter.v b/src/Experiments/NewPipeline/Rewriter.v
index dd898bc3e..f9f6e22ac 100644
--- a/src/Experiments/NewPipeline/Rewriter.v
+++ b/src/Experiments/NewPipeline/Rewriter.v
@@ -8,7 +8,6 @@ Require Crypto.Util.PrimitiveHList.
Require Import Crypto.Experiments.NewPipeline.Language.
Require Import Crypto.Experiments.NewPipeline.UnderLets.
Require Import Crypto.Experiments.NewPipeline.GENERATEDIdentifiersWithoutTypes.
-Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Notations.
Import ListNotations. Local Open Scope bool_scope. Local Open Scope Z_scope.
@@ -426,49 +425,42 @@ Module Compilers.
=> match ctx with
| nil => cont None ctx None
| ctx0 :: ctx'
- => let default := fun 'tt => @eval_decision_tree T ctx default_case cont in
- let bind_default_in f
- := match default_case with
- | Failure => f default
- | _ => (dlet default := default in f default)
- end in
- bind_default_in
- (fun default
- => reveal_rawexpr_cps
- ctx0 _
- (fun ctx0'
- => match ctx0' with
- | rIdent t idc t' alt
- => fold_right
- (fun '(pidc, icase) default 'tt
- => match invert_bind_args _ idc pidc with
- | Some args
- => @eval_decision_tree
- T ctx' icase
- (fun k ctx''
- => cont k (rIdent (pident_to_typed pidc args) alt :: ctx''))
- | None => default tt
- end)
- default
- icases
- tt
- | rApp f x t alt
- => match app_case with
- | Some app_case
+ => let default _ := @eval_decision_tree T ctx default_case cont in
+ reveal_rawexpr_cps
+ ctx0 _
+ (fun ctx0'
+ => match ctx0' with
+ | rIdent t idc t' alt
+ => fold_right
+ (fun '(pidc, icase) default 'tt
+ => match invert_bind_args _ idc pidc with
+ | Some args
=> @eval_decision_tree
- T (f :: x :: ctx') app_case
+ T ctx' icase
(fun k ctx''
- => match ctx'' with
- | f' :: x' :: ctx'''
- => cont k (rApp f' x' alt :: ctx''')
- | _ => cont None ctx
- end)
+ => cont k (rIdent (pident_to_typed pidc args) alt :: ctx''))
| None => default tt
- end
- | rExpr t e
- | rValue t e
- => default tt
- end))
+ end)
+ default
+ icases
+ tt
+ | rApp f x t alt
+ => match app_case with
+ | Some app_case
+ => @eval_decision_tree
+ T (f :: x :: ctx') app_case
+ (fun k ctx''
+ => match ctx'' with
+ | f' :: x' :: ctx'''
+ => cont k (rApp f' x' alt :: ctx''')
+ | _ => cont None ctx
+ end)
+ | None => default tt
+ end
+ | rExpr t e
+ | rValue t e
+ => default tt
+ end)
end
| Swap i d'
=> match swap_list 0 i ctx with
@@ -507,8 +499,7 @@ Module Compilers.
(rew : rewrite_rulesT)
(e : rawexpr)
: UnderLets (expr (type_of_rawexpr e))
- := dlet default := UnderLets.Base (expr_of_rawexpr e) in
- eval_decision_tree
+ := eval_decision_tree
(e::nil) d
(fun k ctx default_on_rewrite_failure
=> match k, ctx return UnderLets (expr (type_of_rawexpr e)) with
@@ -532,18 +523,18 @@ Module Compilers.
=> match fv', default_on_rewrite_failure with
| Some fv'', _ => UnderLets.Base fv''
| None, Some default => default tt
- | None, None => default
+ | None, None => UnderLets.Base (expr_of_rawexpr e)
end))%under_lets
| None => match default_on_rewrite_failure with
| Some default => default tt
- | None => default
+ | None => UnderLets.Base (expr_of_rawexpr e)
end
end)
- | None => default
+ | None => UnderLets.Base (expr_of_rawexpr e)
end)
- | None => default
+ | None => UnderLets.Base (expr_of_rawexpr e)
end
- | _, _ => default
+ | _, _ => UnderLets.Base (expr_of_rawexpr e)
end).
Local Notation enumerate ls
@@ -567,20 +558,12 @@ Module Compilers.
end)
(enumerate p).
- Definition starts_with_wildcard : nat * list pattern -> bool
- := fun '(_, p) => match p with
- | pattern.Wildcard _::_ => true
- | _ => false
- end.
-
- Definition not_starts_with_wildcard : nat * list pattern -> bool
- := fun p => negb (starts_with_wildcard p).
-
Definition filter_pattern_wildcard (p : list (nat * list pattern)) : list (nat * list pattern)
- := filter starts_with_wildcard p.
-
- Definition split_at_first_pattern_wildcard (p : list (nat * list pattern)) : list (nat * list pattern) * list (nat * list pattern)
- := (take_while not_starts_with_wildcard p, drop_while not_starts_with_wildcard p).
+ := filter (fun '(_, p) => match p with
+ | pattern.Wildcard _::_ => true
+ | _ => false
+ end)
+ p.
Fixpoint get_unique_pattern_ident' (p : list (nat * list pattern)) (so_far : list pident) : list pident
:= match p with
@@ -609,23 +592,25 @@ Module Compilers.
end)
p.
- Definition filter_pattern_app (p : nat * list pattern) : option (nat * list pattern)
+ Definition refine_pattern_app (p : nat * list pattern) : option (nat * list pattern)
:= match p with
+ | (n, pattern.Wildcard d::ps)
+ => Some (n, (??{?? -> d} :: ?? :: ps)%list%pattern)
| (n, pattern.App f x :: ps)
=> Some (n, f :: x :: ps)
| (_, pattern.Ident _::_)
- | (_, pattern.Wildcard _::_)
| (_, nil)
=> None
end.
- Definition filter_pattern_pident (pidc : pident) (p : nat * list pattern) : option (nat * list pattern)
+ Definition refine_pattern_pident (pidc : pident) (p : nat * list pattern) : option (nat * list pattern)
:= match p with
+ | (n, pattern.Wildcard _::ps)
+ => Some (n, ps)
| (n, pattern.Ident pidc'::ps)
=> if pident_beq pidc pidc'
then Some (n, ps)
else None
- | (_, pattern.Wildcard _::_)
| (_, pattern.App _ _::_)
| (_, nil)
=> None
@@ -643,14 +628,13 @@ Module Compilers.
=> (onfailure <- compile_rewrites ps;
Some (TryLeaf n1 onfailure))
| Some Datatypes.O
- => let '(pattern_matrix, default_pattern_matrix) := split_at_first_pattern_wildcard pattern_matrix in
- default_case <- compile_rewrites default_pattern_matrix;
+ => default_case <- compile_rewrites (filter_pattern_wildcard pattern_matrix);
app_case <- (if contains_pattern_app pattern_matrix
- then option_map Some (compile_rewrites (Option.List.map filter_pattern_app pattern_matrix))
+ then option_map Some (compile_rewrites (Option.List.map refine_pattern_app pattern_matrix))
else Some None);
let pidcs := get_unique_pattern_ident pattern_matrix in
let icases := Option.List.map
- (fun pidc => option_map (pair pidc) (compile_rewrites (Option.List.map (filter_pattern_pident pidc) pattern_matrix)))
+ (fun pidc => option_map (pair pidc) (compile_rewrites (Option.List.map (refine_pattern_pident pidc) pattern_matrix)))
pidcs in
Some (Switch icases app_case default_case)
| Some i
@@ -1740,7 +1724,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y)
:= Eval cbv -[fancy_pr2_rewrite_rules
base.interp base.try_make_transport_cps
type.try_make_transport_cps type.try_transport_cps
- Let_In
UnderLets.splice UnderLets.to_expr
Compile.reflect Compile.reify Compile.reify_and_let_binds_cps UnderLets.reify_and_let_binds_base_cps
Compile.value' SubstVarLike.is_var_fst_snd_pair_opp
@@ -1804,9 +1787,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y)
] in fancy_rewrite_head2.
(* Finished transaction in 13.298 secs (13.283u,0.s) (successful) *)
- Local Set Printing Depth 1000000.
- Local Set Printing Width 200.
- Local Notation "'llet' x := v 'in' f" := (Let_In v (fun x => f)).
Redirect "/tmp/fancy_rewrite_head" Print fancy_rewrite_head.
End red_fancy.
@@ -1820,7 +1800,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y)
:= Eval cbv -[nbe_pr2_rewrite_rules
base.interp base.try_make_transport_cps
type.try_make_transport_cps type.try_transport_cps
- Let_In
UnderLets.splice UnderLets.to_expr
Compile.reflect UnderLets.reify_and_let_binds_base_cps Compile.reify Compile.reify_and_let_binds_cps
Compile.value'
@@ -1885,9 +1864,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y)
] in nbe_rewrite_head2.
(* Finished transaction in 16.561 secs (16.54u,0.s) (successful) *)
- Local Set Printing Depth 1000000.
- Local Set Printing Width 200.
- Local Notation "'llet' x := v 'in' f" := (Let_In v (fun x => f)).
Redirect "/tmp/nbe_rewrite_head" Print nbe_rewrite_head.
End red_nbe.
@@ -1901,7 +1877,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y)
:= Eval cbv -[arith_pr2_rewrite_rules
base.interp base.try_make_transport_cps
type.try_make_transport_cps type.try_transport_cps
- Let_In
UnderLets.splice UnderLets.to_expr
Compile.reflect UnderLets.reify_and_let_binds_base_cps Compile.reify Compile.reify_and_let_binds_cps
Compile.value'
@@ -1966,9 +1941,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y)
] in arith_rewrite_head2.
(* Finished transaction in 16.561 secs (16.54u,0.s) (successful) *)
- Local Set Printing Depth 1000000.
- Local Set Printing Width 200.
- Local Notation "'llet' x := v 'in' f" := (Let_In v (fun x => f)).
Redirect "/tmp/arith_rewrite_head" Print arith_rewrite_head.
End red_arith.