aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/Z/Named/RewriteAddToAdcInterp.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Compilers/Z/Named/RewriteAddToAdcInterp.v')
-rw-r--r--src/Compilers/Z/Named/RewriteAddToAdcInterp.v281
1 files changed, 281 insertions, 0 deletions
diff --git a/src/Compilers/Z/Named/RewriteAddToAdcInterp.v b/src/Compilers/Z/Named/RewriteAddToAdcInterp.v
new file mode 100644
index 000000000..91a73cfc6
--- /dev/null
+++ b/src/Compilers/Z/Named/RewriteAddToAdcInterp.v
@@ -0,0 +1,281 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Crypto.Compilers.Named.Context.
+Require Import Crypto.Compilers.Named.ContextDefinitions.
+Require Import Crypto.Compilers.Named.ContextProperties.Proper.
+Require Import Crypto.Compilers.Syntax.
+Require Import Crypto.Compilers.Z.Syntax.
+Require Import Crypto.Compilers.Z.Syntax.Equality.
+Require Import Crypto.Compilers.Z.Named.RewriteAddToAdc.
+Require Import Crypto.Compilers.Named.Syntax.
+Require Import Crypto.Util.Notations.
+Require Import Crypto.Util.Option.
+Require Import Crypto.Util.Tactics.DestructHead.
+Require Import Crypto.Util.Tactics.BreakMatch.
+Require Import Crypto.Util.Prod.
+Require Import Crypto.Util.LetIn.
+Require Import Crypto.Util.Bool.
+Require Import Crypto.Util.ZUtil.AddGetCarry.
+Require Import Crypto.Util.Decidable.
+
+Local Open Scope Z_scope.
+
+Section named.
+ Context {Name : Type}
+ {InterpContext : Context Name interp_base_type}
+ {InterpContextOk : ContextOk InterpContext}
+ (Name_beq : Name -> Name -> bool)
+ (Name_bl : forall n1 n2, Name_beq n1 n2 = true -> n1 = n2)
+ (Name_lb : forall n1 n2, n1 = n2 -> Name_beq n1 n2 = true).
+
+ Local Notation exprf := (@exprf base_type op Name).
+ Local Notation expr := (@expr base_type op Name).
+ Local Notation do_rewrite := (@do_rewrite Name Name_beq).
+ Local Notation do_rewriteo := (@do_rewriteo Name Name_beq).
+ Local Notation rewrite_exprf := (@rewrite_exprf Name Name_beq).
+ Local Notation rewrite_exprf_prestep := (@rewrite_exprf_prestep Name).
+ Local Notation rewrite_expr := (@rewrite_expr Name Name_beq).
+
+ Local Instance Name_dec : DecidableRel (@eq Name)
+ := dec_rel_of_bool_dec_rel Name_beq Name_bl Name_lb.
+
+ Local Notation retT e re :=
+ (forall (ctx : InterpContext)
+ v,
+ Named.interpf (interp_op:=interp_op) (ctx:=ctx) re = Some v
+ -> Named.interpf (interp_op:=interp_op) (ctx:=ctx) e = Some v)
+ (only parsing).
+ Local Notation tZ := (Tbase TZ).
+ Local Notation ADC bw c x y := (Op (@AddWithGetCarry bw TZ TZ TZ TZ TZ)
+ (Pair (Pair (t1:=tZ) c (t2:=tZ) x) (t2:=tZ) y)).
+ Local Notation ADD bw x y := (ADC bw (Op (OpConst 0) TT) x y).
+ Local Notation ADX x y := (Op (@Add TZ TZ TZ) (Pair (t1:=tZ) x (t2:=tZ) y)).
+
+ Local Ltac simple_t_step :=
+ first [ discriminate
+ | exact I
+ | progress intros
+ | progress subst
+ | progress inversion_option ].
+ Local Ltac destruct_t_step :=
+ first [ break_innermost_match_hyps_step
+ | break_innermost_match_step ].
+ Local Ltac do_small_inversion e :=
+ is_var e;
+ lazymatch type of e with
+ | exprf ?T
+ => revert dependent e;
+ let P := match goal with |- forall e, @?P e => P end in
+ intro e;
+ lazymatch T with
+ | Unit
+ => refine match e in Named.exprf _ _ _ t return match t return Named.exprf _ _ _ t -> _ with Unit => P | _ => fun _ => True end e with TT => _ | _ => _ end
+ | tZ
+ => refine match e in Named.exprf _ _ _ t return match t return Named.exprf _ _ _ t -> _ with tZ => P | _ => fun _ => True end e with TT => _ | _ => _ end
+ | (tZ * tZ)%ctype
+ => refine match e in Named.exprf _ _ _ t return match t return Named.exprf _ _ _ t -> _ with (tZ * tZ)%ctype => P | _ => fun _ => True end e with TT => _ | _ => _ end
+ | (tZ * tZ * tZ)%ctype
+ => refine match e in Named.exprf _ _ _ t return match t return Named.exprf _ _ _ t -> _ with (tZ * tZ * tZ)%ctype => P | _ => fun _ => True end e with TT => _ | _ => _ end
+ end;
+ try exact I
+ | op ?a ?T
+ => first [ is_var a;
+ move e at top;
+ revert dependent a;
+ let P := match goal with |- forall a e, @?P a e => P end in
+ intros a e;
+ lazymatch T with
+ | tZ
+ => refine match e in op a t return match t return op a t -> _ with tZ => P a | _ => fun _ => True end e with OpConst _ _ => _ | _ => _ end
+ | (tZ * tZ)%ctype
+ => refine match e in op a t return match t return op a t -> _ with (tZ * tZ)%ctype => P a | _ => fun _ => True end e with OpConst _ _ => _ | _ => _ end
+ end ];
+ try exact I
+ end.
+ Local Ltac small_inversion_step _ :=
+ match goal with
+ | [ H : match ?e with _ => _ end = Some _ |- _ ] => do_small_inversion e
+ | [ H : match ?e with _ => _ end = true |- _ ] => do_small_inversion e
+ | [ H : match ?e with _ => _ end _ = Some _ |- _ ] => do_small_inversion e
+ end.
+
+ Local Ltac rewrite_lookupb_step :=
+ first [ rewrite !lookupb_extendb_different in * by (assumption || congruence)
+ | rewrite !lookupb_extendb_same in * by assumption
+ | rewrite !lookupb_extendb_wrong_type in * by (assumption || congruence)
+ | match goal with
+ | [ H : context[lookupb (extendb _ _ _) _] |- _ ] => revert H
+ | [ |- context[lookupb (extendb _ ?n _) ?n'] ]
+ => (tryif constr_eq n n' then fail else idtac);
+ lazymatch goal with
+ | [ H : n = n' |- _ ] => fail
+ | [ H : n' = n |- _ ] => fail
+ | [ H : n <> n' |- _ ] => fail
+ | [ H : n' <> n |- _ ] => fail
+ | _ => destruct (dec (n = n')); subst
+ end
+ | [ |- context[lookupb (t:=?t0) (extendb (t:=?t1) _ _ _) _] ]
+ => (tryif constr_eq t0 t1 then fail else idtac);
+ lazymatch goal with
+ | [ H : t0 = t1 |- _ ] => fail
+ | [ H : t1 = t0 |- _ ] => fail
+ | [ H : t0 <> t1 |- _ ] => fail
+ | [ H : t1 <> t0 |- _ ] => fail
+ | _ => destruct (dec (t0 = t1)); subst
+ end
+ end ].
+ Local Ltac rewrite_lookupb := repeat rewrite_lookupb_step.
+
+ Local Ltac do_rewrite_adc' P :=
+ let lem := open_constr:(Z.add_get_carry_to_add_with_get_carry_cps _ _ _ _ P) in
+ let T := type of lem in
+ let T := (eval cbv [Let_In Definitions.Z.add_with_get_carry Definitions.Z.add_with_get_carry Definitions.Z.get_carry Definitions.Z.add_get_carry] in T) in
+ etransitivity; [ | eapply (lem : T) ];
+ try reflexivity.
+ Local Ltac do_rewrite_adc :=
+ first [ do_rewrite_adc' uconstr:(fun a b => Some b)
+ | do_rewrite_adc' uconstr:(fun a b => Some a) ].
+
+ Lemma interpf_do_rewrite
+ {t} {e e' : exprf t}
+ (H : do_rewrite e = Some e')
+ : retT e e'.
+ Proof.
+ unfold do_rewrite in H;
+ repeat first [ simple_t_step
+ | small_inversion_step ()
+ | destruct_t_step ].
+ Time all:match goal with
+ | [ H : _ = ?x |- _ = ?x ] => rewrite <- H; clear H
+ end.
+ Time all:split_andb.
+ Time all:progress simpl @negb in *.
+ Time all:repeat match goal with
+ | [ H : Name_beq _ _ = true |- _ ] => apply Name_bl in H
+ | [ H : Z.eqb _ _ = true |- _ ] => apply Z.eqb_eq in H
+ end.
+ Time all:subst.
+ Local Ltac do_small_inversion_ctx :=
+ repeat match goal with
+ | [ H : is_const_or_var ?e = true |- _ ]
+ => do_small_inversion e; break_innermost_match; intros; try exact I;
+ simpl in H; try solve [ clear -H; discriminate ]
+ | [ H : match ?e with _ => _ end = true |- _ ]
+ => do_small_inversion e; break_innermost_match; intros; try exact I;
+ simpl in H; try solve [ clear -H; discriminate ]
+ | [ H : match ?e with _ => _ end _ = true |- _ ]
+ => do_small_inversion e; break_innermost_match; intros; try exact I;
+ simpl in H; try solve [ clear -H; discriminate ]
+ end.
+ Time all:do_small_inversion_ctx.
+ Time all:simpl @negb in *.
+ Time all:rewrite !Bool.negb_orb in *.
+ Time all:split_andb.
+ Time all:rewrite !Bool.negb_true_iff in *.
+ Time all:repeat
+ match goal with
+ | [ H : Name_beq ?x ?y = false |- _ ]
+ => assert (x <> y) by (clear -H Name_lb; intro; rewrite Name_lb in H by assumption; congruence);
+ clear H
+ end.
+ Time all:subst.
+ Time all:simpl @interpf in *.
+ Time all:cbv [interp_op option_map lift_op Zinterp_op] in *; simpl in *.
+ Time all:unfold Let_In in * |- .
+ Time all:break_innermost_match; try reflexivity.
+ Local Ltac t_fin_step :=
+ match goal with
+ | [ |- ?x = ?x ] => reflexivity
+ | [ H : ?x = Some _ |- context[?x] ] => rewrite H
+ | [ H : ?x = None |- context[?x] ] => rewrite H
+ | [ H : ?x = Some ?a, H' : ?x = Some ?b |- _ ]
+ => assert (a = b) by congruence; (subst a || subst b)
+ | _ => progress rewrite_lookupb
+ | _ => progress simpl in *
+ | _ => progress intros
+ | _ => progress subst
+ | _ => progress inversion_option
+ | [ |- (dlet x := _ in _) = (dlet y := _ in _) ]
+ => apply Proper_Let_In_nd_changebody_eq; intros ??
+ | _ => progress unfold Let_In
+ | [ |- interpf ?x = interpf ?x ]
+ => eapply @interpf_Proper; [ eauto with typeclass_instances.. | intros ?? | reflexivity ]
+ | _ => progress break_innermost_match; try reflexivity
+ | _ => progress break_innermost_match_hyps; try reflexivity
+ | _ => progress break_match; try reflexivity
+ end.
+ Local Ltac t_fin :=
+ repeat t_fin_step;
+ try do_rewrite_adc.
+ { Time t_fin. }
+ { Time t_fin. }
+ { Time t_fin. }
+ { Time t_fin. }
+ { Time t_fin. }
+ { Time t_fin. }
+ { Time t_fin. }
+ { Time t_fin. }
+ Time Qed.
+
+ Lemma interpf_do_rewriteo
+ {t} {e : exprf t}
+ : retT e (do_rewriteo e).
+ Proof.
+ unfold do_rewriteo; intros *; break_innermost_match; try congruence.
+ apply interpf_do_rewrite; assumption.
+ Qed.
+
+ Local Opaque RewriteAddToAdc.do_rewriteo.
+ Lemma interpf_rewrite_exprf
+ {t} (e : exprf t)
+ : retT e (rewrite_exprf e).
+ Proof.
+ pose t as T.
+ pose (rewrite_exprf_prestep (@rewrite_exprf) e) as E.
+ induction e; simpl in *;
+ intros ctx v H;
+ pose proof (interpf_do_rewriteo (t:=T) (e:=E) ctx v H); clear H;
+ subst T E;
+ repeat first [ assumption
+ | progress unfold option_map, Let_In in *
+ | progress simpl in *
+ | progress subst
+ | progress inversion_option
+ | apply (f_equal (@Some _))
+ | break_innermost_match_step
+ | break_innermost_match_hyps_step
+ | congruence
+ | solve [ eauto ]
+ | match goal with
+ | [ IH : forall ctx v, interpf ?e = Some v -> _ = Some _, H' : interpf ?e = Some _ |- _ ]
+ => specialize (IH _ _ H')
+ | [ H : ?x = Some ?a, H' : ?x = Some ?b |- _ ]
+ => assert (a = b) by congruence; (subst a || subst b)
+ | [ |- ?rhs = Some _ ]
+ => lazymatch rhs with
+ | Some _ => fail
+ | None => fail
+ | _ => destruct rhs eqn:?
+ end
+ end ].
+ Qed.
+
+ Lemma interp_rewrite_expr
+ {t} (e : expr t)
+ : forall (ctx : InterpContext)
+ v x,
+ Named.interp (interp_op:=interp_op) (ctx:=ctx) (rewrite_expr e) x = Some v
+ -> Named.interp (interp_op:=interp_op) (ctx:=ctx) e x = Some v.
+ Proof.
+ unfold Named.interp, rewrite_expr; destruct e; simpl.
+ intros *; apply interpf_rewrite_exprf.
+ Qed.
+
+ Lemma Interp_rewrite_expr
+ {t} (e : expr t)
+ : forall v x,
+ Named.Interp (Context:=InterpContext) (interp_op:=interp_op) (rewrite_expr e) x = Some v
+ -> Named.Interp (Context:=InterpContext) (interp_op:=interp_op) e x = Some v.
+ Proof.
+ intros *; apply interp_rewrite_expr.
+ Qed.
+End named.