diff options
author | Jason Gross <jasongross9@gmail.com> | 2016-08-25 18:48:25 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-25 18:48:25 -0700 |
commit | 176e33eac56ef0c86872720c14df8ff1dcb7dbf6 (patch) | |
tree | 267580f90119fc062f24a8ae1ef1ff7aea8a2ce2 /src | |
parent | 027764b7854cc8f1a089d7a962b71a00ec291032 (diff) | |
parent | bf0d9280ebf806eef8ee3280f7976edb3282ae6e (diff) |
Merge pull request #52 from JasonGross/bounded-interface
Initial work on an architecture interface for ℤ/nℤ
Diffstat (limited to 'src')
-rw-r--r-- | src/BoundedArithmetic/ArchitectureToZLike.v | 32 | ||||
-rw-r--r-- | src/BoundedArithmetic/ArchitectureToZLikeProofs.v | 110 | ||||
-rw-r--r-- | src/BoundedArithmetic/DoubleBounded.v | 102 | ||||
-rw-r--r-- | src/BoundedArithmetic/DoubleBoundedProofs.v | 470 | ||||
-rw-r--r-- | src/BoundedArithmetic/Interface.v | 281 | ||||
-rw-r--r-- | src/BoundedArithmetic/InterfaceProofs.v | 221 | ||||
-rw-r--r-- | src/Util/AutoRewrite.v | 56 | ||||
-rw-r--r-- | src/Util/Notations.v | 1 |
8 files changed, 1273 insertions, 0 deletions
diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v new file mode 100644 index 000000000..3388ece78 --- /dev/null +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -0,0 +1,32 @@ +(*** Implementing ℤ-Like via Architecture *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.DoubleBounded. +Require Import Crypto.ModularArithmetic.ZBounded. +Require Import Crypto.Util.Tuple. + +Local Open Scope Z_scope. + +Section fancy_machine_p256_montgomery_foundation. + Context {n_over_two : Z}. + Local Notation n := (2 * n_over_two). + Context (ops : fancy_machine.instructions n) (modulus : Z). + + Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : Z) + : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := + { LargeT := tuple fancy_machine.W 2; + SmallT := fancy_machine.W; + modulus_digits := ldi modulus; + decode_large := decode; + decode_small := decode; + Mod_SmallBound v := fst v; + DivBy_SmallBound v := snd v; + DivBy_SmallerBound v := if smaller_bound_exp =? n + then snd v + else shrd (snd v) (fst v) smaller_bound_exp; + Mul x y := muldw x y; + CarryAdd x y := adc x y false; + CarrySubSmall x y := subc x y false; + ConditionalSubtract b x := let v := selc b (ldi modulus) (ldi 0) in snd (subc x v false); + ConditionalSubtractModulus y := addm y (ldi 0) (ldi modulus) }. +End fancy_machine_p256_montgomery_foundation. diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v new file mode 100644 index 000000000..804296374 --- /dev/null +++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v @@ -0,0 +1,110 @@ +(*** Proving ℤ-Like via Architecture *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.InterfaceProofs. +Require Import Crypto.BoundedArithmetic.DoubleBounded. +Require Import Crypto.BoundedArithmetic.DoubleBoundedProofs. +Require Import Crypto.BoundedArithmetic.ArchitectureToZLike. +Require Import Crypto.ModularArithmetic.ZBounded. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ZUtil Crypto.Util.Tactics. + +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. + +Section fancy_machine_p256_montgomery_foundation. + Context {n_over_two : Z}. + Local Notation n := (2 * n_over_two)%Z. + Context (ops : fancy_machine.instructions n) (modulus : Z). + + Local Arguments Z.mul !_ !_. + Local Arguments BaseSystem.decode !_ !_ / . + Local Arguments BaseSystem.accumulate / . + Local Arguments BaseSystem.decode' !_ !_ / . + + Local Ltac introduce_t_step := + match goal with + | [ |- forall x : bool, _ ] => intros [|] + | [ |- True -> _ ] => intros _ + | [ |- _ <= _ < _ -> _ ] => intro + | _ => let x := fresh "x" in + intro x; + try pose proof (decode_range (fst x)); + try pose proof (decode_range (snd x)); + pose proof (decode_range x) + end. + Local Ltac unfolder_t := + progress unfold LargeT, SmallT, modulus_digits, decode_large, decode_small, Mod_SmallBound, DivBy_SmallBound, DivBy_SmallerBound, Mul, CarryAdd, CarrySubSmall, ConditionalSubtract, ConditionalSubtractModulus, ZLikeOps_of_ArchitectureBoundedOps in *. + Local Ltac saturate_context_step := + match goal with + | _ => unique assert (0 <= 2 * n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] + | _ => unique assert (0 <= n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] + | _ => unique assert (0 <= 2 * (2 * n_over_two)) by (eauto using decode_exponent_nonnegative with typeclass_instances) + | [ H : 0 <= ?x < _ |- _ ] => unique pose proof (proj1 H); unique pose proof (proj2 H) + end. + Local Ltac pre_t := + repeat first [ tauto + | introduce_t_step + | unfolder_t + | saturate_context_step ]. + Local Ltac post_t_step := + match goal with + | _ => reflexivity + | _ => progress autorewrite with zsimplify_const + | [ |- fst ?x = (?a <=? ?b) :> bool ] + => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); + [ destruct (fst x), (a <=? b); intro; congruence | ] + | [ H : (_ =? _) = true |- _ ] => apply Z.eqb_eq in H; subst + | [ H : (_ =? _) = false |- _ ] => apply Z.eqb_neq in H + | _ => autorewrite with push_Zpow in *; solve [ reflexivity | assumption ] + | _ => autorewrite with pull_Zpow in *; pull_decode; reflexivity + | _ => progress push_decode + | _ => rewrite (Z.add_comm (_ << _) _); progress pull_decode + | [ |- context[if ?x =? ?y then _ else _] ] => destruct (x =? y) eqn:? + | _ => autorewrite with Zshift_to_pow; Z.rewrite_mod_small; reflexivity + end. + Local Ltac post_t := repeat post_t_step. + Local Ltac t := pre_t; post_t. + + Global Instance ZLikeProperties_of_ArchitectureBoundedOps + {arith : fancy_machine.arithmetic ops} + (modulus_in_range : 0 <= modulus < 2^n) + (smaller_bound_exp : Z) + (smaller_bound_smaller : 0 <= smaller_bound_exp <= n) + (n_pos : 0 < n) + : ZLikeProperties (ZLikeOps_of_ArchitectureBoundedOps ops modulus smaller_bound_exp) + := { large_valid v := True; + medium_valid v := 0 <= decode_large v < 2^n * 2^smaller_bound_exp; + small_valid v := True }. + Proof. + (* In 8.5: *) + (* par:t. *) + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + Defined. +End fancy_machine_p256_montgomery_foundation. diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v new file mode 100644 index 000000000..b624c5082 --- /dev/null +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -0,0 +1,102 @@ +(*** Implementing Large Bounded Arithmetic via pairs *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.InterfaceProofs. +Require Import Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.Notations. + +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. +Local Notation eta x := (fst x, snd x). + +(** The list is low to high; the tuple is low to high *) +Definition tuple_decoder {n W} {decode : decoder n W} {k : nat} : decoder (k * n) (tuple W k) + := {| decode w := BaseSystem.decode (base_from_limb_widths (repeat n k)) + (List.map decode (List.rev (Tuple.to_list _ w))) |}. +Global Arguments tuple_decoder : simpl never. +Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances. + +Section ripple_carry_definitions. + (** tuple is high to low ([to_list] reverses) *) + Fixpoint ripple_carry_tuple' {T} (f : T -> T -> bool -> bool * T) k + : forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k + := match k return forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k with + | O => f + | S k' => fun xss yss carry => let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (@ripple_carry_tuple' _ f k' xs ys carry) in + let '(carry, z) := eta (f x y carry) in + (carry, (zs, z)) + end. + + Definition ripple_carry_tuple {T} (f : T -> T -> bool -> bool * T) k + : forall (xs ys : tuple T k) (carry : bool), bool * tuple T k + := match k return forall (xs ys : tuple T k) (carry : bool), bool * tuple T k with + | O => fun xs ys carry => (carry, tt) + | S k' => ripple_carry_tuple' f k' + end. +End ripple_carry_definitions. + +Global Instance ripple_carry_adc + {W} (adc : add_with_carry W) {k} + : add_with_carry (tuple W k) + := { adc := ripple_carry_tuple adc k }. + +(** constructions on [tuple W 2] *) +Section tuple2. + Section spread_left. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W}. + + Definition spread_left_from_shift (r : W) (count : Z) : tuple W 2 + := (shl r count, if count =? 0 then ldi 0 else shr r (n - count)). + + (** Require a [decoder] instance to aid typeclass search in + resolving [n] *) + Global Instance sprl_from_shift {decode : decoder n W} : spread_left_immediate W + := { sprl := spread_left_from_shift }. + End spread_left. + + Section double_from_half. + Context {half_n : Z} {W} + {mulhwll : multiply_low_low W} + {mulhwhl : multiply_high_low W} + {mulhwhh : multiply_high_high W} + {adc : add_with_carry W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {ldi : load_immediate W}. + + Definition mul_double (a b : W) : tuple W 2 + := let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in + let tmp := mulhwhl a b in + let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + let tmp := mulhwhl b a in + let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + out. + + (** Require a dummy [decoder] for these instances to allow + typeclass inference of the [half_n] argument *) + Global Instance mul_double_multiply {decode : decoder (2 * half_n) W} : multiply_double W + := { muldw a b := mul_double a b }. + End double_from_half. + + Global Instance mul_double_multiply_low_low {W} {muldw : multiply_double W} + : multiply_low_low (tuple W 2) + := { mulhwll a b := muldw (fst a) (fst b) }. + Global Instance mul_double_multiply_high_low {W} {muldw : multiply_double W} + : multiply_high_low (tuple W 2) + := { mulhwhl a b := muldw (snd a) (fst b) }. + Global Instance mul_double_multiply_high_high {W} {muldw : multiply_double W} + : multiply_high_high (tuple W 2) + := { mulhwhh a b := muldw (snd a) (snd b) }. +End tuple2. + +Global Arguments mul_double half_n {_ _ _ _ _ _ _} _ _. diff --git a/src/BoundedArithmetic/DoubleBoundedProofs.v b/src/BoundedArithmetic/DoubleBoundedProofs.v new file mode 100644 index 000000000..53ac59d00 --- /dev/null +++ b/src/BoundedArithmetic/DoubleBoundedProofs.v @@ -0,0 +1,470 @@ +(*** Proofs About Large Bounded Arithmetic via pairs *) +Require Import Coq.ZArith.ZArith Coq.Lists.List Coq.micromega.Psatz. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.InterfaceProofs. +Require Import Crypto.BaseSystem. +Require Import Crypto.BaseSystemProofs. +Require Import Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.ModularArithmetic.Pow2BaseProofs. +Require Import Crypto.BoundedArithmetic.DoubleBounded. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.Notations. + +Import ListNotations. +Local Open Scope list_scope. +Local Open Scope nat_scope. +Local Open Scope type_scope. +Local Open Scope Z_scope. + +Local Coercion Z.of_nat : nat >-> Z. +Local Coercion Pos.to_nat : positive >-> nat. +Local Notation eta x := (fst x, snd x). + +Import BoundedRewriteNotations. +Local Open Scope Z_scope. + +Section decode. + Context {n W} {decode : decoder n W}. + Section with_k. + Context {k : nat}. + Local Notation limb_widths := (repeat n k). + + Lemma decode_bounded {isdecode : is_decode decode} w + : 0 <= n -> bounded limb_widths (map decode (rev (to_list k w))). + Proof. + intro. + eapply bounded_uniform; try solve [ eauto using repeat_spec ]. + { distr_length. } + { intros z H'. + apply in_map_iff in H'. + destruct H' as [? [? H'] ]; subst; apply decode_range. } + Qed. + + (** TODO: Clean up this proof *) + Global Instance tuple_is_decode {isdecode : is_decode decode} + : is_decode (tuple_decoder (k := k)). + Proof. + unfold tuple_decoder; hnf; simpl. + intro w. + destruct (zerop k); [ subst | ]. + { unfold BaseSystem.decode, BaseSystem.decode'; simpl; omega. } + assert (0 <= n) + by (destruct k as [ | [|] ]; [ omega | | destruct w ]; + eauto using decode_exponent_nonnegative). + replace (2^(k * n)) with (upper_bound limb_widths) + by (erewrite upper_bound_uniform by eauto using repeat_spec; distr_length). + apply decode_upper_bound; auto using decode_bounded. + { intros ? H'. + apply repeat_spec in H'; omega. } + { distr_length. } + Qed. + End with_k. + + Local Arguments base_from_limb_widths : simpl never. + Local Arguments repeat : simpl never. + Local Arguments Z.mul !_ !_. + Lemma tuple_decoder_S {k} w : 0 <= n -> (tuple_decoder (k := S (S k)) w = tuple_decoder (k := S k) (fst w) + (decode (snd w) << (S k * n)))%Z. + Proof. + intro Hn. + destruct w as [? w]; simpl. + replace (decode w) with (decode w * 1 + 0)%Z by omega. + rewrite map_app, map_cons, map_nil. + erewrite decode_shift_uniform_app by (eauto using repeat_spec; distr_length). + distr_length. + autorewrite with push_skipn natsimplify push_firstn. + reflexivity. + Qed. + Global Instance tuple_decoder_O w : tuple_decoder (k := 1) w =~> decode w. + Proof. + unfold tuple_decoder, BaseSystem.decode, BaseSystem.decode', accumulate, base_from_limb_widths, repeat. + simpl; hnf. + omega. + Qed. + Lemma tuple_decoder_O_ind_prod + (P : forall n, decoder n W -> Type) + (P_ext : forall n (a b : decoder n W), (forall x, a x = b x) -> P _ a -> P _ b) + : (P _ (tuple_decoder (k := 1)) -> P _ decode) + * (P _ decode -> P _ (tuple_decoder (k := 1))). + Proof. + unfold tuple_decoder, BaseSystem.decode, BaseSystem.decode', accumulate, base_from_limb_widths, repeat. + simpl; hnf. + rewrite Z.mul_1_l. + split; apply P_ext; simpl; intro; autorewrite with zsimplify_const; reflexivity. + Qed. + + Global Instance tuple_decoder_m1 w : tuple_decoder (k := 0) w =~> 0. + Proof. reflexivity. Qed. + + Global Instance tuple_decoder_2' w : (0 <= n)%bounded_rewrite -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << (1%nat * n))%Z. + Proof. + intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption. + reflexivity. + Qed. + Global Instance tuple_decoder_2 w : (0 <= n)%bounded_rewrite -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << n)%Z. + Proof. + intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. +End decode. + +Local Arguments tuple_decoder : simpl never. +Local Opaque tuple_decoder. + +Lemma is_add_with_carry_1tuple {n W decode adc} + (H : @is_add_with_carry n W decode adc) + : @is_add_with_carry (1 * n) W (@tuple_decoder n W decode 1) adc. +Proof. + apply tuple_decoder_O_ind_prod; try assumption. + intros ??? ext [H0 H1]; apply Build_is_add_with_carry'. + intros x y c; specialize (H0 x y c); specialize (H1 x y c). + rewrite <- !ext; split; assumption. +Qed. + +Hint Extern 1 (@is_add_with_carry _ _ (@tuple_decoder ?n ?W ?decode 1) ?adc) +=> apply (@is_add_with_carry_1tuple n W decode adc) : typeclass_instances. + +Hint Resolve (fun n W decode pf => (@tuple_is_decode n W decode 2 pf : @is_decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2))) : typeclass_instances. +Hint Extern 3 (@is_decode _ (tuple ?W ?k) _) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode pf => (@tuple_is_decode n W decode k pf : @is_decode (kv * n) (tuple W k) (@tuple_decoder n W decode k : decoder (kv * n)%Z (tuple W k)))) : typeclass_instances. + +Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. +Hint Rewrite Z.mul_1_l : simpl_tuple_decoder. +Hint Rewrite + (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _)) + (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _)) + using solve [ auto with zarith ] + : simpl_tuple_decoder. + +Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. + +Global Instance tuple_decoder_mod {n W} {decode : decoder n W} {k} {isdecode : is_decode decode} (w : tuple W (S (S k))) + : tuple_decoder (k := S k) (fst w) <~= tuple_decoder w mod 2^(S k * n). +Proof. + pose proof (snd w). + assert (0 <= n) by eauto using decode_exponent_nonnegative. + assert (0 <= (S k) * n) by nia. + assert (0 <= tuple_decoder (k := S k) (fst w) < 2^(S k * n)) by apply decode_range. + autorewrite with simpl_tuple_decoder Zshift_to_pow zsimplify. + reflexivity. +Qed. + +Global Instance tuple_decoder_div {n W} {decode : decoder n W} {k} {isdecode : is_decode decode} (w : tuple W (S (S k))) + : decode (snd w) <~= tuple_decoder w / 2^(S k * n). +Proof. + pose proof (snd w). + assert (0 <= n) by eauto using decode_exponent_nonnegative. + assert (0 <= (S k) * n) by nia. + assert (0 <= k * n) by nia. + assert (0 < 2^n) by auto with zarith. + assert (0 <= tuple_decoder (k := S k) (fst w) < 2^(S k * n)) by apply decode_range. + autorewrite with simpl_tuple_decoder Zshift_to_pow zsimplify. + reflexivity. +Qed. + +Global Instance tuple2_decoder_mod {n W} {decode : decoder n W} {isdecode : is_decode decode} (w : tuple W 2) + : decode (fst w) <~= tuple_decoder w mod 2^n. +Proof. + generalize (@tuple_decoder_mod n W decode 0 isdecode w). + autorewrite with simpl_tuple_decoder; trivial. +Qed. + +Global Instance tuple2_decoder_div {n W} {decode : decoder n W} {isdecode : is_decode decode} (w : tuple W 2) + : decode (snd w) <~= tuple_decoder w / 2^n. +Proof. + generalize (@tuple_decoder_div n W decode 0 isdecode w). + autorewrite with simpl_tuple_decoder; trivial. +Qed. + +Lemma decode_is_spread_left_immediate_iff + {n W} + {decode : decoder n W} + {sprl : spread_left_immediate W} + {isdecode : is_decode decode} + : is_spread_left_immediate sprl + <-> (forall r count, + 0 <= count < n + -> tuple_decoder (sprl r count) = decode r << count). +Proof. + rewrite is_spread_left_immediate_alt by assumption. + split; intros H r count Hc; specialize (H r count Hc); revert H; + pose proof (decode_range r); + assert (0 < 2^count < 2^n) by auto with zarith; + autorewrite with simpl_tuple_decoder; + simpl; intro H'; rewrite H'; + autorewrite with Zshift_to_pow; + Z.rewrite_mod_small; reflexivity. +Qed. + +Global Instance decode_is_spread_left_immediate + {n W} + {decode : decoder n W} + {sprl : spread_left_immediate W} + {isdecode : is_decode decode} + {issprl : is_spread_left_immediate sprl} + : forall r count, + (0 <= count < n)%bounded_rewrite + -> tuple_decoder (sprl r count) <~=~> decode r << count + := proj1 decode_is_spread_left_immediate_iff _. + +Lemma decode_mul_double_iff + {n W} + {decode : decoder n W} + {muldw : multiply_double W} + {isdecode : is_decode decode} + : is_mul_double muldw + <-> (forall x y, tuple_decoder (muldw x y) = (decode x * decode y)%Z). +Proof. + rewrite is_mul_double_alt by assumption. + split; intros H x y; specialize (H x y); revert H; + pose proof (decode_range x); pose proof (decode_range y); + assert (0 <= decode x * decode y < 2^n * 2^n) by nia; + assert (0 <= n) by eauto using decode_exponent_nonnegative; + autorewrite with simpl_tuple_decoder; + simpl; intro H'; rewrite H'; + Z.rewrite_mod_small; reflexivity. +Qed. + +Global Instance decode_mul_double + {n W} + {decode : decoder n W} + {muldw : multiply_double W} + {isdecode : is_decode decode} + {ismuldw : is_mul_double muldw} + : forall x y, tuple_decoder (muldw x y) <~=~> (decode x * decode y)%Z + := proj1 decode_mul_double_iff _. + +Lemma ripple_carry_tuple_SS {T} f k xss yss carry + : @ripple_carry_tuple T f (S (S k)) xss yss carry + = let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (@ripple_carry_tuple _ f (S k) xs ys carry) in + let '(carry, z) := eta (f x y carry) in + (carry, (zs, z)). +Proof. reflexivity. Qed. + +Lemma carry_is_good (n z0 z1 k : Z) + : 0 <= n -> + 0 <= k -> + (z1 + z0 >> k) >> n = (z0 + z1 << k) >> (k + n) /\ + (z0 mod 2 ^ k + ((z1 + z0 >> k) mod 2 ^ n) << k)%Z = (z0 + z1 << k) mod (2 ^ k * 2 ^ n). +Proof. + intros. + assert (0 < 2 ^ n) by auto with zarith. + assert (0 < 2 ^ k) by auto with zarith. + assert (0 < 2^n * 2^k) by nia. + autorewrite with Zshift_to_pow push_Zpow. + rewrite <- (Zmod_small ((z0 mod _) + _) (2^k * 2^n)) by (Z.div_mod_to_quot_rem; nia). + rewrite <- !Z.mul_mod_distr_r by lia. + rewrite !(Z.mul_comm (2^k)); pull_Zmod. + split; [ | apply f_equal2 ]; + Z.div_mod_to_quot_rem; nia. +Qed. + +Definition carry_is_good_carry n z0 z1 k H0 H1 := proj1 (@carry_is_good n z0 z1 k H0 H1). +Definition carry_is_good_value n z0 z1 k H0 H1 := proj2 (@carry_is_good n z0 z1 k H0 H1). + +Section ripple_carry_adc. + Context {n W} {decode : decoder n W} (adc : add_with_carry W). + + Lemma ripple_carry_adc_SS k xss yss carry + : ripple_carry_adc (k := S (S k)) adc xss yss carry + = let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (ripple_carry_adc (k := S k) adc xs ys carry) in + let '(carry, z) := eta (adc x y carry) in + (carry, (zs, z)). + Proof. apply ripple_carry_tuple_SS. Qed. + + Local Opaque Z.of_nat. + Global Instance ripple_carry_is_add_with_carry {k} + {isdecode : is_decode decode} + {is_adc : is_add_with_carry adc} + : is_add_with_carry (ripple_carry_adc (k := k) adc). + Proof. + destruct k as [|k]. + { constructor; simpl; intros; autorewrite with zsimplify; reflexivity. } + { induction k as [|k IHk]. + { cbv [ripple_carry_adc ripple_carry_tuple to_list]. + constructor; simpl @fst; simpl @snd; intros; + simpl; pull_decode; reflexivity. } + { apply Build_is_add_with_carry'; intros x y c. + assert (0 <= n) by (destruct x; eauto using decode_exponent_nonnegative). + assert (2^n <> 0) by auto with zarith. + assert (0 <= S k * n) by nia. + rewrite !tuple_decoder_S, !ripple_carry_adc_SS by assumption. + simplify_projections; push_decode; generalize_decode. + erewrite carry_is_good_carry, carry_is_good_value by lia. + autorewrite with pull_Zpow push_Zof_nat zsimplify Zshift_to_pow. + split; apply f_equal2; nia. } } + Qed. + +End ripple_carry_adc. + +Hint Extern 2 (@is_add_with_carry _ (tuple ?W ?k) (@tuple_decoder ?n _ ?decode _) (@ripple_carry_adc _ ?adc _)) +=> apply (@ripple_carry_is_add_with_carry n W decode adc k) : typeclass_instances. + +Section tuple2. + Local Arguments Z.pow !_ !_. + Local Arguments Z.mul !_ !_. + + Section spread_left. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {decode : decoder n W} + {isdecode : is_decode decode} + {isldi : is_load_immediate ldi} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr}. + + Lemma spread_left_from_shift_correct + r count + (H : 0 < count < n) + : (decode (shl r count) + decode (shr r (n - count)) << n = decode r << count mod (2^n*2^n))%Z. + Proof. + assert (0 <= count < n) by lia. + assert (0 <= n - count < n) by lia. + assert (0 < 2^(n-count)) by auto with zarith. + assert (2^count < 2^n) by auto with zarith. + pose proof (decode_range r). + assert (0 <= decode r * 2 ^ count < 2 ^ n * 2^n) by auto with zarith. + push_decode; autorewrite with Zshift_to_pow zsimplify. + replace (decode r / 2^(n-count) * 2^n)%Z with ((decode r / 2^(n-count) * 2^(n-count)) * 2^count)%Z + by (rewrite <- Z.mul_assoc; autorewrite with pull_Zpow zsimplify; reflexivity). + rewrite Z.mul_div_eq' by lia. + autorewrite with push_Zmul zsimplify. + rewrite <- Z.mul_mod_distr_r_full, Z.add_sub_assoc. + repeat autorewrite with pull_Zpow zsimplify in *. + reflexivity. + Qed. + + Global Instance is_spread_left_from_shift + : is_spread_left_immediate (sprl_from_shift n). + Proof. + apply is_spread_left_immediate_alt. + intros r count; intros. + pose proof (decode_range r). + assert (0 < 2^n) by auto with zarith. + assert (decode r < 2^n * 2^n) by (generalize dependent (decode r); intros; nia). + autorewrite with simpl_tuple_decoder. + destruct (Z_zerop count). + { subst; autorewrite with Zshift_to_pow zsimplify. + simpl; push_decode. + autorewrite with push_Zpow zsimplify. + reflexivity. } + simpl. + rewrite <- spread_left_from_shift_correct by lia. + autorewrite with zsimplify Zpow_to_shift. + reflexivity. + Qed. + End spread_left. + + Local Opaque ripple_carry_adc. + Section full_from_half. + Context {W} + {mulhwll : multiply_low_low W} + {mulhwhl : multiply_high_low W} + {mulhwhh : multiply_high_high W} + {adc : add_with_carry W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {half_n : Z} + {ldi : load_immediate W} + {decode : decoder (2 * half_n) W} + {ismulhwll : is_mul_low_low half_n mulhwll} + {ismulhwhl : is_mul_high_low half_n mulhwhl} + {ismulhwhh : is_mul_high_high half_n mulhwhh} + {isadc : is_add_with_carry adc} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr} + {isldi : is_load_immediate ldi} + {isdecode : is_decode decode}. + + Local Arguments Z.mul !_ !_. + Lemma spread_left_from_shift_half_correct + r + : (decode (shl r half_n) + decode (shr r half_n) * (2^half_n * 2^half_n) + = (decode r * 2^half_n) mod (2^half_n*2^half_n*2^half_n*2^half_n))%Z. + Proof. + destruct (0 <? half_n) eqn:Hn; Z.ltb_to_lt. + { pose proof (spread_left_from_shift_correct (2*half_n) r half_n) as H. + specialize_by lia. + autorewrite with Zshift_to_pow push_Zpow zsimplify in *. + rewrite !Z.mul_assoc in *. + simpl in *; rewrite <- H; reflexivity. } + { pose proof (decode_range r). + pose proof (decode_range (shr r half_n)). + pose proof (decode_range (shl r half_n)). + simpl in *. + autorewrite with push_Zpow in *. + destruct (Z_zerop half_n). + { subst; simpl in *. + autorewrite with zsimplify. + nia. } + assert (half_n < 0) by lia. + assert (2^half_n = 0) by auto with zarith. + assert (0 < 0) by nia; omega. } + Qed. + + Lemma decode_mul_double_mod x y + : (tuple_decoder (mul_double half_n x y) = (decode x * decode y) mod (2^(2 * half_n) * 2^(2*half_n)))%Z. + Proof. + assert (0 <= 2 * half_n) by eauto using decode_exponent_nonnegative. + assert (0 <= half_n) by omega. + unfold mul_double. + push_decode; autorewrite with simpl_tuple_decoder; simplify_projections. + autorewrite with zsimplify Zshift_to_pow push_Zpow. + rewrite !spread_left_from_shift_half_correct. + push_decode. + generalize_decode_var. + simpl in *. + autorewrite with push_Zpow in *. + repeat autorewrite with Zshift_to_pow zsimplify push_Zpow. + rewrite <- !(Z.mul_mod_distr_r_full _ _ (_^_ * _^_)), ?Z.mul_assoc. + Z.rewrite_mod_small. + push_Zmod; pull_Zmod. + apply f_equal2; [ | reflexivity ]. + Z.div_mod_to_quot_rem; nia. + Qed. + + Lemma decode_mul_double_function x y + : tuple_decoder (mul_double half_n x y) = (decode x * decode y)%Z. + Proof. + rewrite decode_mul_double_mod; generalize_decode_var. + simpl in *; Z.rewrite_mod_small; reflexivity. + Qed. + + Global Instance mul_double_is_multiply_double : is_mul_double mul_double_multiply. + Proof. + apply decode_mul_double_iff; apply decode_mul_double_function. + Qed. + End full_from_half. + + Section half_from_full. + Context {n W} + {decode : decoder n W} + {muldw : multiply_double W} + {isdecode : is_decode decode} + {ismuldw : is_mul_double muldw}. + + Local Ltac t := + hnf; intros [??] [??]; + assert (0 <= n) by eauto using decode_exponent_nonnegative; + assert (0 < 2^n) by auto with zarith; + assert (forall x y, 0 <= x < 2^n -> 0 <= y < 2^n -> 0 <= x * y < 2^n * 2^n) by auto with zarith; + simpl @Interface.mulhwhh; simpl @Interface.mulhwhl; simpl @Interface.mulhwll; + rewrite decode_mul_double; autorewrite with simpl_tuple_decoder Zshift_to_pow zsimplify push_Zpow; + Z.rewrite_mod_small; + try reflexivity. + + Global Instance mul_double_is_multiply_low_low : is_mul_low_low n mul_double_multiply_low_low. + Proof. t. Qed. + Global Instance mul_double_is_multiply_high_low : is_mul_high_low n mul_double_multiply_high_low. + Proof. t. Qed. + Global Instance mul_double_is_multiply_high_high : is_mul_high_high n mul_double_multiply_high_high. + Proof. t. Qed. + End half_from_full. +End tuple2. diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v new file mode 100644 index 000000000..152c43cee --- /dev/null +++ b/src/BoundedArithmetic/Interface.v @@ -0,0 +1,281 @@ +(*** Interface for bounded arithmetic *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.AutoRewrite. +Require Import Crypto.Util.Notations. + +Local Open Scope type_scope. +Local Open Scope Z_scope. + +Class decoder (n : Z) W := + { decode : W -> Z }. +Coercion decode : decoder >-> Funclass. +Global Arguments decode {n W _} _. + +Class is_decode {n W} (decode : decoder n W) := + decode_range : forall x, 0 <= decode x < 2^n. + +Class bounded_in_range_cls (x y z : Z) := is_bounded_in_range : x <= y < z. +Ltac bounded_solver_tac := + solve [ eassumption | typeclasses eauto | omega ]. +Hint Extern 0 (bounded_in_range_cls _ _ _) => unfold bounded_in_range_cls; bounded_solver_tac : typeclass_instances. +Global Arguments bounded_in_range_cls / . +Global Instance decode_range_bound {n W} {decode : decoder n W} {H : is_decode decode} + : forall x, bounded_in_range_cls 0 (decode x) (2^n) + := H. + +Class bounded_le_cls (x y : Z) := is_bounded_le : x <= y. +Hint Extern 0 (bounded_le_cls _ _) => unfold bounded_le_cls; bounded_solver_tac : typeclass_instances. +Global Arguments bounded_le_cls / . + +Inductive bounded_decode_pusher_tag := decode_tag. + +Ltac push_decode_step := + match goal with + | [ |- context[@decode ?n ?W ?decoder ?w] ] + => tc_rewrite (decode_tag) (@decode n W decoder w) -> + | [ |- context[match @fst ?A ?B ?x with true => 1 | false => 0 end] ] + => tc_rewrite (decode_tag) (match @fst A B x with true => 1 | false => 0 end) -> + end. +Ltac push_decode := repeat push_decode_step. +Ltac pull_decode_step := + match goal with + | [ |- context[?E] ] + => lazymatch type of E with + | Z => idtac + | bool => idtac + end; + tc_rewrite (decode_tag) <- E + end. +Ltac pull_decode := repeat pull_decode_step. + +Delimit Scope bounded_rewrite_scope with bounded_rewrite. + +Infix "<~=~>" := (rewrite_eq decode_tag) : bounded_rewrite_scope. +Infix "=~>" := (rewrite_left_to_right_eq decode_tag) : bounded_rewrite_scope. +Infix "<~=" := (rewrite_right_to_left_eq decode_tag) : bounded_rewrite_scope. +Notation "x <= y" := (bounded_le_cls x y) : bounded_rewrite_scope. +Notation "x <= y < z" := (bounded_in_range_cls x y z) : bounded_rewrite_scope. + +Module Import BoundedRewriteNotations. + Infix "<~=~>" := (rewrite_eq decode_tag) : type_scope. + Infix "=~>" := (rewrite_left_to_right_eq decode_tag) : type_scope. + Infix "<~=" := (rewrite_right_to_left_eq decode_tag) : type_scope. + Open Scope bounded_rewrite_scope. +End BoundedRewriteNotations. + +(** This is required for typeclass resolution to be fast. *) +Typeclasses Opaque decode. + +Section InstructionGallery. + Context (n : Z) (* bit-width of width of [W] *) + {W : Type} (* bounded type, [W] for word *) + (Wdecoder : decoder n W). + Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) + + Class load_immediate := { ldi : imm -> W }. + Global Coercion ldi : load_immediate >-> Funclass. + + Class is_load_immediate {ldi : load_immediate} := + decode_load_immediate :> forall x, 0 <= x < 2^n -> decode (ldi x) =~> x. + + Class shift_right_doubleword_immediate := { shrd : W -> W -> imm -> W }. + Global Coercion shrd : shift_right_doubleword_immediate >-> Funclass. + + Class is_shift_right_doubleword_immediate (shrd : shift_right_doubleword_immediate) := + decode_shift_right_doubleword :> + forall high low count, + 0 <= count < n + -> decode (shrd high low count) <~=~> (((decode high << n) + decode low) >> count) mod 2^n. + + Class shift_left_immediate := { shl : W -> imm -> W }. + Global Coercion shl : shift_left_immediate >-> Funclass. + + Class is_shift_left_immediate (shl : shift_left_immediate) := + decode_shift_left_immediate :> + forall r count, 0 <= count < n -> decode (shl r count) <~=~> (decode r << count) mod 2^n. + + Class shift_right_immediate := { shr : W -> imm -> W }. + Global Coercion shr : shift_right_immediate >-> Funclass. + + Class is_shift_right_immediate (shr : shift_right_immediate) := + decode_shift_right_immediate :> + forall r count, 0 <= count < n -> decode (shr r count) <~=~> (decode r >> count). + + Class spread_left_immediate := { sprl : W -> imm -> tuple W 2 (* [(low, high)] *) }. + Global Coercion sprl : spread_left_immediate >-> Funclass. + + Class is_spread_left_immediate (sprl : spread_left_immediate) := + { + decode_fst_spread_left_immediate :> forall r count, + 0 <= count < n + -> decode (fst (sprl r count)) =~> (decode r << count) mod 2^n; + decode_snd_spread_left_immediate :> forall r count, + 0 <= count < n + -> decode (snd (sprl r count)) =~> (decode r << count) >> n + + }. + + Class mask_keep_low := { mkl :> W -> imm -> W }. + Global Coercion mkl : mask_keep_low >-> Funclass. + + Class is_mask_keep_low (mkl : mask_keep_low) := + decode_mask_keep_low :> forall r count, + 0 <= count < n -> decode (mkl r count) <~=~> decode r mod 2^count. + + Local Notation bit b := (if b then 1 else 0). + + Class add_with_carry := { adc : W -> W -> bool -> bool * W }. + Global Coercion adc : add_with_carry >-> Funclass. + + Class is_add_with_carry (adc : add_with_carry) := + { + bit_fst_add_with_carry :> forall x y c, bit (fst (adc x y c)) <~=~> (decode x + decode y + bit c) >> n; + decode_snd_add_with_carry :> forall x y c, decode (snd (adc x y c)) <~=~> (decode x + decode y + bit c) mod (2^n) + }. + + Class sub_with_carry := { subc : W -> W -> bool -> bool * W }. + Global Coercion subc : sub_with_carry >-> Funclass. + + Class is_sub_with_carry (subc:W->W->bool->bool*W) := + { + fst_sub_with_carry :> forall x y c, fst (subc x y c) <~=~> ((decode x - decode y - bit c) <? 0); + decode_snd_sub_with_carry :> forall x y c, decode (snd (subc x y c)) <~=~> (decode x - decode y - bit c) mod 2^n + }. + + Class multiply := { mul : W -> W -> W }. + Global Coercion mul : multiply >-> Funclass. + + Class is_mul (mul : multiply) := + decode_mul :> forall x y, decode (mul x y) <~=~> (decode x * decode y). + + Class multiply_low_low := { mulhwll : W -> W -> W }. + Global Coercion mulhwll : multiply_low_low >-> Funclass. + Class multiply_high_low := { mulhwhl : W -> W -> W }. + Global Coercion mulhwhl : multiply_high_low >-> Funclass. + Class multiply_high_high := { mulhwhh : W -> W -> W }. + Global Coercion mulhwhh : multiply_high_high >-> Funclass. + Class multiply_double := { muldw : W -> W -> tuple W 2 }. + Global Coercion muldw : multiply_double >-> Funclass. + + Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := + decode_mul_low_low :> + forall x y, decode (mulhwll x y) <~=~> ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. + Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) := + decode_mul_high_low :> + forall x y, decode (mulhwhl x y) <~=~> ((decode x >> w) * (decode y mod 2^w)) mod 2^n. + Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) := + decode_mul_high_high :> + forall x y, decode (mulhwhh x y) <~=~> ((decode x >> w) * (decode y >> w)) mod 2^n. + Class is_mul_double (muldw : multiply_double) := + { + decode_fst_mul_double :> + forall x y, decode (fst (muldw x y)) =~> (decode x * decode y) mod 2^n; + decode_snd_mul_double :> + forall x y, decode (snd (muldw x y)) =~> (decode x * decode y) >> n + }. + + Class select_conditional := { selc : bool -> W -> W -> W }. + Global Coercion selc : select_conditional >-> Funclass. + + Class is_select_conditional (selc : select_conditional) := + decode_select_conditional :> forall b x y, + decode (selc b x y) <~=~> if b then decode x else decode y. + + Class add_modulo := { addm : W -> W -> W (* modulus *) -> W }. + Global Coercion addm : add_modulo >-> Funclass. + + Class is_add_modulo (addm : add_modulo) := + decode_add_modulo :> forall x y modulus, + decode (addm x y modulus) <~=~> (if (decode x + decode y) <? decode modulus + then (decode x + decode y) + else (decode x + decode y) - decode modulus). +End InstructionGallery. + +Global Arguments load_immediate : clear implicits. +Global Arguments shift_right_doubleword_immediate : clear implicits. +Global Arguments shift_left_immediate : clear implicits. +Global Arguments shift_right_immediate : clear implicits. +Global Arguments spread_left_immediate : clear implicits. +Global Arguments mask_keep_low : clear implicits. +Global Arguments add_with_carry : clear implicits. +Global Arguments sub_with_carry : clear implicits. +Global Arguments multiply : clear implicits. +Global Arguments multiply_low_low : clear implicits. +Global Arguments multiply_high_low : clear implicits. +Global Arguments multiply_high_high : clear implicits. +Global Arguments multiply_double : clear implicits. +Global Arguments select_conditional : clear implicits. +Global Arguments add_modulo : clear implicits. +Global Arguments ldi {_ _} _. +Global Arguments shrd {_ _} _ _ _. +Global Arguments shl {_ _} _ _. +Global Arguments shr {_ _} _ _. +Global Arguments sprl {_ _} _ _. +Global Arguments mkl {_ _} _ _. +Global Arguments adc {_ _} _ _ _. +Global Arguments subc {_ _} _ _ _. +Global Arguments mul {_ _} _ _. +Global Arguments mulhwll {_ _} _ _. +Global Arguments mulhwhl {_ _} _ _. +Global Arguments mulhwhh {_ _} _ _. +Global Arguments muldw {_ _} _ _. +Global Arguments selc {_ _} _ _ _. +Global Arguments addm {_ _} _ _ _. + +Global Arguments is_decode {_ _} _. +Global Arguments is_load_immediate {_ _ _} _. +Global Arguments is_shift_right_doubleword_immediate {_ _ _} _. +Global Arguments is_shift_left_immediate {_ _ _} _. +Global Arguments is_shift_right_immediate {_ _ _} _. +Global Arguments is_spread_left_immediate {_ _ _} _. +Global Arguments is_mask_keep_low {_ _ _} _. +Global Arguments is_add_with_carry {_ _ _} _. +Global Arguments is_sub_with_carry {_ _ _} _. +Global Arguments is_mul {_ _ _} _. +Global Arguments is_mul_low_low {_ _ _} _ _. +Global Arguments is_mul_high_low {_ _ _} _ _. +Global Arguments is_mul_high_high {_ _ _} _ _. +Global Arguments is_mul_double {_ _ _} _. +Global Arguments is_select_conditional {_ _ _} _. +Global Arguments is_add_modulo {_ _ _} _. + +Module fancy_machine. + Local Notation imm := Z (only parsing). + + Class instructions (n : Z) := + { + W : Type (* [n]-bit word *); + decode :> decoder n W; + ldi :> load_immediate W; + shrd :> shift_right_doubleword_immediate W; + shl :> shift_left_immediate W; + shr :> shift_right_immediate W; + mkl :> mask_keep_low W; + adc :> add_with_carry W; + subc :> sub_with_carry W; + mulhwll :> multiply_low_low W; + mulhwhl :> multiply_high_low W; + mulhwhh :> multiply_high_high W; + selc :> select_conditional W; + addm :> add_modulo W + }. + + Class arithmetic {n_over_two} (ops:instructions (2 * n_over_two)) := + { + decode_range :> is_decode decode; + load_immediate :> is_load_immediate ldi; + shift_right_doubleword_immediate :> is_shift_right_doubleword_immediate shrd; + shift_left_immediate :> is_shift_left_immediate shl; + shift_right_immediate :> is_shift_right_immediate shr; + mask_keep_low :> is_mask_keep_low mkl; + add_with_carry :> is_add_with_carry adc; + sub_with_carry :> is_sub_with_carry subc; + multiply_low_low :> is_mul_low_low n_over_two mulhwll; + multiply_high_low :> is_mul_high_low n_over_two mulhwhl; + multiply_high_high :> is_mul_high_high n_over_two mulhwhh; + select_conditional :> is_select_conditional selc; + add_modulo :> is_add_modulo addm + }. +End fancy_machine. diff --git a/src/BoundedArithmetic/InterfaceProofs.v b/src/BoundedArithmetic/InterfaceProofs.v new file mode 100644 index 000000000..8256fc23f --- /dev/null +++ b/src/BoundedArithmetic/InterfaceProofs.v @@ -0,0 +1,221 @@ +(** * Alternate forms for Interface for bounded arithmetic *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.AutoRewrite. +Require Import Crypto.Util.Notations. + +Local Open Scope type_scope. +Local Open Scope Z_scope. + +Import BoundedRewriteNotations. +Local Notation bit b := (if b then 1 else 0). + +Section InstructionGallery. + Context (n : Z) (* bit-width of width of [W] *) + {W : Type} (* bounded type, [W] for word *) + (Wdecoder : decoder n W). + Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) + + Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate W) + (pf : forall r count, 0 <= count < n + -> _ /\ _) + := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H); + decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}. + + Definition Build_is_add_with_carry' (adc : add_with_carry W) + (pf : forall x y c, _ /\ _) + := {| bit_fst_add_with_carry x y c := proj1 (pf x y c); + decode_snd_add_with_carry x y c := proj2 (pf x y c) |}. + + Definition Build_is_sub_with_carry' (subc : sub_with_carry W) + (pf : forall x y c, _ /\ _) + : is_sub_with_carry subc + := {| fst_sub_with_carry x y c := proj1 (pf x y c); + decode_snd_sub_with_carry x y c := proj2 (pf x y c) |}. + + Definition Build_is_mul_double' (muldw : multiply_double W) + (pf : forall x y, _ /\ _) + := {| decode_fst_mul_double x y := proj1 (pf x y); + decode_snd_mul_double x y := proj2 (pf x y) |}. + + Lemma is_spread_left_immediate_alt + {sprl : spread_left_immediate W} + {isdecode : is_decode Wdecoder} + : is_spread_left_immediate sprl + <-> (forall r count, 0 <= count < n -> decode (fst (sprl r count)) + decode (snd (sprl r count)) << n = (decode r << count) mod (2^n*2^n))%Z. + Proof. + split; intro H; [ | apply Build_is_spread_left_immediate' ]; + intros r count Hc; + [ | specialize (H r count Hc); revert H ]; + unfold bounded_in_range_cls in *; + pose proof (decode_range r); + assert (0 < 2^n) by auto with zarith; + assert (0 <= 2^count < 2^n)%Z by auto with zarith; + assert (0 <= decode r * 2^count < 2^n * 2^n)%Z by (generalize dependent (decode r); intros; nia); + rewrite ?decode_fst_spread_left_immediate, ?decode_snd_spread_left_immediate + by typeclasses eauto with typeclass_instances core; + autorewrite with Zshift_to_pow zsimplify push_Zpow. + { reflexivity. } + { intro H'; rewrite <- H'. + autorewrite with zsimplify; split; reflexivity. } + Qed. + + Lemma is_mul_double_alt + {muldw : multiply_double W} + {isdecode : is_decode Wdecoder} + : is_mul_double muldw + <-> (forall x y, decode (fst (muldw x y)) + decode (snd (muldw x y)) << n = (decode x * decode y) mod (2^n*2^n)). + Proof. + split; intro H; [ | apply Build_is_mul_double' ]; + intros x y; + [ | specialize (H x y); revert H ]; + pose proof (decode_range x); + pose proof (decode_range y); + assert (0 < 2^n) by auto with zarith; + assert (0 <= decode x * decode y < 2^n * 2^n)%Z by nia; + (destruct (0 <=? n) eqn:?; Z.ltb_to_lt; + [ | assert (2^n = 0) by auto with zarith; exfalso; omega ]); + rewrite ?decode_fst_mul_double, ?decode_snd_mul_double + by typeclasses eauto with typeclass_instances core; + autorewrite with Zshift_to_pow zsimplify push_Zpow. + { reflexivity. } + { intro H'; rewrite <- H'. + autorewrite with zsimplify; split; reflexivity. } + Qed. +End InstructionGallery. + +Global Arguments is_spread_left_immediate_alt {_ _ _ _ _}. +Global Arguments is_mul_double_alt {_ _ _ _ _}. + +Ltac bounded_solver_tac := + solve [ eassumption | typeclasses eauto | omega ]. + +Global Instance decode_proj n W (dec : W -> Z) + : @decode n W {| decode := dec |} =~> dec. +Proof. reflexivity. Qed. + +Global Instance decode_if_bool n W (decode : decoder n W) (b : bool) x y + : decode (if b then x else y) + =~> if b then decode x else decode y. +Proof. destruct b; reflexivity. Qed. + +Global Instance decode_mod_small {n W} {decode : decoder n W} {x b} + {H : bounded_in_range_cls 0 (decode x) b} + : decode x <~= decode x mod b. +Proof. + Z.rewrite_mod_small; reflexivity. +Qed. + +Global Instance decode_mod_range {n W decode} {H : @is_decode n W decode} x + : decode x <~= decode x mod 2^n. +Proof. exact _. Qed. + +Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode} + (isinhabited : W) + : (0 <= n)%Z. +Proof. + pose proof (decode_range isinhabited). + assert (0 < 2^n) by omega. + destruct (Z_lt_ge_dec n 0) as [H'|]; [ | omega ]. + assert (2^n = 0) by auto using Z.pow_neg_r. + omega. +Qed. + +Section adc_subc. + Context {n W} + {decode : decoder n W} + {adc : add_with_carry W} + {subc : sub_with_carry W} + {isdecode : is_decode decode} + {isadc : is_add_with_carry adc} + {issubc : is_sub_with_carry subc}. + Global Instance bit_fst_add_with_carry_false + : forall x y, bit (fst (adc x y false)) <~=~> (decode x + decode y) >> n. + Proof. + intros; erewrite bit_fst_add_with_carry by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance bit_fst_add_with_carry_true + : forall x y, bit (fst (adc x y true)) <~=~> (decode x + decode y + 1) >> n. + Proof. + intros; erewrite bit_fst_add_with_carry by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance fst_add_with_carry_leb + : forall x y c, fst (adc x y c) <~= (2^n <=? (decode x + decode y + bit c)). + Proof. + intros x y c; hnf. + assert (0 <= n)%Z by eauto using decode_exponent_nonnegative. + pose proof (decode_range x); pose proof (decode_range y). + assert (0 <= bit c <= 1)%Z by (destruct c; omega). + lazymatch goal with + | [ |- fst ?x = (?a <=? ?b) :> bool ] + => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); + [ destruct (fst x), (a <=? b); intro; congruence | ] + end. + push_decode. + autorewrite with Zshift_to_pow. + rewrite Z.div_between_0_if by auto with zarith. + reflexivity. + Qed. + Global Instance fst_add_with_carry_false_leb + : forall x y, fst (adc x y false) <~= (2^n <=? (decode x + decode y)). + Proof. + intros; erewrite fst_add_with_carry_leb by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance fst_add_with_carry_true_leb + : forall x y, fst (adc x y true) <~=~> (2^n <=? (decode x + decode y + 1)). + Proof. + intros; erewrite fst_add_with_carry_leb by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance fst_sub_with_carry_false + : forall x y, fst (subc x y false) <~=~> ((decode x - decode y) <? 0). + Proof. + intros; erewrite fst_sub_with_carry by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance fst_sub_with_carry_true + : forall x y, fst (subc x y true) <~=~> ((decode x - decode y - 1) <? 0). + Proof. + intros; erewrite fst_sub_with_carry by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. +End adc_subc. + +Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y))) +=> apply @fst_add_with_carry_false_leb : typeclass_instances. +Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + 1))) +=> apply @fst_add_with_carry_true_leb : typeclass_instances. +Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + if ?c then _ else _))) +=> apply @fst_add_with_carry_leb : typeclass_instances. + + +(* We take special care to handle the case where the decoder is + syntactically different but the decoded expression is judgmentally + the same; we don't want to split apart variables that should be the + same. *) +Ltac set_decode_step check := + match goal with + | [ |- context G[@decode ?n ?W ?dr ?w] ] + => check w; + first [ match goal with + | [ d := @decode _ _ _ w |- _ ] + => change (@decode n W dr w) with d + end + | generalize (@decode_range n W dr _ w); + let d := fresh "d" in + set (d := @decode n W dr w); + intro ] + end. +Ltac set_decode check := repeat set_decode_step check. +Ltac clearbody_decode := + repeat match goal with + | [ H := @decode _ _ _ _ |- _ ] => clearbody H + end. +Ltac generalize_decode_by check := set_decode check; clearbody_decode. +Ltac generalize_decode := generalize_decode_by ltac:(fun w => idtac). +Ltac generalize_decode_var := generalize_decode_by ltac:(fun w => is_var w). diff --git a/src/Util/AutoRewrite.v b/src/Util/AutoRewrite.v new file mode 100644 index 000000000..b5e276f20 --- /dev/null +++ b/src/Util/AutoRewrite.v @@ -0,0 +1,56 @@ +(** * Machinery for reimplementing some bits of [rewrite_strat] with open pattern databases *) +Require Import Crypto.Util.Notations. +(** We build classes for rewriting in each direction, and pick lemmas + by resolving on tags and the term to be rewritten. *) +(** Base class, for bidirectional rewriting. *) +Class rewrite_eq {tagT} (tag : tagT) {A} (x y : A) + := by_rewrite : x = y. +Arguments by_rewrite {tagT} tag {A} _ _ {_}. +Infix "<~=~>" := (rewrite_eq _) : type_scope. + +Class rewrite_right_to_left_eq {tagT} (tag : tagT) {A} (x y : A) + := by_rewrite_right_to_left : rewrite_eq tag x y. +Arguments by_rewrite_right_to_left {tagT} tag {A} _ _ {_}. +Global Instance unfold_rewrite_right_to_left_eq {tagT tag A x y} (H : @rewrite_eq tagT tag A x y) + : @rewrite_right_to_left_eq tagT tag A x y := H. +Infix "<~=" := (rewrite_right_to_left_eq _) : type_scope. + +Class rewrite_left_to_right_eq {tagT} (tag : tagT) {A} (x y : A) + := by_rewrite_left_to_right : rewrite_eq tag x y. +Arguments by_rewrite_left_to_right {tagT} tag {A} _ _ {_}. +Global Instance unfold_rewrite_left_to_right_eq {tagT tag A x y} (H : @rewrite_eq tagT tag A x y) + : @rewrite_left_to_right_eq tagT tag A x y := H. +Infix "=~>" := (rewrite_left_to_right_eq _) : type_scope. + +Ltac typeclass_do_left_to_right tag from tac := + let lem := constr:(by_rewrite_left_to_right tag from _ : from = _) in tac lem. +Ltac typeclass_do_right_to_left tag from tac := + let lem := constr:(by_rewrite_right_to_left tag _ from : _ = from) in tac lem. + + +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" hyp_list(H) := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in H). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" hyp_list(H) "|-" "*" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in H |- *). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" "|-" "*" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * |- * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" "|-" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * |- ). + + +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" hyp_list(H) := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in H). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" hyp_list(H) "|-" "*" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in H |- *). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" "|-" "*" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * |- * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" "|-" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * |- ). diff --git a/src/Util/Notations.v b/src/Util/Notations.v index 686443829..8795da399 100644 --- a/src/Util/Notations.v +++ b/src/Util/Notations.v @@ -27,6 +27,7 @@ Reserved Infix "~=" (at level 70). Reserved Infix "==" (at level 70, no associativity). Reserved Infix "=~>" (at level 70, no associativity). Reserved Infix "<~=" (at level 70, no associativity). +Reserved Infix "<~=~>" (at level 70, no associativity). Reserved Infix "≡" (at level 70, no associativity). Reserved Infix "≢" (at level 70, no associativity). Reserved Infix "≡_n" (at level 70, no associativity). |