diff options
Diffstat (limited to 'src/Rewriter.v')
-rw-r--r-- | src/Rewriter.v | 386 |
1 files changed, 131 insertions, 255 deletions
diff --git a/src/Rewriter.v b/src/Rewriter.v index 05a802ddd..34107b2c2 100644 --- a/src/Rewriter.v +++ b/src/Rewriter.v @@ -13,7 +13,6 @@ Require Import Crypto.Util.ZUtil.Notations. Require Import Crypto.Util.Tactics.ConstrFail. Require Crypto.Util.PrimitiveProd. Require Crypto.Util.PrimitiveHList. -Require Import Crypto.RewriterRules. Require Import Crypto.Language. Require Import Crypto.UnderLets. Require Import Crypto.GENERATEDIdentifiersWithoutTypes. @@ -1716,7 +1715,11 @@ Module Compilers. Ltac equation_to_parts' lem side_conditions := lazymatch lem with | ?H -> ?P - => let H := prop_to_bool H in + => let __ := lazymatch type of H with + | Prop => constr:(I) + | ?T => constr_fail_with ltac:(fun _ => fail 1 "Invalid non-Prop non-dependent hypothesis of type" H ":" T "when reifying a lemma of type" lem) + end in + let H := prop_to_bool H in let side_conditions := push_side_conditions H side_conditions in equation_to_parts' P side_conditions | forall x : ?T, ?P @@ -2435,12 +2438,16 @@ Module Compilers. | _ => check_debug_level_then_Set () end. + Definition pident_unify_unknown := @pattern.ident.unify. + Definition invert_bind_args_unknown := @pattern.Raw.ident.invert_bind_args. + Module Export GoalType. Record rewriter_dataT := Build_rewriter_dataT' { rewrite_rules_specs : list (bool * Prop); dummy_count : nat; + dtree : @Compile.decision_tree pattern.Raw.ident; rewrite_rules : forall var, @Compile.rewrite_rulesT ident var pattern.ident (@pattern.ident.arg_types) ; all_rewrite_rules (* adjusted version *) : _; @@ -2448,9 +2455,11 @@ Module Compilers. default_fuel : nat; - rewrite_head0 : forall var (do_again : forall t, @defaults.expr (@Compile.value _ ident var) (type.base t) -> @UnderLets.UnderLets _ ident var (@defaults.expr var (type.base t))) + rewrite_head0 + := (fun var + => @Compile.assemble_identifier_rewriters ident var (@pattern.ident.eta_ident_cps) (@pattern.ident) (@pattern.ident.arg_types) (@pattern.ident.unify) pident_unify_unknown pattern.Raw.ident (@pattern.Raw.ident.full_types) (@pattern.Raw.ident.invert_bind_args) invert_bind_args_unknown (@pattern.Raw.ident.type_of) (@pattern.Raw.ident.to_typed) pattern.Raw.ident.is_simple dtree (all_rewrite_rules var)); + rewrite_head (* adjusted version *) : forall var (do_again : forall t, @defaults.expr (@Compile.value _ ident var) (type.base t) -> @UnderLets.UnderLets _ ident var (@defaults.expr var (type.base t))) t (idc : ident t), @Compile.value_with_lets base.type ident var t; - rewrite_head (* adjusted version *) : _; rewrite_head_eq : rewrite_head = rewrite_head0 }. End GoalType. @@ -2473,94 +2482,95 @@ Module Compilers. | _ => constr_fail_with ltac:(fun _ => fail 1 "Invalid value for include_interp (must be either true or false):" include_interp) end. - Definition pident_unify_unknown := @pattern.ident.unify. - Definition invert_bind_args_unknown := @pattern.Raw.ident.invert_bind_args. + Ltac time_if_debug1 := + let lvl := rewriter_assembly_debug_level in + lazymatch lvl with + | O => ltac:(fun tac => tac ()) + | S _ => ltac:(fun tac => time tac ()) + | ?v => ltac:(fun tac => fail 0 "Invalid non-nat rewriter_assembly_debug_level" v) + end. + Ltac time_tac_in_constr_if_debug1 tac := + constr:(ltac:(time_if_debug1 ltac:(fun _ => idtac; let v := tac () in exact v))). Ltac make_rewrite_head1 rewrite_head0 pr2_rewrite_rules := - let rewrite_head1 - := (eval cbv -[pr2_rewrite_rules - base.interp base.try_make_transport_cps - type.try_make_transport_cps - pattern.type.unify_extracted - Compile.option_type_type_beq - Let_In Option.sequence Option.sequence_return - UnderLets.splice UnderLets.to_expr - Compile.option_bind' pident_unify_unknown invert_bind_args_unknown Compile.normalize_deep_rewrite_rule - Compile.reflect UnderLets.reify_and_let_binds_base_cps Compile.reify Compile.reify_and_let_binds_cps - Compile.value' - SubstVarLike.is_var_fst_snd_pair_opp_cast - ] in rewrite_head0) in - let rewrite_head1 - := (eval cbn [type.try_make_transport_cps base.try_make_transport_cps base.try_make_base_transport_cps] - in rewrite_head1) in - rewrite_head1. - Ltac timed_make_rewrite_head1 rewrite_head0 pr2_rewrite_rules := - constr:(ltac:(time (idtac; let v := make_rewrite_head1 rewrite_head0 pr2_rewrite_rules in exact v))). + time_tac_in_constr_if_debug1 + ltac:(fun _ + => let rewrite_head1 + := (eval cbv -[pr2_rewrite_rules + base.interp base.try_make_transport_cps + type.try_make_transport_cps + pattern.type.unify_extracted + Compile.option_type_type_beq + Let_In Option.sequence Option.sequence_return + UnderLets.splice UnderLets.to_expr + Compile.option_bind' pident_unify_unknown invert_bind_args_unknown Compile.normalize_deep_rewrite_rule + Compile.reflect UnderLets.reify_and_let_binds_base_cps Compile.reify Compile.reify_and_let_binds_cps + Compile.value' + SubstVarLike.is_var_fst_snd_pair_opp_cast + ] in rewrite_head0) in + let rewrite_head1 + := (eval cbn [type.try_make_transport_cps base.try_make_transport_cps base.try_make_base_transport_cps] + in rewrite_head1) in + rewrite_head1). Ltac make_rewrite_head2 rewrite_head1 pr2_rewrite_rules := - (eval cbv [id - pr2_rewrite_rules - projT1 projT2 - cpsbind cpscall cps_option_bind cpsreturn - PrimitiveProd.Primitive.fst PrimitiveProd.Primitive.snd - pattern.type.subst_default pattern.base.subst_default pattern.base.lookup_default - PositiveMap.add PositiveMap.find PositiveMap.empty - PositiveSet.rev PositiveSet.rev_append - pattern.ident.arg_types - Compile.eval_decision_tree - Compile.eval_rewrite_rules - Compile.expr_of_rawexpr - Compile.normalize_deep_rewrite_rule - Compile.option_bind' pident_unify_unknown invert_bind_args_unknown Compile.normalize_deep_rewrite_rule - (*Compile.reflect*) - (*Compile.reify*) - Compile.reveal_rawexpr_cps - Compile.reveal_rawexpr_cps_gen - Compile.rew_should_do_again - Compile.rew_with_opt - Compile.rew_under_lets - Compile.rew_replacement - Compile.rValueOrExpr - Compile.swap_list - Compile.type_of_rawexpr - Compile.option_type_type_beq - Compile.value - (*Compile.value'*) - Compile.value_of_rawexpr - Compile.value_with_lets - ident.smart_Literal - type.try_transport_cps - (*rlist_rect rwhen rwhenl*) - ] in rewrite_head1). - Ltac timed_make_rewrite_head2 rewrite_head1 pr2_rewrite_rules := - constr:(ltac:(time (idtac; let v := make_rewrite_head2 rewrite_head1 pr2_rewrite_rules in exact v))). + time_tac_in_constr_if_debug1 + ltac:(fun _ + => (eval cbv [id + pr2_rewrite_rules + projT1 projT2 + cpsbind cpscall cps_option_bind cpsreturn + PrimitiveProd.Primitive.fst PrimitiveProd.Primitive.snd + pattern.type.subst_default pattern.base.subst_default pattern.base.lookup_default + PositiveMap.add PositiveMap.find PositiveMap.empty + PositiveSet.rev PositiveSet.rev_append + pattern.ident.arg_types + Compile.eval_decision_tree + Compile.eval_rewrite_rules + Compile.expr_of_rawexpr + Compile.normalize_deep_rewrite_rule + Compile.option_bind' pident_unify_unknown invert_bind_args_unknown Compile.normalize_deep_rewrite_rule + (*Compile.reflect*) + (*Compile.reify*) + Compile.reveal_rawexpr_cps + Compile.reveal_rawexpr_cps_gen + Compile.rew_should_do_again + Compile.rew_with_opt + Compile.rew_under_lets + Compile.rew_replacement + Compile.rValueOrExpr + Compile.swap_list + Compile.type_of_rawexpr + Compile.option_type_type_beq + Compile.value + (*Compile.value'*) + Compile.value_of_rawexpr + Compile.value_with_lets + ident.smart_Literal + type.try_transport_cps + (*rlist_rect rwhen rwhenl*) + ] in rewrite_head1)). Ltac make_rewrite_head3 rewrite_head2 := - (eval cbn [id - cpsbind cpscall cps_option_bind cpsreturn - Compile.reify Compile.reify_and_let_binds_cps Compile.reflect Compile.value' - Option.sequence Option.sequence_return Option.bind - UnderLets.reify_and_let_binds_base_cps - UnderLets.splice UnderLets.splice_list UnderLets.to_expr - base.interp base.base_interp - base.type.base_beq option_beq - type.try_make_transport_cps base.try_make_transport_cps base.try_make_base_transport_cps - Datatypes.fst Datatypes.snd - ] in rewrite_head2). - Ltac timed_make_rewrite_head3 rewrite_head2 := - constr:(ltac:(time (idtac; let v := make_rewrite_head3 rewrite_head2 in exact v))). + time_tac_in_constr_if_debug1 + ltac:(fun _ + => (eval cbn [id + cpsbind cpscall cps_option_bind cpsreturn + Compile.reify Compile.reify_and_let_binds_cps Compile.reflect Compile.value' + Option.sequence Option.sequence_return Option.bind + UnderLets.reify_and_let_binds_base_cps + UnderLets.splice UnderLets.splice_list UnderLets.to_expr + base.interp base.base_interp + base.type.base_beq option_beq + type.try_make_transport_cps base.try_make_transport_cps base.try_make_base_transport_cps + Datatypes.fst Datatypes.snd + ] in rewrite_head2)). Ltac make_rewrite_head' rewrite_head0 pr2_rewrite_rules := let rewrite_head1 := make_rewrite_head1 rewrite_head0 pr2_rewrite_rules in let rewrite_head2 := make_rewrite_head2 rewrite_head1 pr2_rewrite_rules in let rewrite_head3 := make_rewrite_head3 rewrite_head2 in rewrite_head3. - Ltac timed_make_rewrite_head' rewrite_head0 pr2_rewrite_rules := - let rewrite_head1 := timed_make_rewrite_head1 rewrite_head0 pr2_rewrite_rules in - let rewrite_head2 := timed_make_rewrite_head2 rewrite_head1 pr2_rewrite_rules in - let rewrite_head3 := timed_make_rewrite_head3 rewrite_head2 in - rewrite_head3. Ltac make_rewrite_head rewrite_head0 pr2_rewrite_rules := let rewrite_head := fresh "rewrite_head" in - let lvl := rewriter_assembly_debug_level in let var := fresh "var" in let do_again := fresh "do_again" in let t := fresh "t" in @@ -2572,10 +2582,7 @@ Module Compilers. => ltac:( let rewrite_head0 := constr:(rewrite_head0 var do_again t idc) in let pr2_rewrite_rules := head pr2_rewrite_rules in - let v := lazymatch lvl with - | O => make_rewrite_head' rewrite_head0 pr2_rewrite_rules - | S _ => timed_make_rewrite_head' rewrite_head0 pr2_rewrite_rules - end in + let v := make_rewrite_head' rewrite_head0 pr2_rewrite_rules in exact v)) in cache_term v rewrite_head. @@ -2605,187 +2612,56 @@ Module Compilers. rewrite_head0 in let __ := debug1 ltac:(fun _ => idtac "Reducing rewrite_head...") in let rewrite_head := make_rewrite_head rewrite_head0 pr2_rewrite_rules in - refine (@Build_rewriter_dataT' - specs dummy_count + constr:(@Build_rewriter_dataT' + specs dummy_count dtree rewrite_rules all_rewrite_rules eq_refl default_fuel - rewrite_head0 rewrite_head eq_refl). + (*rewrite_head0*) rewrite_head eq_refl). Module Export Tactic. - Global Arguments base.try_make_base_transport_cps _ !_ !_. - Global Arguments base.try_make_transport_cps _ !_ !_. - Global Arguments type.try_make_transport_cps _ _ _ !_ !_. - Global Arguments Option.sequence A !v1 v2. - Global Arguments Option.sequence_return A !v1 v2. - Global Arguments Option.bind A B !_ _. - Global Arguments pattern.Raw.ident.invert_bind_args t !_ !_. - Global Arguments base.type.base_beq !_ !_. - Global Arguments id / . + Module Export Settings. + Global Arguments base.try_make_base_transport_cps _ !_ !_. + Global Arguments base.try_make_transport_cps _ !_ !_. + Global Arguments type.try_make_transport_cps _ _ _ !_ !_. + Global Arguments Option.sequence A !v1 v2. + Global Arguments Option.sequence_return A !v1 v2. + Global Arguments Option.bind A B !_ _. + Global Arguments pattern.Raw.ident.invert_bind_args t !_ !_. + Global Arguments base.type.base_beq !_ !_. + Global Arguments id / . + End Settings. Tactic Notation "make_rewriter_data" constr(include_interp) constr(specs) := - Build_rewriter_dataT include_interp specs. + let res := Build_rewriter_dataT include_interp specs in refine res. End Tactic. End Make. Export Make.GoalType. Import Make.Tactic. - Definition nbe_rewriter_data : rewriter_dataT. - Proof. make_rewriter_data true nbe_rewrite_rulesT. Defined. - - Definition arith_rewriter_data (max_const_val : Z) : rewriter_dataT. - Proof. make_rewriter_data false (arith_rewrite_rulesT max_const_val). Defined. - - Definition arith_with_casts_rewriter_data : rewriter_dataT. - Proof. make_rewriter_data false arith_with_casts_rewrite_rulesT. Defined. - - Definition strip_literal_casts_rewriter_data : rewriter_dataT. - Proof. make_rewriter_data false strip_literal_casts_rewrite_rulesT. Defined. - - Definition fancy_rewriter_data - (invert_low invert_high : Z (*log2wordmax*) -> Z -> option Z) - : rewriter_dataT. - Proof. make_rewriter_data false fancy_rewrite_rulesT. Defined. - - Definition fancy_with_casts_rewriter_data - (invert_low invert_high : Z (*log2wordmax*) -> Z -> option Z) - (value_range flag_range : zrange) - : rewriter_dataT. - Proof. make_rewriter_data false (fancy_with_casts_rewrite_rulesT invert_low invert_high value_range flag_range). Defined. - - Module RewriterPrintingNotations. - Arguments base.try_make_transport_cps {P} t1 t2 {_} _. - Arguments type.try_make_transport_cps {base_type _ P} t1 t2 {_} _. - Export pattern.Raw.ident. - Export GENERATEDIdentifiersWithoutTypes.Compilers.pattern.Raw. - Export GENERATEDIdentifiersWithoutTypes.Compilers.pattern. - Export UnderLets. - Export Compilers.ident. - Export Language.Compilers. - Export Language.Compilers.defaults. - Export PrimitiveSigma.Primitive. - Notation "'llet' x := v 'in' f" := (Let_In v (fun x => f)). - Notation "x <- 'type.try_make_transport_cps' t1 t2 ; f" := (type.try_make_transport_cps t1 t2 (fun y => match y with Datatypes.Some x => f | Datatypes.None => Datatypes.None end)) (at level 70, t1 at next level, t2 at next level, right associativity, format "'[v' x <- 'type.try_make_transport_cps' t1 t2 ; '/' f ']'"). - Notation "x <- 'base.try_make_transport_cps' t1 t2 ; f" := (base.try_make_transport_cps t1 t2 (fun y => match y with Datatypes.Some x => f | Datatypes.None => Datatypes.None end)) (at level 70, t1 at next level, t2 at next level, right associativity, format "'[v' x <- 'base.try_make_transport_cps' t1 t2 ; '/' f ']'"). - End RewriterPrintingNotations. - - (* For printing *) - Local Arguments base.try_make_transport_cps {P} t1 t2 {_} _. - Local Arguments type.try_make_transport_cps {base_type _ P} t1 t2 {_} _. - Local Arguments option {_}. - Local Arguments UnderLets.UnderLets {_ _ _}. - Local Arguments expr.expr {_ _ _}. - Local Notation ℤ := base.type.Z. - Local Notation ℕ := base.type.nat. - Local Notation bool := base.type.bool. - Local Notation unit := base.type.unit. - Local Notation list := base.type.list. - Local Notation "x" := (type.base x) (only printing, at level 9). - - Section red_fancy. - Context (invert_low invert_high : Z (*log2wordmax*) -> Z -> @option Z). - - Local Definition fancy_rewrite_head - := Eval hnf in rewrite_head (fancy_rewriter_data invert_low invert_high). - - Local Set Printing Depth 1000000. - Local Set Printing Width 200. - Import RewriterPrintingNotations. - Redirect "fancy_rewrite_head" Print fancy_rewrite_head. - End red_fancy. - Section red_fancy_with_casts. - Context (invert_low invert_high : Z (*log2wordmax*) -> Z -> @option Z) - (value_range flag_range : zrange). - - Local Definition fancy_with_casts_rewrite_head - := Eval hnf in rewrite_head (fancy_with_casts_rewriter_data invert_low invert_high value_range flag_range). - - Local Set Printing Depth 1000000. - Local Set Printing Width 200. - Import RewriterPrintingNotations. - Redirect "fancy_with_casts_rewrite_head" Print fancy_with_casts_rewrite_head. - End red_fancy_with_casts. - Section red_nbe. - Local Definition nbe_rewrite_head - := Eval hnf in rewrite_head nbe_rewriter_data. - - Local Set Printing Depth 1000000. - Local Set Printing Width 200. - Import RewriterPrintingNotations. - Redirect "nbe_rewrite_head" Print nbe_rewrite_head. - End red_nbe. - - Section red_arith. - Context (max_const_val : Z). - - Local Definition arith_rewrite_head - := Eval hnf in rewrite_head (arith_rewriter_data max_const_val). - - Local Set Printing Depth 1000000. - Local Set Printing Width 200. - Import RewriterPrintingNotations. - Redirect "arith_rewrite_head" Print arith_rewrite_head. - End red_arith. - - Section red_arith_with_casts. - Local Definition arith_with_casts_rewrite_head - := Eval hnf in rewrite_head arith_with_casts_rewriter_data. - - Local Set Printing Depth 1000000. - Local Set Printing Width 200. - Import RewriterPrintingNotations. - Redirect "arith_with_casts_rewrite_head" Print arith_with_casts_rewrite_head. - End red_arith_with_casts. - - Section red_strip_literal_casts. - Local Definition strip_literal_casts_rewrite_head - := Eval hnf in rewrite_head strip_literal_casts_rewriter_data. - - Local Set Printing Depth 1000000. - Local Set Printing Width 200. - Import RewriterPrintingNotations. - Redirect "strip_literal_casts_rewrite_head" Print strip_literal_casts_rewrite_head. - End red_strip_literal_casts. - - Local Ltac unfold_Rewrite Rewrite := - let h := head Rewrite in - let Rewrite := (eval cbv [h] in Rewrite) in - let data := lazymatch Rewrite with context[@Make.Rewrite ?data] => head data end in - (eval cbv [Make.Rewrite rewrite_head default_fuel data] in Rewrite). - Local Notation unfold_Rewrite Rewrite := - (ltac:(let v := unfold_Rewrite Rewrite in - exact v)) (only parsing). - - Definition RewriteNBE_folded := @Make.Rewrite nbe_rewriter_data. - Definition RewriteNBE {t} (e : expr.Expr (ident:=ident) t) : expr.Expr (ident:=ident) t - := unfold_Rewrite (@RewriteNBE_folded t e). - Definition RewriteArith_folded (max_const_val : Z) := @Make.Rewrite (arith_rewriter_data max_const_val). - Definition RewriteArith (max_const_val : Z) {t} (e : expr.Expr (ident:=ident) t) : expr.Expr (ident:=ident) t - := unfold_Rewrite (@RewriteArith_folded max_const_val t e). - Definition RewriteArithWithCasts_folded := @Make.Rewrite arith_with_casts_rewriter_data. - Definition RewriteArithWithCasts {t} (e : expr.Expr (ident:=ident) t) : expr.Expr (ident:=ident) t - := unfold_Rewrite (@RewriteArithWithCasts_folded t e). - Definition RewriteStripLiteralCasts_folded := @Make.Rewrite strip_literal_casts_rewriter_data. - Definition RewriteStripLiteralCasts {t} (e : expr.Expr (ident:=ident) t) : expr.Expr (ident:=ident) t - := unfold_Rewrite (@RewriteStripLiteralCasts_folded t e). - Definition RewriteToFancy_folded - (invert_low invert_high : Z (*log2wordmax*) -> Z -> @option Z) - := @Make.Rewrite (fancy_rewriter_data invert_low invert_high). - Definition RewriteToFancy - (invert_low invert_high : Z (*log2wordmax*) -> Z -> @option Z) - {t} (e : expr.Expr (ident:=ident) t) : expr.Expr (ident:=ident) t - := unfold_Rewrite (@RewriteToFancy_folded invert_low invert_high t e). - Definition RewriteToFancyWithCasts_folded - (invert_low invert_high : Z (*log2wordmax*) -> Z -> @option Z) - (value_range flag_range : zrange) - := @Make.Rewrite (fancy_with_casts_rewriter_data invert_low invert_high value_range flag_range). - Definition RewriteToFancyWithCasts - (invert_low invert_high : Z (*log2wordmax*) -> Z -> @option Z) - (value_range flag_range : zrange) - {t} (e : expr.Expr (ident:=ident) t) : expr.Expr (ident:=ident) t - := unfold_Rewrite (@RewriteToFancyWithCasts_folded invert_low invert_high value_range flag_range t e). + Module Export GoalType. + Record RewriterT := + { + Rewriter_data : rewriter_dataT; + Rewrite : forall {t} (e : expr.Expr (ident:=ident) t), expr.Expr (ident:=ident) t; + Rewrite_eq : @Rewrite = @Make.Rewrite Rewriter_data + }. + End GoalType. + + Ltac Build_RewriterT include_interp specs := + let rewriter_data := fresh "rewriter_data" in + let data := Make.Build_rewriter_dataT include_interp specs in + let Rewrite_name := fresh "Rewriter" in + let Rewrite := (eval cbv [Make.Rewrite rewrite_head default_fuel] in (@Make.Rewrite data)) in + let Rewrite := cache_term Rewrite Rewrite_name in + constr:(@Build_RewriterT data Rewrite eq_refl). + + Module Export Tactic. + Module Export Settings. + Export Make.Tactic.Settings. + End Settings. + + Tactic Notation "make_Rewriter" constr(include_interp) constr(specs) := + let res := Build_RewriterT include_interp specs in refine res. + End Tactic. End RewriteRules. - - Import defaults. - - Definition PartialEvaluate {t} (e : Expr t) : Expr t := RewriteRules.RewriteNBE e. End Compilers. |