aboutsummaryrefslogtreecommitdiff
path: root/src/LegacyArithmetic/InterfaceProofs.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/LegacyArithmetic/InterfaceProofs.v')
-rw-r--r--src/LegacyArithmetic/InterfaceProofs.v224
1 files changed, 224 insertions, 0 deletions
diff --git a/src/LegacyArithmetic/InterfaceProofs.v b/src/LegacyArithmetic/InterfaceProofs.v
new file mode 100644
index 000000000..9ef97fa55
--- /dev/null
+++ b/src/LegacyArithmetic/InterfaceProofs.v
@@ -0,0 +1,224 @@
+(** * Alternate forms for Interface for bounded arithmetic *)
+Require Import Coq.ZArith.ZArith Coq.micromega.Psatz.
+Require Import Crypto.LegacyArithmetic.Interface.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.Tuple.
+Require Import Crypto.Util.AutoRewrite.
+Require Import Crypto.Util.Notations.
+
+Local Open Scope type_scope.
+Local Open Scope Z_scope.
+
+Import BoundedRewriteNotations.
+Local Notation bit b := (if b then 1 else 0).
+
+Lemma decoder_eta {n W} (decode : decoder n W) : decode = {| Interface.decode := decode |}.
+Proof. destruct decode; reflexivity. Defined.
+
+Section InstructionGallery.
+ Context (n : Z) (* bit-width of width of [W] *)
+ {W : Type} (* bounded type, [W] for word *)
+ (Wdecoder : decoder n W).
+ Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *)
+
+ Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate W)
+ (pf : forall r count, 0 <= count < n
+ -> _ /\ _)
+ := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H);
+ decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}.
+
+ Definition Build_is_add_with_carry' (adc : add_with_carry W)
+ (pf : forall x y c, _ /\ _)
+ := {| bit_fst_add_with_carry x y c := proj1 (pf x y c);
+ decode_snd_add_with_carry x y c := proj2 (pf x y c) |}.
+
+ Definition Build_is_sub_with_carry' (subc : sub_with_carry W)
+ (pf : forall x y c, _ /\ _)
+ : is_sub_with_carry subc
+ := {| fst_sub_with_carry x y c := proj1 (pf x y c);
+ decode_snd_sub_with_carry x y c := proj2 (pf x y c) |}.
+
+ Definition Build_is_mul_double' (muldw : multiply_double W)
+ (pf : forall x y, _ /\ _)
+ := {| decode_fst_mul_double x y := proj1 (pf x y);
+ decode_snd_mul_double x y := proj2 (pf x y) |}.
+
+ Lemma is_spread_left_immediate_alt
+ {sprl : spread_left_immediate W}
+ {isdecode : is_decode Wdecoder}
+ : is_spread_left_immediate sprl
+ <-> (forall r count, 0 <= count < n -> decode (fst (sprl r count)) + decode (snd (sprl r count)) << n = (decode r << count) mod (2^n*2^n))%Z.
+ Proof using Type.
+ split; intro H; [ | apply Build_is_spread_left_immediate' ];
+ intros r count Hc;
+ [ | specialize (H r count Hc); revert H ];
+ unfold bounded_in_range_cls in *;
+ pose proof (decode_range r);
+ assert (0 < 2^n) by auto with zarith;
+ assert (0 <= 2^count < 2^n)%Z by auto with zarith;
+ assert (0 <= decode r * 2^count < 2^n * 2^n)%Z by (generalize dependent (decode r); intros; nia);
+ rewrite ?decode_fst_spread_left_immediate, ?decode_snd_spread_left_immediate
+ by typeclasses eauto with typeclass_instances core;
+ autorewrite with Zshift_to_pow zsimplify push_Zpow.
+ { reflexivity. }
+ { intro H'; rewrite <- H'.
+ autorewrite with zsimplify; split; reflexivity. }
+ Qed.
+
+ Lemma is_mul_double_alt
+ {muldw : multiply_double W}
+ {isdecode : is_decode Wdecoder}
+ : is_mul_double muldw
+ <-> (forall x y, decode (fst (muldw x y)) + decode (snd (muldw x y)) << n = (decode x * decode y) mod (2^n*2^n)).
+ Proof using Type.
+ split; intro H; [ | apply Build_is_mul_double' ];
+ intros x y;
+ [ | specialize (H x y); revert H ];
+ pose proof (decode_range x);
+ pose proof (decode_range y);
+ assert (0 < 2^n) by auto with zarith;
+ assert (0 <= decode x * decode y < 2^n * 2^n)%Z by nia;
+ (destruct (0 <=? n) eqn:?; Z.ltb_to_lt;
+ [ | assert (2^n = 0) by auto with zarith; exfalso; omega ]);
+ rewrite ?decode_fst_mul_double, ?decode_snd_mul_double
+ by typeclasses eauto with typeclass_instances core;
+ autorewrite with Zshift_to_pow zsimplify push_Zpow.
+ { reflexivity. }
+ { intro H'; rewrite <- H'.
+ autorewrite with zsimplify; split; reflexivity. }
+ Qed.
+End InstructionGallery.
+
+Global Arguments is_spread_left_immediate_alt {_ _ _ _ _}.
+Global Arguments is_mul_double_alt {_ _ _ _ _}.
+
+Ltac bounded_solver_tac :=
+ solve [ eassumption | typeclasses eauto | omega ].
+
+Global Instance decode_proj n W (dec : W -> Z)
+ : @decode n W {| decode := dec |} =~> dec.
+Proof. reflexivity. Qed.
+
+Global Instance decode_if_bool n W (decode : decoder n W) (b : bool) x y
+ : decode (if b then x else y)
+ =~> if b then decode x else decode y.
+Proof. destruct b; reflexivity. Qed.
+
+Global Instance decode_mod_small {n W} {decode : decoder n W} {x b}
+ {H : bounded_in_range_cls 0 (decode x) b}
+ : decode x <~= decode x mod b.
+Proof.
+ Z.rewrite_mod_small; reflexivity.
+Qed.
+
+Global Instance decode_mod_range {n W decode} {H : @is_decode n W decode} x
+ : decode x <~= decode x mod 2^n.
+Proof. exact _. Qed.
+
+Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode}
+ (isinhabited : W)
+ : (0 <= n)%Z.
+Proof.
+ pose proof (decode_range isinhabited).
+ assert (0 < 2^n) by omega.
+ destruct (Z_lt_ge_dec n 0) as [H'|]; [ | omega ].
+ assert (2^n = 0) by auto using Z.pow_neg_r.
+ omega.
+Qed.
+
+Section adc_subc.
+ Context {n W}
+ {decode : decoder n W}
+ {adc : add_with_carry W}
+ {subc : sub_with_carry W}
+ {isdecode : is_decode decode}
+ {isadc : is_add_with_carry adc}
+ {issubc : is_sub_with_carry subc}.
+ Global Instance bit_fst_add_with_carry_false
+ : forall x y, bit (fst (adc x y false)) <~=~> (decode x + decode y) >> n.
+ Proof using isadc.
+ intros; erewrite bit_fst_add_with_carry by assumption.
+ autorewrite with zsimplify_const; reflexivity.
+ Qed.
+ Global Instance bit_fst_add_with_carry_true
+ : forall x y, bit (fst (adc x y true)) <~=~> (decode x + decode y + 1) >> n.
+ Proof using isadc.
+ intros; erewrite bit_fst_add_with_carry by assumption.
+ autorewrite with zsimplify_const; reflexivity.
+ Qed.
+ Global Instance fst_add_with_carry_leb
+ : forall x y c, fst (adc x y c) <~= (2^n <=? (decode x + decode y + bit c)).
+ Proof using isadc isdecode.
+ intros x y c; hnf.
+ assert (0 <= n)%Z by eauto using decode_exponent_nonnegative.
+ pose proof (decode_range x); pose proof (decode_range y).
+ assert (0 <= bit c <= 1)%Z by (destruct c; omega).
+ lazymatch goal with
+ | [ |- fst ?x = (?a <=? ?b) :> bool ]
+ => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z);
+ [ destruct (fst x), (a <=? b); intro; congruence | ]
+ end.
+ push_decode.
+ autorewrite with Zshift_to_pow.
+ rewrite Z.div_between_0_if by auto with zarith.
+ reflexivity.
+ Qed.
+ Global Instance fst_add_with_carry_false_leb
+ : forall x y, fst (adc x y false) <~= (2^n <=? (decode x + decode y)).
+ Proof using isadc isdecode.
+ intros; erewrite fst_add_with_carry_leb by assumption.
+ autorewrite with zsimplify_const; reflexivity.
+ Qed.
+ Global Instance fst_add_with_carry_true_leb
+ : forall x y, fst (adc x y true) <~=~> (2^n <=? (decode x + decode y + 1)).
+ Proof using isadc isdecode.
+ intros; erewrite fst_add_with_carry_leb by assumption.
+ autorewrite with zsimplify_const; reflexivity.
+ Qed.
+ Global Instance fst_sub_with_carry_false
+ : forall x y, fst (subc x y false) <~=~> ((decode x - decode y) <? 0).
+ Proof using issubc.
+ intros; erewrite fst_sub_with_carry by assumption.
+ autorewrite with zsimplify_const; reflexivity.
+ Qed.
+ Global Instance fst_sub_with_carry_true
+ : forall x y, fst (subc x y true) <~=~> ((decode x - decode y - 1) <? 0).
+ Proof using issubc.
+ intros; erewrite fst_sub_with_carry by assumption.
+ autorewrite with zsimplify_const; reflexivity.
+ Qed.
+End adc_subc.
+
+Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y)))
+=> apply @fst_add_with_carry_false_leb : typeclass_instances.
+Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + 1)))
+=> apply @fst_add_with_carry_true_leb : typeclass_instances.
+Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + if ?c then _ else _)))
+=> apply @fst_add_with_carry_leb : typeclass_instances.
+
+
+(* We take special care to handle the case where the decoder is
+ syntactically different but the decoded expression is judgmentally
+ the same; we don't want to split apart variables that should be the
+ same. *)
+Ltac set_decode_step check :=
+ match goal with
+ | [ |- context G[@decode ?n ?W ?dr ?w] ]
+ => check w;
+ first [ match goal with
+ | [ d := @decode _ _ _ w |- _ ]
+ => change (@decode n W dr w) with d
+ end
+ | generalize (@decode_range n W dr _ w);
+ let d := fresh "d" in
+ set (d := @decode n W dr w);
+ intro ]
+ end.
+Ltac set_decode check := repeat set_decode_step check.
+Ltac clearbody_decode :=
+ repeat match goal with
+ | [ H := @decode _ _ _ _ |- _ ] => clearbody H
+ end.
+Ltac generalize_decode_by check := set_decode check; clearbody_decode.
+Ltac generalize_decode := generalize_decode_by ltac:(fun w => idtac).
+Ltac generalize_decode_var := generalize_decode_by ltac:(fun w => is_var w).