diff options
Diffstat (limited to 'src/Experiments/NewPipeline/AbstractInterpretationProofs.v')
-rw-r--r-- | src/Experiments/NewPipeline/AbstractInterpretationProofs.v | 341 |
1 files changed, 240 insertions, 101 deletions
diff --git a/src/Experiments/NewPipeline/AbstractInterpretationProofs.v b/src/Experiments/NewPipeline/AbstractInterpretationProofs.v index b50064e9c..66ed19540 100644 --- a/src/Experiments/NewPipeline/AbstractInterpretationProofs.v +++ b/src/Experiments/NewPipeline/AbstractInterpretationProofs.v @@ -18,6 +18,7 @@ Require Import Crypto.Util.Tactics.SplitInContext. Require Import Crypto.Util.Tactics.UniquePose. Require Import Crypto.Util.Tactics.SpecializeBy. Require Import Crypto.Util.Tactics.SpecializeAllWays. +Require Import Crypto.Util.Tactics.Head. Require Import Crypto.Experiments.NewPipeline.Language. Require Import Crypto.Experiments.NewPipeline.LanguageInversion. Require Import Crypto.Experiments.NewPipeline.LanguageWf. @@ -45,19 +46,147 @@ Module Compilers. Let type_base (x : base_type) : type := type.base x. Local Coercion type_base : base_type >-> type. Context {ident : type -> Type}. - Local Notation expr := (@expr base_type ident). Local Notation Expr := (@expr.Expr base_type ident). - Local Notation UnderLets := (@UnderLets base_type ident). - Context (abstract_domain' : base_type -> Type) + Context (abstract_domain' base_interp : base_type -> Type) + (abstraction_relation' : forall t, abstract_domain' t -> base_interp t -> Prop) (bottom' : forall A, abstract_domain' A) + (bottom'_related : forall t v, abstraction_relation' t (bottom' t) v) (abstract_interp_ident : forall t, ident t -> type.interp abstract_domain' t) - (abstract_domain'_R : forall t, abstract_domain' t -> abstract_domain' t -> Prop) - {abstract_interp_ident_Proper : forall t, Proper (eq ==> abstract_domain'_R t) (abstract_interp_ident t)} - {bottom'_Proper : forall t, Proper (abstract_domain'_R t) (bottom' t)}. + (interp_ident : forall t, ident t -> type.interp base_interp t) + (interp_ident_Proper : forall t (idc : ident t), type.related_hetero abstraction_relation' (abstract_interp_ident t idc) (interp_ident t idc)). Local Notation abstract_domain := (@abstract_domain base_type abstract_domain'). + Definition abstraction_relation {t} : abstract_domain t -> type.interp base_interp t -> Prop + := type.related_hetero (@abstraction_relation'). Local Notation bottom := (@bottom base_type abstract_domain' (@bottom')). Local Notation bottom_for_each_lhs_of_arrow := (@bottom_for_each_lhs_of_arrow base_type abstract_domain' (@bottom')). - Local Notation abstract_domain_R := (@abstract_domain_R base_type abstract_domain' abstract_domain'_R). + Local Notation var := (type.interp base_interp). + Local Notation expr := (@expr base_type ident var). + Local Notation UnderLets := (@UnderLets base_type ident var). + Local Notation value := (@value base_type ident var abstract_domain'). + Local Notation value_with_lets := (@value_with_lets base_type ident var abstract_domain'). + Local Notation state_of_value := (@state_of_value base_type ident var abstract_domain'). + Context (annotate : forall (is_let_bound : bool) t, abstract_domain' t -> expr t -> UnderLets (expr t)). + (* + Local Notation reify1 := (@reify base_type ident var1 abstract_domain' annotate1 bottom'). + Local Notation reify2 := (@reify base_type ident var2 abstract_domain' annotate2 bottom'). + Local Notation reflect1 := (@reflect base_type ident var1 abstract_domain' annotate1 bottom'). + Local Notation reflect2 := (@reflect base_type ident var2 abstract_domain' annotate2 bottom'). + Local Notation interp1 := (@interp base_type ident var1 abstract_domain' annotate1 bottom' interp_ident1). + Local Notation interp2 := (@interp base_type ident var2 abstract_domain' annotate2 bottom' interp_ident2). + Local Notation eval_with_bound'1 := (@eval_with_bound' base_type ident var1 abstract_domain' annotate1 bottom' interp_ident1). + Local Notation eval_with_bound'2 := (@eval_with_bound' base_type ident var2 abstract_domain' annotate2 bottom' interp_ident2). + Local Notation eval'1 := (@eval' base_type ident var1 abstract_domain' annotate1 bottom' interp_ident1). + Local Notation eval'2 := (@eval' base_type ident var2 abstract_domain' annotate2 bottom' interp_ident2). + Local Notation eta_expand_with_bound'1 := (@eta_expand_with_bound' base_type ident var1 abstract_domain' annotate1 bottom'). + Local Notation eta_expand_with_bound'2 := (@eta_expand_with_bound' base_type ident var2 abstract_domain' annotate2 bottom'). +*) +(* + + Fixpoint value (t : type) + := (abstract_domain t + * match t return Type (* COQBUG(https://github.com/coq/coq/issues/7727) *) with + | type.base t + => @expr var t + | type.arrow s d + => value s -> UnderLets (value d) + end)%type. + + Definition value_with_lets (t : type) + := UnderLets (value t). + + + Fixpoint bottom {t} : abstract_domain t + := match t with + | type.base t => bottom' t + | type.arrow s d => fun _ => @bottom d + end. + + Fixpoint bottom_for_each_lhs_of_arrow {t} : type.for_each_lhs_of_arrow abstract_domain t + := match t return type.for_each_lhs_of_arrow abstract_domain t with + | type.base t => tt + | type.arrow s d => (bottom, @bottom_for_each_lhs_of_arrow d) + end. + + Definition state_of_value {t} : value t -> abstract_domain t + := match t return value t -> abstract_domain t with + | type.base t => fun '(st, v) => st + | type.arrow s d => fun '(st, v) => st + end. + + Fixpoint reify (is_let_bound : bool) {t} : value t -> type.for_each_lhs_of_arrow abstract_domain t -> UnderLets (@expr var t) + := match t return value t -> type.for_each_lhs_of_arrow abstract_domain t -> UnderLets (@expr var t) with + | type.base t + => fun '(st, v) 'tt + => annotate is_let_bound t st v + | type.arrow s d + => fun '(f_st, f_e) '(sv, dv) + => Base + (λ x , (UnderLets.to_expr + (fx <-- f_e (@reflect _ (expr.Var x) sv); + @reify false _ fx dv))) + end%core%expr + with reflect {t} : @expr var t -> abstract_domain t -> value t + := match t return @expr var t -> abstract_domain t -> value t with + | type.base t + => fun e st => (st, e) + | type.arrow s d + => fun e absf + => (absf, + (fun v + => let stv := state_of_value v in + (rv <-- (@reify false s v bottom_for_each_lhs_of_arrow); + Base (@reflect d (e @ rv) (absf stv))%expr))) + end%under_lets. + + (* N.B. Because the [App] case only looks at the second argument + of arrow-values, we are free to set the state of [Abs] + nodes to [bottom], because for any [Abs] nodes which are + actually applied (here and in places where we don't + rewrite), we just drop it. *) + Fixpoint interp {t} (e : @expr value_with_lets t) : value_with_lets t + := match e in expr.expr t return value_with_lets t with + | expr.Ident t idc => interp_ident _ idc (* Base (reflect (###idc) (abstract_interp_ident _ idc))*) + | expr.Var t v => v + | expr.Abs s d f => Base (bottom, fun x => @interp d (f (Base x))) + | expr.App s d f x + => (x' <-- @interp s x; + f' <-- @interp (s -> d)%etype f; + snd f' x') + | expr.LetIn (type.arrow _ _) B x f + => (x' <-- @interp _ x; + @interp _ (f (Base x'))) + | expr.LetIn (type.base A) B x f + => (x' <-- @interp _ x; + x'' <-- reify true (* this forces a let-binder here *) x' tt; + @interp _ (f (Base (reflect x'' (state_of_value x'))))) + end%under_lets. + + Definition eval_with_bound' {t} (e : @expr value_with_lets t) + (st : type.for_each_lhs_of_arrow abstract_domain t) + : expr t + := UnderLets.to_expr (e' <-- interp e; reify false e' st). + + Definition eval' {t} (e : @expr value_with_lets t) : expr t + := eval_with_bound' e bottom_for_each_lhs_of_arrow. + + Definition eta_expand_with_bound' {t} (e : @expr var t) + (st : type.for_each_lhs_of_arrow abstract_domain t) + : expr t + := UnderLets.to_expr (reify false (reflect e bottom) st). + + Section extract. + Context (ident_extract : forall t, ident t -> abstract_domain t). + + Definition extract' {t} (e : @expr abstract_domain t) : abstract_domain t + := expr.interp (@ident_extract) e. + + Definition extract_gen {t} (e : @expr abstract_domain t) (bound : type.for_each_lhs_of_arrow abstract_domain t) + : abstract_domain' (type.final_codomain t) + := type.app_curried (extract' e) bound. + End extract. + End with_var. + + Section extract. Context (ident_extract : forall t, ident t -> abstract_domain t) @@ -82,8 +211,9 @@ Module Compilers. eapply extract'_Proper; eassumption. Qed. End extract. +*) End with_type. - +(* Module ident. Import defaults. Local Notation UnderLets := (@UnderLets base.type ident). @@ -118,7 +248,7 @@ Module Compilers. End extract. End with_type. End ident. - +*) Section specialized. Import defaults. Local Notation abstract_domain' := ZRange.type.base.option.interp (only parsing). @@ -201,58 +331,91 @@ Module Compilers. End partial. Import defaults. + Module Import CheckCasts. + Module ident. + Lemma interp_eqv_without_casts t idc + cast_outside_of_range1 cast_outside_of_range2 + (Hc : partial.is_annotation t idc = false) + : ident.gen_interp cast_outside_of_range1 idc + == ident.gen_interp cast_outside_of_range2 idc. + Proof. + generalize (@ident.gen_interp_Proper cast_outside_of_range1 t idc idc eq_refl); + destruct idc; try exact id; cbn in Hc; discriminate. + Qed. + End ident. + + Lemma interp_eqv_without_casts + cast_outside_of_range1 cast_outside_of_range2 + G {t} e1 e2 e3 + (HG : forall t v1 v2 v3, List.In (existT _ t (v1, v2, v3)) G -> v2 == v3) + (Hwf : expr.wf3 G e1 e2 e3) + (Hc : @CheckCasts.get_casts t e1 = nil) + : expr.interp (@ident.gen_interp cast_outside_of_range1) e2 + == expr.interp (@ident.gen_interp cast_outside_of_range2) e3. + Proof. + induction Hwf; + repeat first [ progress cbn [CheckCasts.get_casts] in * + | discriminate + | match goal with + | [ H : (_ ++ _)%list = nil |- _ ] => apply List.app_eq_nil in H + end + | progress destruct_head'_and + | progress break_innermost_match_hyps + | progress interp_safe_t + | solve [ eauto using ident.interp_eqv_without_casts ] ]. + Qed. + + Lemma Interp_WithoutUnsupportedCasts {t} (e : Expr t) + (Hc : CheckCasts.GetUnsupportedCasts e = nil) + (Hwf : expr.Wf3 e) + cast_outside_of_range1 cast_outside_of_range2 + : expr.Interp (@ident.gen_interp cast_outside_of_range1) e + == expr.Interp (@ident.gen_interp cast_outside_of_range2) e. + Proof. eapply interp_eqv_without_casts with (G:=nil); wf_safe_t. Qed. + End CheckCasts. + Module RelaxZRange. + Definition relaxed_cast_outside_of_range + (relax_zrange : zrange -> option zrange) + (cast_outside_of_range : zrange -> Z -> Z) + : zrange -> Z -> Z + := fun r v + => match relax_zrange r with + | Some r' => ident.cast cast_outside_of_range r' v + | None => cast_outside_of_range r v + end. + Module ident. Section relax. Context (relax_zrange : zrange -> option zrange) + (cast_outside_of_range : zrange -> Z -> Z) (Hrelax : forall r r' z, is_tighter_than_bool z r = true -> relax_zrange r = Some r' -> is_tighter_than_bool z r' = true). - Lemma interp_relax {t} (idc idc' : ident t) - (Hidc : @RelaxZRange.ident.relax relax_zrange t idc = Some idc') - v - (Hinterp : forall cast_outside_of_range, type.app_curried (ident.gen_interp cast_outside_of_range idc) v = type.app_curried (ident.interp idc) v) - : forall cast_outside_of_range, type.app_curried (ident.gen_interp cast_outside_of_range idc') v = type.app_curried (ident.interp idc) v. + Local Notation relaxed_cast_outside_of_range := (@relaxed_cast_outside_of_range relax_zrange cast_outside_of_range). + + Lemma interp_relax {t} (idc : ident t) + : ident.gen_interp cast_outside_of_range (@RelaxZRange.ident.relax relax_zrange t idc) + == ident.gen_interp relaxed_cast_outside_of_range idc. Proof. - intro cast_outside_of_range. - pose proof (Hinterp (fun _ => id)). - pose proof (fun myrange => Hinterp (fun _ => cast_outside_of_range myrange)). - destruct idc; cbv [RelaxZRange.ident.relax Option.bind] in *; inversion_option; break_innermost_match_hyps; inversion_option; subst; - repeat match goal with - | [ H : relax_zrange _ = Some _ |- _ ] => unique pose proof (fun zl zu pf => Hrelax _ _ (Build_zrange zl zu) pf H) - end; - repeat first [ reflexivity + pose proof (@ident.gen_interp_Proper cast_outside_of_range t idc idc eq_refl) as Hp. + destruct idc; cbn [type.related] in *; repeat (let x := fresh "x" in intro x; specialize (Hp x)); + repeat first [ assumption + | reflexivity | discriminate | congruence - | progress cbv [RelaxZRange.ident.relax Option.bind id ident.cast is_tighter_than_bool] in * - | progress cbn [fst snd] in * | progress subst - | progress inversion_option - | progress inversion_prod + | progress cbv [relaxed_cast_outside_of_range RelaxZRange.ident.relax Option.bind ident.cast respectful is_tighter_than_bool id] in * + | progress cbn [ident.gen_interp type.related type.interp base.interp upper lower] in * | progress destruct_head'_prod - | progress destruct_head'_and - | progress cbn in * - | progress Bool.split_andb - | progress intros + | progress specialize_by (exact eq_refl) + | break_match_step ltac:(fun x => let h := head x in constr_eq h relax_zrange) | match goal with - | [ H : forall x, (_, _) = (_, _) |- _ ] - => pose proof (fun x => f_equal (@fst _ _) (H x)); - pose proof (fun x => f_equal (@snd _ _) (H x)); - clear H - | [ H : context[andb _ _ = true] |- _ ] => rewrite Bool.andb_true_iff in H || setoid_rewrite Bool.andb_true_iff in H - | [ H : context[Z.leb _ _ = true] |- _ ] => rewrite Z.leb_le in H || setoid_rewrite Z.leb_le in H - | [ H : forall a b, and (Z.le ?x a) (Z.le b ?y) -> _ /\ _, H' : Z.le ?x _, H'' : Z.le _ ?y |- _ ] - => unique pose proof (proj1 (H _ _ (conj H' H''))); - unique pose proof (proj2 (H _ _ (conj H' H''))) + | [ H : relax_zrange ?r = Some ?r' |- context[Z.leb (lower ?r) ?v] ] + => pose proof (fun pf => Hrelax _ _ (Build_zrange v v) pf H); clear H end - | progress rewrite ?Bool.andb_false_iff in * - | progress destruct_head'_or - | progress break_innermost_match_hyps - | progress break_innermost_match - | progress Z.ltb_to_lt - | apply (f_equal2 (@pair _ _)) - | lia ]. + | break_innermost_match_step ]. Qed. End relax. End ident. @@ -262,63 +425,27 @@ Module Compilers. Context (relax_zrange : zrange -> option zrange) (Hrelax : forall r r' z, is_tighter_than_bool z r = true -> relax_zrange r = Some r' - -> is_tighter_than_bool z r' = true). - Lemma interp_relax {t} (e : expr t) - v - (Hinterp : forall cast_outside_of_range, type.app_curried (expr.interp (@ident.gen_interp cast_outside_of_range) e) v - = type.app_curried (defaults.interp e) v) - : forall cast_outside_of_range, type.app_curried (expr.interp (@ident.gen_interp cast_outside_of_range) (RelaxZRange.expr.relax relax_zrange e)) v - = type.app_curried (defaults.interp e) v. + -> is_tighter_than_bool z r' = true) + (cast_outside_of_range : zrange -> Z -> Z). + + Local Notation relaxed_cast_outside_of_range := (@relaxed_cast_outside_of_range relax_zrange cast_outside_of_range). + + Lemma interp_relax G {t} (e1 e2 : expr t) + (HG : forall t v1 v2, List.In (existT _ t (v1, v2)) G -> v1 == v2) + (Hwf : expr.wf G e1 e2) + : expr.interp (@ident.gen_interp cast_outside_of_range) (RelaxZRange.expr.relax relax_zrange e1) + == expr.interp (@ident.gen_interp relaxed_cast_outside_of_range) e2. Proof. - intro cast_outside_of_range; rewrite <- (Hinterp cast_outside_of_range); pose proof (Hinterp cast_outside_of_range). - induction e; cbn -[RelaxZRange.ident.relax] in *; interp_safe_t; cbv [option_map] in *; + induction Hwf; cbn -[RelaxZRange.ident.relax] in *; interp_safe_t; cbv [option_map] in *; break_innermost_match; cbn -[RelaxZRange.ident.relax] in *; interp_safe_t; - eauto using tt. - all: repeat first [ reflexivity - | progress intros - | progress specialize_by_assumption - | progress cbn -[RelaxZRange.ident.relax] in * - | match goal with - | [ H : unit -> ?T |- _ ] => specialize (H tt) - | [ H : forall x : _ * _, _ |- _ ] => specialize (fun a b => H (a, b)) - | [ e : expr (type.base (base.type.type_base base.type.unit)) |- _ ] - => match goal with - | [ |- context[expr.interp ?ii e] ] => destruct (expr.interp ii e) - | [ H : context[expr.interp ?ii e] |- _ ] => destruct (expr.interp ii e) - end - end - | progress cbn [fst snd] in * - | match goal with - | [ H : _ |- _ ] => rewrite H - end ]. - all: specialize_all_ways. - all: repeat first [ reflexivity - | progress intros - | progress specialize_by_assumption - | progress cbn -[RelaxZRange.ident.relax] in * - | match goal with - | [ H : unit -> ?T |- _ ] => specialize (H tt) - | [ H : forall x : _ * _, _ |- _ ] => specialize (fun a b => H (a, b)) - | [ e : expr (type.base (base.type.type_base base.type.unit)) |- _ ] - => match goal with - | [ |- context[expr.interp ?ii e] ] => destruct (expr.interp ii e) - | [ H : context[expr.interp ?ii e] |- _ ] => destruct (expr.interp ii e) - end - end - | progress cbn [fst snd] in * - | match goal with - | [ H : _ |- _ ] => rewrite H - end ]. - Admitted. + eauto using tt, @ident.interp_relax. + Qed. Lemma Interp_Relax {t} (e : Expr t) - (Hwf : expr.Wf3 e) - v - (Hinterp : forall cast_outside_of_range, type.app_curried (expr.Interp (@ident.gen_interp cast_outside_of_range) e) v - = type.app_curried (defaults.Interp e) v) - : forall cast_outside_of_range, type.app_curried (expr.Interp (@ident.gen_interp cast_outside_of_range) (RelaxZRange.expr.Relax relax_zrange e)) v - = type.app_curried (defaults.Interp e) v. - Proof. eapply @interp_relax; try assumption. Qed. + (Hwf : Wf e) + : expr.Interp (@ident.gen_interp cast_outside_of_range) (RelaxZRange.expr.Relax relax_zrange e) + == expr.Interp (@ident.gen_interp relaxed_cast_outside_of_range) e. + Proof. apply interp_relax with (G:=nil); wf_safe_t. Qed. End relax. End expr. End RelaxZRange. @@ -336,11 +463,12 @@ Module Compilers. (Harg12 : type.and_for_each_lhs_of_arrow (@type.eqv) arg1 arg2) (Harg1 : type.andb_bool_for_each_lhs_of_arrow (@ZRange.type.option.is_bounded_by) b_in arg1 = true), type.app_curried (expr.Interp (@ident.gen_interp cast_outside_of_range) (PartialEvaluateWithBounds E b_in)) arg1 - = type.app_curried (Interp E) arg2. + = type.app_curried (expr.Interp (@ident.gen_interp cast_outside_of_range) E) arg2. Proof. Admitted. Lemma Interp_PartialEvaluateWithBounds_bounded + cast_outside_of_range {t} (E : Expr t) (Hwf : Wf E) (b_in : type.for_each_lhs_of_arrow ZRange.type.option.interp t) @@ -348,7 +476,7 @@ Module Compilers. (Harg1 : type.andb_bool_for_each_lhs_of_arrow (@ZRange.type.option.is_bounded_by) b_in arg1 = true), ZRange.type.base.option.is_bounded_by (partial.Extract (PartialEvaluateWithBounds E b_in) b_in) - (type.app_curried (expr.Interp (@ident.interp) (PartialEvaluateWithBounds E b_in)) arg1) + (type.app_curried (expr.Interp (@ident.gen_interp cast_outside_of_range) (PartialEvaluateWithBounds E b_in)) arg1) = true. Proof. Admitted. @@ -362,6 +490,7 @@ Module Compilers. (Harg1 : type.andb_bool_for_each_lhs_of_arrow (@ZRange.type.option.is_bounded_by) b_in arg1 = true), type.app_curried (Interp (PartialEvaluateWithListInfoFromBounds E b_in)) arg1 = type.app_curried (Interp E) arg2. Proof. + cbv [PartialEvaluateWithListInfoFromBounds]. Admitted. Theorem CheckedPartialEvaluateWithBounds_Correct @@ -385,6 +514,14 @@ Module Compilers. Proof. cbv [CheckedPartialEvaluateWithBounds CheckPartialEvaluateWithBounds Let_In] in *; break_innermost_match_hyps; inversion_sum; subst. + let H := lazymatch goal with H : _ = nil |- _ => H end in + pose proof (@Interp_WithoutUnsupportedCasts _ _ H ltac:(solve [ auto with wf ])) as H'; clear H; + assert (forall cast_outside_of_range1 cast_outside_of_range2, + expr.Interp (@ident.gen_interp cast_outside_of_range1) E == expr.Interp (@ident.gen_interp cast_outside_of_range2) E) + by (intros c1 c2; specialize (H' c1 c2); + rewrite !@GeneralizeVar.Interp_gen1_FromFlat_ToFlat in H' by eauto with wf typeclass_instances; + assumption). + clear H'. split. { intros arg1 arg2 Harg12 Harg1. assert (arg1_Proper : Proper (type.and_for_each_lhs_of_arrow (@type.related base.type base.interp (fun _ => eq))) arg1) @@ -398,6 +535,8 @@ Module Compilers. | rewrite RelaxZRange.expr.Interp_Relax; eauto | erewrite !Interp_PartialEvaluateWithBounds | solve [ eauto with wf ] + | apply type.app_curried_Proper + | apply expr.Wf_Interp_Proper_gen | progress intros ]. } { auto with wf. } Qed. |