diff options
Diffstat (limited to 'src/Experiments/NewPipeline/Rewriter.v')
-rw-r--r-- | src/Experiments/NewPipeline/Rewriter.v | 134 |
1 files changed, 53 insertions, 81 deletions
diff --git a/src/Experiments/NewPipeline/Rewriter.v b/src/Experiments/NewPipeline/Rewriter.v index dd898bc3e..f9f6e22ac 100644 --- a/src/Experiments/NewPipeline/Rewriter.v +++ b/src/Experiments/NewPipeline/Rewriter.v @@ -8,7 +8,6 @@ Require Crypto.Util.PrimitiveHList. Require Import Crypto.Experiments.NewPipeline.Language. Require Import Crypto.Experiments.NewPipeline.UnderLets. Require Import Crypto.Experiments.NewPipeline.GENERATEDIdentifiersWithoutTypes. -Require Import Crypto.Util.LetIn. Require Import Crypto.Util.Notations. Import ListNotations. Local Open Scope bool_scope. Local Open Scope Z_scope. @@ -426,49 +425,42 @@ Module Compilers. => match ctx with | nil => cont None ctx None | ctx0 :: ctx' - => let default := fun 'tt => @eval_decision_tree T ctx default_case cont in - let bind_default_in f - := match default_case with - | Failure => f default - | _ => (dlet default := default in f default) - end in - bind_default_in - (fun default - => reveal_rawexpr_cps - ctx0 _ - (fun ctx0' - => match ctx0' with - | rIdent t idc t' alt - => fold_right - (fun '(pidc, icase) default 'tt - => match invert_bind_args _ idc pidc with - | Some args - => @eval_decision_tree - T ctx' icase - (fun k ctx'' - => cont k (rIdent (pident_to_typed pidc args) alt :: ctx'')) - | None => default tt - end) - default - icases - tt - | rApp f x t alt - => match app_case with - | Some app_case + => let default _ := @eval_decision_tree T ctx default_case cont in + reveal_rawexpr_cps + ctx0 _ + (fun ctx0' + => match ctx0' with + | rIdent t idc t' alt + => fold_right + (fun '(pidc, icase) default 'tt + => match invert_bind_args _ idc pidc with + | Some args => @eval_decision_tree - T (f :: x :: ctx') app_case + T ctx' icase (fun k ctx'' - => match ctx'' with - | f' :: x' :: ctx''' - => cont k (rApp f' x' alt :: ctx''') - | _ => cont None ctx - end) + => cont k (rIdent (pident_to_typed pidc args) alt :: ctx'')) | None => default tt - end - | rExpr t e - | rValue t e - => default tt - end)) + end) + default + icases + tt + | rApp f x t alt + => match app_case with + | Some app_case + => @eval_decision_tree + T (f :: x :: ctx') app_case + (fun k ctx'' + => match ctx'' with + | f' :: x' :: ctx''' + => cont k (rApp f' x' alt :: ctx''') + | _ => cont None ctx + end) + | None => default tt + end + | rExpr t e + | rValue t e + => default tt + end) end | Swap i d' => match swap_list 0 i ctx with @@ -507,8 +499,7 @@ Module Compilers. (rew : rewrite_rulesT) (e : rawexpr) : UnderLets (expr (type_of_rawexpr e)) - := dlet default := UnderLets.Base (expr_of_rawexpr e) in - eval_decision_tree + := eval_decision_tree (e::nil) d (fun k ctx default_on_rewrite_failure => match k, ctx return UnderLets (expr (type_of_rawexpr e)) with @@ -532,18 +523,18 @@ Module Compilers. => match fv', default_on_rewrite_failure with | Some fv'', _ => UnderLets.Base fv'' | None, Some default => default tt - | None, None => default + | None, None => UnderLets.Base (expr_of_rawexpr e) end))%under_lets | None => match default_on_rewrite_failure with | Some default => default tt - | None => default + | None => UnderLets.Base (expr_of_rawexpr e) end end) - | None => default + | None => UnderLets.Base (expr_of_rawexpr e) end) - | None => default + | None => UnderLets.Base (expr_of_rawexpr e) end - | _, _ => default + | _, _ => UnderLets.Base (expr_of_rawexpr e) end). Local Notation enumerate ls @@ -567,20 +558,12 @@ Module Compilers. end) (enumerate p). - Definition starts_with_wildcard : nat * list pattern -> bool - := fun '(_, p) => match p with - | pattern.Wildcard _::_ => true - | _ => false - end. - - Definition not_starts_with_wildcard : nat * list pattern -> bool - := fun p => negb (starts_with_wildcard p). - Definition filter_pattern_wildcard (p : list (nat * list pattern)) : list (nat * list pattern) - := filter starts_with_wildcard p. - - Definition split_at_first_pattern_wildcard (p : list (nat * list pattern)) : list (nat * list pattern) * list (nat * list pattern) - := (take_while not_starts_with_wildcard p, drop_while not_starts_with_wildcard p). + := filter (fun '(_, p) => match p with + | pattern.Wildcard _::_ => true + | _ => false + end) + p. Fixpoint get_unique_pattern_ident' (p : list (nat * list pattern)) (so_far : list pident) : list pident := match p with @@ -609,23 +592,25 @@ Module Compilers. end) p. - Definition filter_pattern_app (p : nat * list pattern) : option (nat * list pattern) + Definition refine_pattern_app (p : nat * list pattern) : option (nat * list pattern) := match p with + | (n, pattern.Wildcard d::ps) + => Some (n, (??{?? -> d} :: ?? :: ps)%list%pattern) | (n, pattern.App f x :: ps) => Some (n, f :: x :: ps) | (_, pattern.Ident _::_) - | (_, pattern.Wildcard _::_) | (_, nil) => None end. - Definition filter_pattern_pident (pidc : pident) (p : nat * list pattern) : option (nat * list pattern) + Definition refine_pattern_pident (pidc : pident) (p : nat * list pattern) : option (nat * list pattern) := match p with + | (n, pattern.Wildcard _::ps) + => Some (n, ps) | (n, pattern.Ident pidc'::ps) => if pident_beq pidc pidc' then Some (n, ps) else None - | (_, pattern.Wildcard _::_) | (_, pattern.App _ _::_) | (_, nil) => None @@ -643,14 +628,13 @@ Module Compilers. => (onfailure <- compile_rewrites ps; Some (TryLeaf n1 onfailure)) | Some Datatypes.O - => let '(pattern_matrix, default_pattern_matrix) := split_at_first_pattern_wildcard pattern_matrix in - default_case <- compile_rewrites default_pattern_matrix; + => default_case <- compile_rewrites (filter_pattern_wildcard pattern_matrix); app_case <- (if contains_pattern_app pattern_matrix - then option_map Some (compile_rewrites (Option.List.map filter_pattern_app pattern_matrix)) + then option_map Some (compile_rewrites (Option.List.map refine_pattern_app pattern_matrix)) else Some None); let pidcs := get_unique_pattern_ident pattern_matrix in let icases := Option.List.map - (fun pidc => option_map (pair pidc) (compile_rewrites (Option.List.map (filter_pattern_pident pidc) pattern_matrix))) + (fun pidc => option_map (pair pidc) (compile_rewrites (Option.List.map (refine_pattern_pident pidc) pattern_matrix))) pidcs in Some (Switch icases app_case default_case) | Some i @@ -1740,7 +1724,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) := Eval cbv -[fancy_pr2_rewrite_rules base.interp base.try_make_transport_cps type.try_make_transport_cps type.try_transport_cps - Let_In UnderLets.splice UnderLets.to_expr Compile.reflect Compile.reify Compile.reify_and_let_binds_cps UnderLets.reify_and_let_binds_base_cps Compile.value' SubstVarLike.is_var_fst_snd_pair_opp @@ -1804,9 +1787,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) ] in fancy_rewrite_head2. (* Finished transaction in 13.298 secs (13.283u,0.s) (successful) *) - Local Set Printing Depth 1000000. - Local Set Printing Width 200. - Local Notation "'llet' x := v 'in' f" := (Let_In v (fun x => f)). Redirect "/tmp/fancy_rewrite_head" Print fancy_rewrite_head. End red_fancy. @@ -1820,7 +1800,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) := Eval cbv -[nbe_pr2_rewrite_rules base.interp base.try_make_transport_cps type.try_make_transport_cps type.try_transport_cps - Let_In UnderLets.splice UnderLets.to_expr Compile.reflect UnderLets.reify_and_let_binds_base_cps Compile.reify Compile.reify_and_let_binds_cps Compile.value' @@ -1885,9 +1864,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) ] in nbe_rewrite_head2. (* Finished transaction in 16.561 secs (16.54u,0.s) (successful) *) - Local Set Printing Depth 1000000. - Local Set Printing Width 200. - Local Notation "'llet' x := v 'in' f" := (Let_In v (fun x => f)). Redirect "/tmp/nbe_rewrite_head" Print nbe_rewrite_head. End red_nbe. @@ -1901,7 +1877,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) := Eval cbv -[arith_pr2_rewrite_rules base.interp base.try_make_transport_cps type.try_make_transport_cps type.try_transport_cps - Let_In UnderLets.splice UnderLets.to_expr Compile.reflect UnderLets.reify_and_let_binds_base_cps Compile.reify Compile.reify_and_let_binds_cps Compile.value' @@ -1966,9 +1941,6 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) ] in arith_rewrite_head2. (* Finished transaction in 16.561 secs (16.54u,0.s) (successful) *) - Local Set Printing Depth 1000000. - Local Set Printing Width 200. - Local Notation "'llet' x := v 'in' f" := (Let_In v (fun x => f)). Redirect "/tmp/arith_rewrite_head" Print arith_rewrite_head. End red_arith. |