aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/NewPipeline/RewriterRulesGood.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Experiments/NewPipeline/RewriterRulesGood.v')
-rw-r--r--src/Experiments/NewPipeline/RewriterRulesGood.v207
1 files changed, 207 insertions, 0 deletions
diff --git a/src/Experiments/NewPipeline/RewriterRulesGood.v b/src/Experiments/NewPipeline/RewriterRulesGood.v
new file mode 100644
index 000000000..03502fb9c
--- /dev/null
+++ b/src/Experiments/NewPipeline/RewriterRulesGood.v
@@ -0,0 +1,207 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.micromega.Lia.
+Require Import Coq.Lists.List.
+Require Import Coq.Classes.Morphisms.
+Require Import Coq.MSets.MSetPositive.
+Require Import Coq.FSets.FMapPositive.
+Require Import Crypto.Experiments.NewPipeline.Language.
+Require Import Crypto.Experiments.NewPipeline.LanguageInversion.
+Require Import Crypto.Experiments.NewPipeline.LanguageWf.
+Require Import Crypto.Experiments.NewPipeline.UnderLetsProofs.
+Require Import Crypto.Experiments.NewPipeline.GENERATEDIdentifiersWithoutTypesProofs.
+Require Import Crypto.Experiments.NewPipeline.Rewriter.
+Require Import Crypto.Experiments.NewPipeline.RewriterWf1.
+Require Import Crypto.Util.Tactics.BreakMatch.
+Require Import Crypto.Util.Tactics.SplitInContext.
+Require Import Crypto.Util.Tactics.SpecializeAllWays.
+Require Import Crypto.Util.Tactics.SpecializeBy.
+Require Import Crypto.Util.Tactics.RewriteHyp.
+Require Import Crypto.Util.Tactics.Head.
+Require Import Crypto.Util.Prod.
+Require Import Crypto.Util.ListUtil.
+Require Import Crypto.Util.Option.
+Require Import Crypto.Util.CPSNotations.
+Require Import Crypto.Util.HProp.
+Require Import Crypto.Util.Decidable.
+Import ListNotations. Local Open Scope list_scope.
+Local Open Scope Z_scope.
+
+Import EqNotations.
+Module Compilers.
+ Import Language.Compilers.
+ Import LanguageInversion.Compilers.
+ Import LanguageWf.Compilers.
+ Import UnderLetsProofs.Compilers.
+ Import GENERATEDIdentifiersWithoutTypesProofs.Compilers.
+ Import Rewriter.Compilers.
+ Import RewriterWf1.Compilers.
+ Import expr.Notations.
+ Import RewriterWf1.Compilers.RewriteRules.
+ Import defaults.
+
+ Module Import RewriteRules.
+ Import Rewriter.Compilers.RewriteRules.
+
+ Lemma nbe_rewrite_head_eq : @nbe_rewrite_head = @nbe_rewrite_head0.
+ Proof. reflexivity. Qed.
+
+ Lemma fancy_rewrite_head_eq invert_low invert_high
+ : (fun var do_again => @fancy_rewrite_head invert_low invert_high var)
+ = (fun var => @fancy_rewrite_head0 var invert_low invert_high).
+ Proof. reflexivity. Qed.
+
+ Lemma arith_rewrite_head_eq max_const_val : @arith_rewrite_head max_const_val = (fun var => @arith_rewrite_head0 var max_const_val).
+ Proof. reflexivity. Qed.
+
+ Lemma nbe_all_rewrite_rules_eq : @nbe_all_rewrite_rules = @nbe_rewrite_rules.
+ Proof. reflexivity. Qed.
+
+ Lemma fancy_all_rewrite_rules_eq : @fancy_all_rewrite_rules = @fancy_rewrite_rules.
+ Proof. reflexivity. Qed.
+
+ Lemma arith_all_rewrite_rules_eq : @arith_all_rewrite_rules = @arith_rewrite_rules.
+ Proof. reflexivity. Qed.
+
+ Section good.
+ Context {var1 var2 : type -> Type}.
+
+ Local Notation rewrite_rules_goodT := (@Compile.rewrite_rules_goodT ident pattern.ident pattern.ident.arg_types var1 var2).
+
+ Lemma rlist_rect_cps_id {var} A P {ivar} N_case C_case ls T k
+ : @rlist_rect var A P ivar N_case C_case ls T k = k (@rlist_rect var A P ivar N_case C_case ls _ id).
+ Proof.
+ cbv [rlist_rect id Compile.option_bind']; rewrite !expr.reflect_list_cps_id.
+ destruct (invert_expr.reflect_list ls) eqn:?; cbn [Option.bind Option.sequence_return]; reflexivity.
+ Qed.
+ Lemma rlist_rect_cast_cps_id {var} A A' P {ivar} N_case C_case ls T k
+ : @rlist_rect_cast var A A' P ivar N_case C_case ls T k = k (@rlist_rect_cast var A A' P ivar N_case C_case ls _ id).
+ Proof.
+ cbv [rlist_rect_cast Compile.castbe Compile.castb id Compile.option_bind']; rewrite_type_transport_correct;
+ break_innermost_match; type_beq_to_eq; subst; cbn [eq_rect Option.bind Option.sequence_return]; [ | reflexivity ].
+ apply rlist_rect_cps_id.
+ Qed.
+
+ Local Ltac start_cps_id :=
+ lazymatch goal with
+ | [ |- In _ ?rewr -> _ ] => let h := head rewr in cbv [h]
+ end;
+ cbn [In combine]; intros; destruct_head'_or; inversion_sigma; subst; try reflexivity; destruct_head' False.
+
+ Local Ltac cps_id_step :=
+ first [ reflexivity
+ | progress destruct_head' False
+ | progress subst
+ | progress inversion_option
+ | progress cbv [id Compile.binding_dataT pattern.ident.arg_types Compile.ptype_interp Compile.ptype_interp_cps Compile.pbase_type_interp_cps Compile.value Compile.value' Compile.app_binding_data Compile.app_ptype_interp_cps Compile.app_pbase_type_interp_cps Compile.lift_with_bindings Compile.lift_ptype_interp_cps Compile.lift_pbase_type_interp_cps cpsbind cpscall cpsreturn cps_option_bind type_base rwhen] in *
+ | progress cbn [UnderLets.splice eq_rect projT1 projT2 Option.bind Option.sequence Option.sequence_return] in *
+ | progress type_beq_to_eq
+ | progress rewrite_type_transport_correct
+ | progress cbv [Compile.option_bind' Compile.castbe Compile.castb Compile.castv] in *
+ | progress break_innermost_match
+ | progress destruct_head'_sigT
+ | rewrite !expr.reflect_list_cps_id
+ | match goal with
+ | [ |- context[@rlist_rect_cast ?var ?A ?A' ?P ?ivar ?N_case ?C_case ?ls ?T ?k] ]
+ => (tryif (let __ := constr:(eq_refl : k = (fun x => x)) in idtac)
+ then fail
+ else rewrite (@rlist_rect_cast_cps_id var A A' P ivar N_case C_case ls T k))
+ | [ |- context[@rlist_rect ?var ?A ?P ?ivar ?N_case ?C_case ?ls ?T ?k] ]
+ => (tryif (let __ := constr:(eq_refl : k = (fun x => x)) in idtac)
+ then fail
+ else rewrite (@rlist_rect_cps_id var A P ivar N_case C_case ls T k))
+ end
+ | progress cbv [Option.bind] in *
+ | break_match_step ltac:(fun _ => idtac) ].
+
+ Local Ltac cps_id_t := start_cps_id; repeat cps_id_step.
+
+ Lemma nbe_cps_id {var} p r
+ : In (existT _ p r) (@nbe_rewrite_rules var)
+ -> forall v T k, r v T k = k (r v _ id).
+ Proof. cps_id_t. Qed.
+
+ Lemma arith_cps_id max_const {var} p r
+ : In (existT _ p r) (@arith_rewrite_rules var max_const)
+ -> forall v T k, r v T k = k (r v _ id).
+ Proof. cps_id_t. Qed.
+
+ Lemma fancy_cps_id invert_low invert_high {var} p r
+ : In (existT _ p r) (@fancy_rewrite_rules var invert_low invert_high)
+ -> forall v T k, r v T k = k (r v _ id).
+ Proof. cps_id_t. Qed.
+
+ Local Ltac start_good cps_id rewrite_rules :=
+ split; [ reflexivity | ];
+ repeat apply conj; try solve [ eapply cps_id ]; [];
+ cbv [rewrite_rules]; cbn [In combine];
+ intros; destruct_head'_or; inversion_prod; inversion_sigma; subst; destruct_head' False;
+ (split; [ reflexivity | ]).
+
+ Local Ltac good_t_step :=
+ first [ progress subst
+ | progress cbv [id Compile.binding_dataT pattern.ident.arg_types Compile.ptype_interp Compile.ptype_interp_cps Compile.pbase_type_interp_cps Compile.value Compile.value' Compile.app_binding_data Compile.app_ptype_interp_cps Compile.app_pbase_type_interp_cps Compile.lift_with_bindings Compile.lift_ptype_interp_cps Compile.lift_pbase_type_interp_cps cpsbind cpscall cpsreturn cps_option_bind type_base Compile.wf_binding_dataT Compile.wf_ptype_interp_id Compile.wf_ptype_interp_cps Compile.wf_pbase_type_interp_cps ident.smart_Literal rwhen AnyExpr.unwrap] in *
+ | progress destruct_head'_sig
+ | progress cbn [eq_rect option_eq projT1 projT2 fst snd base.interp In combine Option.bind Option.sequence Option.sequence_return UnderLets.splice] in *
+ | progress destruct_head'_prod
+ | progress destruct_head'_sigT
+ | progress intros
+ | progress eliminate_hprop_eq
+ | progress cbv [Compile.option_bind' Compile.castbe Compile.castb Compile.castv] in *
+ | progress type_beq_to_eq
+ | progress rewrite_type_transport_correct
+ | break_innermost_match_step
+ | wf_safe_t_step
+ | rewrite !expr.reflect_list_cps_id
+ | congruence
+ | match goal with
+ | [ |- expr.wf _ (reify_list _) (reify_list _) ] => rewrite expr.wf_reify_list
+ | [ |- context[length ?ls] ] => tryif is_var ls then fail else (progress autorewrite with distr_length)
+ | [ |- ex _ ] => eexists
+ | [ |- UnderLets.wf _ _ _ _ ] => constructor
+ | [ |- UnderLets.wf _ _ (UnderLets.splice _ _) (UnderLets.splice _ _) ] => eapply UnderLets.wf_splice
+ | [ |- Compile.wf_anyexpr _ _ _ _ ] => constructor
+ | [ H : Compile.wf_value ?G ?e1 ?e2 |- UnderLets.wf _ ?G (?e1 _) (?e2 _) ] => eapply (H nil)
+ | [ H : Compile.wf_value ?G ?e1 ?e2 |- UnderLets.wf _ ?G (?e1 _ _) (?e2 _ _) ]
+ => eapply UnderLets.wf_Proper_list; [ | | eapply H; [ reflexivity | | reflexivity | ] ]; revgoals
+ | [ |- context[@rlist_rect_cast ?var ?A ?A' ?P ?ivar ?N_case ?C_case ?ls ?T ?k] ]
+ => (tryif (let __ := constr:(eq_refl : k = (fun x => x)) in idtac)
+ then fail
+ else rewrite (@rlist_rect_cast_cps_id var A A' P ivar N_case C_case ls T k))
+ | [ |- context[@rlist_rect ?var ?A ?P ?ivar ?N_case ?C_case ?ls ?T ?k] ]
+ => (tryif (let __ := constr:(eq_refl : k = (fun x => x)) in idtac)
+ then fail
+ else rewrite (@rlist_rect_cps_id var A P ivar N_case C_case ls T k))
+ | [ |- ?x = ?x /\ _ ] => split; [ reflexivity | ]
+ end
+ | solve [ wf_t ]
+(*| progress cbv [Option.bind]
+ | break_match_step ltac:(fun _ => idtac)*) ].
+
+ Lemma nbe_rewrite_rules_good
+ : rewrite_rules_goodT nbe_rewrite_rules nbe_rewrite_rules.
+ Proof.
+ start_good (@nbe_cps_id) (@nbe_rewrite_rules).
+ all: repeat good_t_step.
+ Admitted.
+
+ Lemma arith_rewrite_rules_good max_const
+ : rewrite_rules_goodT (arith_rewrite_rules max_const) (arith_rewrite_rules max_const).
+ Proof.
+ start_good (@arith_cps_id) (@arith_rewrite_rules).
+ all: repeat good_t_step.
+ Admitted.
+
+ Lemma fancy_rewrite_rules_good
+ (invert_low invert_high : Z -> Z -> option Z)
+ (Hlow : forall s v v', invert_low s v = Some v' -> v = Z.land v' (2^(s/2)-1))
+ (Hhigh : forall s v v', invert_high s v = Some v' -> v = Z.shiftr v' (s/2))
+ : rewrite_rules_goodT (fancy_rewrite_rules invert_low invert_high) (fancy_rewrite_rules invert_low invert_high).
+ Proof.
+ start_good (@fancy_cps_id) (@fancy_rewrite_rules).
+ all: repeat good_t_step.
+ all: cbv [Option.bind].
+ all: repeat good_t_step.
+ Qed.
+ End good.
+ End RewriteRules.
+End Compilers.