diff options
Diffstat (limited to 'src/Experiments/NewPipeline/Rewriter.v')
-rw-r--r-- | src/Experiments/NewPipeline/Rewriter.v | 134 |
1 files changed, 81 insertions, 53 deletions
diff --git a/src/Experiments/NewPipeline/Rewriter.v b/src/Experiments/NewPipeline/Rewriter.v index f9f6e22ac..dd898bc3e 100644 --- a/src/Experiments/NewPipeline/Rewriter.v +++ b/src/Experiments/NewPipeline/Rewriter.v @@ -8,6 +8,7 @@ Require Crypto.Util.PrimitiveHList. Require Import Crypto.Experiments.NewPipeline.Language. Require Import Crypto.Experiments.NewPipeline.UnderLets. Require Import Crypto.Experiments.NewPipeline.GENERATEDIdentifiersWithoutTypes. +Require Import Crypto.Util.LetIn. Require Import Crypto.Util.Notations. Import ListNotations. Local Open Scope bool_scope. Local Open Scope Z_scope. @@ -425,42 +426,49 @@ Module Compilers. => match ctx with | nil => cont None ctx None | ctx0 :: ctx' - => let default _ := @eval_decision_tree T ctx default_case cont in - reveal_rawexpr_cps - ctx0 _ - (fun ctx0' - => match ctx0' with - | rIdent t idc t' alt - => fold_right - (fun '(pidc, icase) default 'tt - => match invert_bind_args _ idc pidc with - | Some args + => let default := fun 'tt => @eval_decision_tree T ctx default_case cont in + let bind_default_in f + := match default_case with + | Failure => f default + | _ => (dlet default := default in f default) + end in + bind_default_in + (fun default + => reveal_rawexpr_cps + ctx0 _ + (fun ctx0' + => match ctx0' with + | rIdent t idc t' alt + => fold_right + (fun '(pidc, icase) default 'tt + => match invert_bind_args _ idc pidc with + | Some args + => @eval_decision_tree + T ctx' icase + (fun k ctx'' + => cont k (rIdent (pident_to_typed pidc args) alt :: ctx'')) + | None => default tt + end) + default + icases + tt + | rApp f x t alt + => match app_case with + | Some app_case => @eval_decision_tree - T ctx' icase + T (f :: x :: ctx') app_case (fun k ctx'' - => cont k (rIdent (pident_to_typed pidc args) alt :: ctx'')) + => match ctx'' with + | f' :: x' :: ctx''' + => cont k (rApp f' x' alt :: ctx''') + | _ => cont None ctx + end) | None => default tt - end) - default - icases - tt - | rApp f x t alt - => match app_case with - | Some app_case - => @eval_decision_tree - T (f :: x :: ctx') app_case - (fun k ctx'' - => match ctx'' with - | f' :: x' :: ctx''' - => cont k (rApp f' x' alt :: ctx''') - | _ => cont None ctx - end) - | None => default tt - end - | rExpr t e - | rValue t e - => default tt - end) + end + | rExpr t e + | rValue t e + => default tt + end)) end | Swap i d' => match swap_list 0 i ctx with @@ -499,7 +507,8 @@ Module Compilers. (rew : rewrite_rulesT) (e : rawexpr) : UnderLets (expr (type_of_rawexpr e)) - := eval_decision_tree + := dlet default := UnderLets.Base (expr_of_rawexpr e) in + eval_decision_tree (e::nil) d (fun k ctx default_on_rewrite_failure => match k, ctx return UnderLets (expr (type_of_rawexpr e)) with @@ -523,18 +532,18 @@ Module Compilers. => match fv', default_on_rewrite_failure with | Some fv'', _ => UnderLets.Base fv'' | None, Some default => default tt - | None, None => UnderLets.Base (expr_of_rawexpr e) + | None, None => default end))%under_lets | None => match default_on_rewrite_failure with | Some default => default tt - | None => UnderLets.Base (expr_of_rawexpr e) + | None => default end end) - | None => UnderLets.Base (expr_of_rawexpr e) + | None => default end) - | None => UnderLets.Base (expr_of_rawexpr e) + | None => default end - | _, _ => UnderLets.Base (expr_of_rawexpr e) + | _, _ => default end). Local Notation enumerate ls @@ -558,12 +567,20 @@ Module Compilers. end) (enumerate p). + Definition starts_with_wildcard : nat * list pattern -> bool + := fun '(_, p) => match p with + | pattern.Wildcard _::_ => true + | _ => false + end. + + Definition not_starts_with_wildcard : nat * list pattern -> bool + := fun p => negb (starts_with_wildcard p). + Definition filter_pattern_wildcard (p : list (nat * list pattern)) : list (nat * list pattern) - := filter (fun '(_, p) => match p with - | pattern.Wildcard _::_ => true - | _ => false - end) - p. + := filter starts_with_wildcard p. + + Definition split_at_first_pattern_wildcard (p : list (nat * list pattern)) : list (nat * list pattern) * list (nat * list pattern) + := (take_while not_starts_with_wildcard p, drop_while not_starts_with_wildcard p). Fixpoint get_unique_pattern_ident' (p : list (nat * list pattern)) (so_far : list pident) : list pident := match p with @@ -592,25 +609,23 @@ Module Compilers. end) p. - Definition refine_pattern_app (p : nat * list pattern) : option (nat * list pattern) + Definition filter_pattern_app (p : nat * list pattern) : option (nat * list pattern) := match p with - | (n, pattern.Wildcard d::ps) - => Some (n, (??{?? -> d} :: ?? :: ps)%list%pattern) | (n, pattern.App f x :: ps) => Some (n, f :: x :: ps) | (_, pattern.Ident _::_) + | (_, pattern.Wildcard _::_) | (_, nil) => None end. - Definition refine_pattern_pident (pidc : pident) (p : nat * list pattern) : option (nat * list pattern) + Definition filter_pattern_pident (pidc : pident) (p : nat * list pattern) : option (nat * list pattern) := match p with - | (n, pattern.Wildcard _::ps) - => Some (n, ps) | (n, pattern.Ident pidc'::ps) => if pident_beq pidc pidc' then Some (n, ps) else None + | (_, pattern.Wildcard _::_) | (_, pattern.App _ _::_) | (_, nil) => None @@ -628,13 +643,14 @@ Module Compilers. => (onfailure <- compile_rewrites ps; Some (TryLeaf n1 onfailure)) | Some Datatypes.O - => default_case <- compile_rewrites (filter_pattern_wildcard pattern_matrix); + => let '(pattern_matrix, default_pattern_matrix) := split_at_first_pattern_wildcard pattern_matrix in + default_case <- compile_rewrites default_pattern_matrix; app_case <- (if contains_pattern_app pattern_matrix - then option_map Some (compile_rewrites (Option.List.map refine_pattern_app pattern_matrix)) + then option_map Some (compile_rewrites (Option.List.map filter_pattern_app pattern_matrix)) else Some None); let pidcs := get_unique_pattern_ident pattern_matrix in let icases := Option.List.map - (fun pidc => option_map (pair pidc) (compile_rewrites (Option.List.map (refine_pattern_pident pidc) pattern_matrix))) + (fun pidc => option_map (pair pidc) (compile_rewrites (Option.List.map (filter_pattern_pident pidc) pattern_matrix))) pidcs in Some (Switch icases app_case default_case) | Some i @@ -1724,6 +1740,7 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) := Eval cbv -[fancy_pr2_rewrite_rules base.interp base.try_make_transport_cps type.try_make_transport_cps type.try_transport_cps + Let_In UnderLets.splice UnderLets.to_expr Compile.reflect Compile.reify Compile.reify_and_let_binds_cps UnderLets.reify_and_let_binds_base_cps Compile.value' SubstVarLike.is_var_fst_snd_pair_opp @@ -1787,6 +1804,9 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) ] in fancy_rewrite_head2. (* Finished transaction in 13.298 secs (13.283u,0.s) (successful) *) + Local Set Printing Depth 1000000. + Local Set Printing Width 200. + Local Notation "'llet' x := v 'in' f" := (Let_In v (fun x => f)). Redirect "/tmp/fancy_rewrite_head" Print fancy_rewrite_head. End red_fancy. @@ -1800,6 +1820,7 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) := Eval cbv -[nbe_pr2_rewrite_rules base.interp base.try_make_transport_cps type.try_make_transport_cps type.try_transport_cps + Let_In UnderLets.splice UnderLets.to_expr Compile.reflect UnderLets.reify_and_let_binds_base_cps Compile.reify Compile.reify_and_let_binds_cps Compile.value' @@ -1864,6 +1885,9 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) ] in nbe_rewrite_head2. (* Finished transaction in 16.561 secs (16.54u,0.s) (successful) *) + Local Set Printing Depth 1000000. + Local Set Printing Width 200. + Local Notation "'llet' x := v 'in' f" := (Let_In v (fun x => f)). Redirect "/tmp/nbe_rewrite_head" Print nbe_rewrite_head. End red_nbe. @@ -1877,6 +1901,7 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) := Eval cbv -[arith_pr2_rewrite_rules base.interp base.try_make_transport_cps type.try_make_transport_cps type.try_transport_cps + Let_In UnderLets.splice UnderLets.to_expr Compile.reflect UnderLets.reify_and_let_binds_base_cps Compile.reify Compile.reify_and_let_binds_cps Compile.value' @@ -1941,6 +1966,9 @@ Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) ] in arith_rewrite_head2. (* Finished transaction in 16.561 secs (16.54u,0.s) (successful) *) + Local Set Printing Depth 1000000. + Local Set Printing Width 200. + Local Notation "'llet' x := v 'in' f" := (Let_In v (fun x => f)). Redirect "/tmp/arith_rewrite_head" Print arith_rewrite_head. End red_arith. |