aboutsummaryrefslogtreecommitdiff
path: root/src/BoundedArithmetic
diff options
context:
space:
mode:
authorGravatar Jason Gross <jagro@google.com>2016-08-23 15:59:35 -0700
committerGravatar Jason Gross <jagro@google.com>2016-08-23 16:01:54 -0700
commit6897a4f42c86c4a6bfdbab6887276e7334317661 (patch)
treef5f283d445622171d09584af9ca0ac652c589d3a /src/BoundedArithmetic
parente7554f5525a36699fff33e70ee454cfd0a687808 (diff)
Hook up the bounded interface, finish proofs
Diffstat (limited to 'src/BoundedArithmetic')
-rw-r--r--src/BoundedArithmetic/ArchitectureToZLike.v115
-rw-r--r--src/BoundedArithmetic/ArchitectureToZLikeProofs.v109
-rw-r--r--src/BoundedArithmetic/DoubleBounded.v105
-rw-r--r--src/BoundedArithmetic/DoubleBoundedProofs.v354
-rw-r--r--src/BoundedArithmetic/Interface.v168
5 files changed, 690 insertions, 161 deletions
diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v
index 01387e969..e30fcfd09 100644
--- a/src/BoundedArithmetic/ArchitectureToZLike.v
+++ b/src/BoundedArithmetic/ArchitectureToZLike.v
@@ -3,123 +3,28 @@ Require Import Coq.ZArith.ZArith.
Require Import Crypto.BoundedArithmetic.Interface.
Require Import Crypto.BoundedArithmetic.DoubleBounded.
Require Import Crypto.ModularArithmetic.ZBounded.
-Require Import Coq.Lists.List.
-Import ListNotations.
+Require Import Crypto.Util.Tuple.
-Local Open Scope nat_scope.
Local Open Scope Z_scope.
-Local Open Scope type_scope.
-
-Local Coercion Z.of_nat : nat >-> Z.
Section fancy_machine_p256_montgomery_foundation.
Context {n_over_two : Z}.
- Local Notation n := (2 * n_over_two)%Z.
+ Local Notation n := (2 * n_over_two).
Context (ops : fancy_machine.instructions n) (modulus : Z).
- Definition two_list_to_tuple {A B} (x : A * list B)
- := match x return match x with
- | (a, [b0; b1]) => A * (B * B)
- | _ => True
- end
- with
- | (a, [b0; b1]) => (a, (b0, b1))
- | _ => I
- end.
-(*
- (* make all machine-specific constructions here, preferrably as
- thing wrappers around generic constructions *)
- Local Instance DoubleArchitectureBoundedOps : ArchitectureBoundedOps (2 * n)%nat
- := { BoundedType := BoundedType * BoundedType (* [(high, low)] *);
- decode high_low := (decode (fst high_low) * 2^n + decode (snd high_low))%Z;
- encode z := (encode (z / 2^n), encode (z mod 2^n));
- ShiftRight a high_low
- := let '(high, low) := eta high_low in
- if n <=? a then
- (ShiftRight (a - n)%nat (encode 0, fst high), ShiftRight (a - n)%nat high)
- else
- (ShiftRight a (snd high, fst low), ShiftRight a low);
- ShiftLeft a high_low
- := let '(high, low) := eta high_low in
- if 2 * n <=? a then
- let '(high0, low) := eta (ShiftLeft (a - 2 * n)%nat low) in
- let '(high_high, high1) := eta (ShiftLeft (a - 2 * n)%nat high) in
- ((snd (CarryAdd false high0 high1), low), (encode 0, encode 0))
- else if n <=? a then
- let '(high0, low) := eta (ShiftLeft (a - n)%nat low) in
- let '(high_high, high1) := eta (ShiftLeft (a - n)%nat high) in
- ((high_high, snd (CarryAdd false high0 high1)), (low, encode 0))
- else
- let '(high0, low) := eta (ShiftLeft a low) in
- let '(high_high, high1) := eta (ShiftLeft a high) in
- ((encode 0, high_high), (snd (CarryAdd false high0 high1), low));
- Mod2Pow a high_low
- := let '(high, low) := (fst high_low, snd high_low) in
- (Mod2Pow (a - n)%nat high, Mod2Pow a low);
- CarryAdd carry x_high_low y_high_low
- := let '(xhigh, xlow) := eta x_high_low in
- let '(yhigh, ylow) := eta y_high_low in
- two_list_to_tuple (ripple_carry CarryAdd carry [xhigh; xlow] [yhigh; ylow]);
- CarrySub carry x_high_low y_high_low
- := let '(xhigh, xlow) := eta x_high_low in
- let '(yhigh, ylow) := eta y_high_low in
- two_list_to_tuple (ripple_carry CarrySub carry [xhigh; xlow] [yhigh; ylow]) }.
-
- Definition BoundedOfHalfBounded (x : @BoundedHalfType (2 * n)%nat _) : @BoundedType n _
- := match x with
- | UpperHalf x => fst x
- | LowerHalf x => snd x
- end.
-
- Local Instance DoubleArchitectureBoundedHalfWidthMulOpsOfFullMulOps
- {base_mops : ArchitectureBoundedFullMulOps n}
- : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat :=
- { HalfWidthMul a b
- := Mul (BoundedOfHalfBounded a) (BoundedOfHalfBounded b) }.
- End single.
-
- Local Existing Instance DoubleArchitectureBoundedOps.
-
- Section full_from_half.
- Context (n : size) {base_ops : ArchitectureBoundedOps (2 * n)%nat}.
-
- Local Infix "*" := HalfWidthMul.
-
- Local Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps
- {base_mops : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat}
- : ArchitectureBoundedFullMulOps (2 * n)%nat :=
- { Mul a b
- := let '(a1, a0) := (UpperHalf a, LowerHalf a) in
- let '(b1, b0) := (UpperHalf b, LowerHalf b) in
- let out := a0 * b0 in
- let outHigh := a1 * b1 in
- let tmp := a1 * b0 in
- let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in
- let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in
- let tmp := a0 * b1 in
- let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in
- let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in
- (outHigh, out) }.
- End full_from_half.
-
- Local Existing Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps.
-*)
- Axiom admit : forall {T}, T.
-
Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : Z)
: ZLikeOps (2^n) (2^smaller_bound_exp) modulus :=
- { LargeT := fancy_machine.W * fancy_machine.W;
+ { LargeT := tuple fancy_machine.W 2;
SmallT := fancy_machine.W;
modulus_digits := ldi modulus;
- decode_large := _;
+ decode_large := decode;
decode_small := decode;
- Mod_SmallBound v := snd v;
- DivBy_SmallBound v := fst v;
- DivBy_SmallerBound v := shrd (fst v) (snd v) smaller_bound_exp;
- Mul x y := _ (*mulhwll (ldi 0, x) (ldi 0, y)*);
- CarryAdd x y := _ (*adc x y false*);
+ Mod_SmallBound v := fst v;
+ DivBy_SmallBound v := snd v;
+ DivBy_SmallerBound v := shrd (snd v) (fst v) smaller_bound_exp;
+ Mul x y := mulhwll (W := tuple _ 2) (sprl x 0) (sprl y 0);
+ CarryAdd x y := adc x y false;
CarrySubSmall x y := subc x y false;
- ConditionalSubtract b x := let v := selc b (ldi 0) (ldi modulus) in snd (subc x v false);
+ ConditionalSubtract b x := let v := selc b (ldi modulus) (ldi 0) in snd (subc x v false);
ConditionalSubtractModulus y := addm y (ldi 0) (ldi modulus) }.
- Abort.
End fancy_machine_p256_montgomery_foundation.
diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v
new file mode 100644
index 000000000..b7cac2bb3
--- /dev/null
+++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v
@@ -0,0 +1,109 @@
+(*** Proving ℤ-Like via Architecture *)
+Require Import Coq.ZArith.ZArith.
+Require Import Crypto.BoundedArithmetic.Interface.
+Require Import Crypto.BoundedArithmetic.DoubleBounded.
+Require Import Crypto.BoundedArithmetic.DoubleBoundedProofs.
+Require Import Crypto.BoundedArithmetic.ArchitectureToZLike.
+Require Import Crypto.ModularArithmetic.ZBounded.
+Require Import Crypto.Util.Tuple.
+Require Import Crypto.Util.ZUtil Crypto.Util.Tactics.
+
+Local Open Scope nat_scope.
+Local Open Scope Z_scope.
+Local Open Scope type_scope.
+
+Local Coercion Z.of_nat : nat >-> Z.
+
+Section fancy_machine_p256_montgomery_foundation.
+ Context {n_over_two : Z}.
+ Local Notation n := (2 * n_over_two)%Z.
+ Context (ops : fancy_machine.instructions n) (modulus : Z).
+
+ Local Arguments Z.mul !_ !_.
+ Local Arguments BaseSystem.decode !_ !_ / .
+ Local Arguments BaseSystem.accumulate / .
+ Local Arguments BaseSystem.decode' !_ !_ / .
+
+ Local Ltac introduce_t_step :=
+ match goal with
+ | [ |- forall x : bool, _ ] => intros [|]
+ | [ |- True -> _ ] => intros _
+ | [ |- _ <= _ < _ -> _ ] => intro
+ | _ => let x := fresh "x" in
+ intro x;
+ try pose proof (decode_range (fst x));
+ try pose proof (decode_range (snd x));
+ pose proof (decode_range x)
+ end.
+ Local Ltac unfolder_t :=
+ progress unfold LargeT, SmallT, modulus_digits, decode_large, decode_small, Mod_SmallBound, DivBy_SmallBound, DivBy_SmallerBound, Mul, CarryAdd, CarrySubSmall, ConditionalSubtract, ConditionalSubtractModulus, ZLikeOps_of_ArchitectureBoundedOps in *.
+ Local Ltac saturate_context_step :=
+ match goal with
+ | _ => unique assert (0 <= 2 * n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ]
+ | _ => unique assert (0 <= n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ]
+ | _ => unique assert (0 <= 2 * (2 * n_over_two)) by (eauto using decode_exponent_nonnegative with typeclass_instances)
+ end.
+ Local Ltac pre_t :=
+ repeat first [ tauto
+ | introduce_t_step
+ | unfolder_t
+ | saturate_context_step ].
+ Local Ltac post_t_step :=
+ match goal with
+ | _ => tauto
+ | _ => progress autorewrite with zsimplify_const in *
+ | _ => progress push_decode
+ | _ => progress autorewrite with push_Zpow in *
+ | _ => progress Z.rewrite_mod_small
+ | [ |- fst ?x = (?a <=? ?b) :> bool ]
+ => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z);
+ [ destruct (fst x), (a <=? b); intro; congruence | ]
+ | [ |- appcontext[let (a, b) := ?x in _] ]
+ => rewrite (surjective_pairing x); simplify_projections
+ | _ => progress autorewrite with Zshift_to_pow in *
+ | _ => progress autorewrite with simpl_tuple_decoder in *
+ | _ => progress autorewrite with zsimplify
+ | [ |- _ / ?y = _ / ?y ] => apply f_equal2; omega
+ | [ |- _ / _ = if _ then _ else _ ] => apply Z.div_between_0_if; auto with zarith omega
+ end.
+ Local Ltac post_t := repeat post_t_step.
+ Local Ltac t := pre_t; post_t.
+
+ Global Instance ZLikeProperties_of_ArchitectureBoundedOps
+ {arith : fancy_machine.arithmetic ops}
+ (modulus_in_range : 0 <= modulus < 2^n)
+ (smaller_bound_exp : Z)
+ (smaller_bound_smaller : 0 <= smaller_bound_exp < n)
+ : ZLikeProperties (ZLikeOps_of_ArchitectureBoundedOps ops modulus smaller_bound_exp)
+ := { large_valid v := True;
+ medium_valid v := 0 <= decode_large v < 2^n * 2^smaller_bound_exp;
+ small_valid v := True }.
+ Proof.
+ (* In 8.5: *)
+ (* par:t. *)
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ { abstract t. }
+ Defined.
+End fancy_machine_p256_montgomery_foundation.
diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v
index 7fa0d4db1..a368b96a0 100644
--- a/src/BoundedArithmetic/DoubleBounded.v
+++ b/src/BoundedArithmetic/DoubleBounded.v
@@ -1,13 +1,11 @@
(*** Implementing Large Bounded Arithmetic via pairs *)
-Require Import Coq.ZArith.ZArith Coq.Lists.List.
+Require Import Coq.ZArith.ZArith.
Require Import Crypto.BoundedArithmetic.Interface.
-Require Import Crypto.BaseSystem.
-Require Import Crypto.BaseSystemProofs.
Require Import Crypto.ModularArithmetic.Pow2Base.
-Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.Tuple.
+Require Import Crypto.Util.ListUtil.
Require Import Crypto.Util.Notations.
-Import ListNotations.
Local Open Scope list_scope.
Local Open Scope nat_scope.
Local Open Scope Z_scope.
@@ -17,6 +15,17 @@ Local Coercion Z.of_nat : nat >-> Z.
Local Notation eta x := (fst x, snd x).
Section generic_constructions.
+ Section decode.
+ Context {n W} {decode : decoder n W}.
+ Section with_k.
+ Context {k : nat}.
+ Let limb_widths := repeat n k.
+ (** The list is low to high; the tuple is low to high *)
+ Local Instance tuple_decoder : decoder (k * n) (tuple W k)
+ := { decode w := BaseSystem.decode (base_from_limb_widths limb_widths) (List.map decode (List.rev (Tuple.to_list _ w))) }.
+ End with_k.
+ End decode.
+
Definition ripple_carry {T} (f : T -> T -> bool -> bool * T)
(xs ys : list T) (carry : bool) : bool * list T
:= List.fold_right
@@ -27,17 +36,91 @@ Section generic_constructions.
(carry, nil)
(List.combine xs ys).
+ (** tuple is high to low ([to_list] reverses) *)
+ Fixpoint ripple_carry_tuple' {T} (f : T -> T -> bool -> bool * T) k
+ : forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k
+ := match k return forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k with
+ | O => f
+ | S k' => fun xss yss carry => let '(xs, x) := eta xss in
+ let '(ys, y) := eta yss in
+ let '(carry, zs) := eta (@ripple_carry_tuple' _ f k' xs ys carry) in
+ let '(carry, z) := eta (f x y carry) in
+ (carry, (zs, z))
+ end.
+
+ Definition ripple_carry_tuple {T} (f : T -> T -> bool -> bool * T) k
+ : forall (xs ys : tuple T k) (carry : bool), bool * tuple T k
+ := match k return forall (xs ys : tuple T k) (carry : bool), bool * tuple T k with
+ | O => fun xs ys carry => (carry, tt)
+ | S k' => ripple_carry_tuple' f k'
+ end.
+
Section ripple_carry_adc.
Context {n W} {decode : decoder n W} (adc : add_with_carry W).
- Global Instance ripple_carry_add_with_carry : add_with_carry (list W)
- := {| Interface.adc := ripple_carry adc |}.
- (*
- Global Instance ripple_carry_is_add_with_carry {is_adc : is_add_with_carry adc}
- : is_add_with_carry ripple_carry_add_with_carry.*)
+ Global Instance ripple_carry_adc {k} : add_with_carry (tuple W k)
+ := {| Interface.adc := ripple_carry_tuple adc k |}.
End ripple_carry_adc.
(* TODO: Would it made sense to make generic-width shift operations here? *)
- (* FUTURE: here go proofs about [ripple_carry] with [f] that satisfies [is_add_with_carry] *)
+ Section tuple2.
+ Section spread_left.
+ Context (n : Z) {W}
+ {ldi : load_immediate W}
+ {shl : shift_left_immediate W}
+ {shr : shift_right_immediate W}.
+
+ Definition spread_left_from_shift (r : W) (count : Z) : tuple W 2
+ := (shl r count, if count =? 0 then ldi 0 else shr r (n - count)).
+
+ (** Require a [decode] instance to aid typeclass search in
+ resolving [n] *)
+ Global Instance sprl_from_shift {decode : decoder n W} : spread_left_immediate W
+ := {| Interface.sprl := spread_left_from_shift |}.
+ End spread_left.
+
+ Section full_from_half.
+ Context {W}
+ {mulhwll : multiply_low_low W}
+ {mulhwhl : multiply_high_low W}
+ {mulhwhh : multiply_high_high W}
+ {adc : add_with_carry W}
+ {shl : shift_left_immediate W}
+ {shr : shift_right_immediate W}.
+
+ Section def.
+ Context (half_n : Z).
+ Definition mul_double (a b : W) : tuple W 2
+ := let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in
+ let tmp := mulhwhl a b in
+ let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in
+ let tmp := mulhwhl b a in
+ let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in
+ out.
+ End def.
+
+ Section instances.
+ Context {half_n : Z}
+ {ldi : load_immediate W}.
+
+ (** Require a dummy [decoder] for these instances to allow
+ typeclass inference of the [half_n] argument *)
+ Global Instance mul_double_multiply_low_low {decode : decoder (2 * half_n) W}
+ : multiply_low_low (tuple W 2)
+ := {| Interface.mulhwll a b := mul_double half_n (fst a) (fst b) |}.
+ Global Instance mul_double_multiply_high_low {decode : decoder (2 * half_n) W}
+ : multiply_high_low (tuple W 2)
+ := {| Interface.mulhwhl a b := mul_double half_n (snd a) (fst b) |}.
+ Global Instance mul_double_multiply_high_high {decode : decoder (2 * half_n) W}
+ : multiply_high_high (tuple W 2)
+ := {| Interface.mulhwhh a b := mul_double half_n (snd a) (snd b) |}.
+ End instances.
+ End full_from_half.
+ End tuple2.
End generic_constructions.
+
+Global Arguments tuple_decoder : simpl never.
+
+Hint Resolve (fun n W decode => (@tuple_decoder n W decode 2 : decoder (2 * n) (tuple W 2))) : typeclass_instances.
+Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances.
diff --git a/src/BoundedArithmetic/DoubleBoundedProofs.v b/src/BoundedArithmetic/DoubleBoundedProofs.v
new file mode 100644
index 000000000..d878a1373
--- /dev/null
+++ b/src/BoundedArithmetic/DoubleBoundedProofs.v
@@ -0,0 +1,354 @@
+(*** Proofs About Large Bounded Arithmetic via pairs *)
+Require Import Coq.ZArith.ZArith Coq.Lists.List Coq.micromega.Psatz.
+Require Import Crypto.BoundedArithmetic.Interface.
+Require Import Crypto.BaseSystem.
+Require Import Crypto.BaseSystemProofs.
+Require Import Crypto.ModularArithmetic.Pow2Base.
+Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
+Require Import Crypto.BoundedArithmetic.DoubleBounded.
+Require Import Crypto.Util.Tuple.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.ListUtil.
+Require Import Crypto.Util.Tactics.
+Require Import Crypto.Util.Notations.
+
+Import ListNotations.
+Local Open Scope list_scope.
+Local Open Scope nat_scope.
+Local Open Scope Z_scope.
+Local Open Scope type_scope.
+
+Local Coercion Z.of_nat : nat >-> Z.
+Local Coercion Pos.to_nat : positive >-> nat.
+Local Notation eta x := (fst x, snd x).
+
+Section generic_constructions.
+ Section decode.
+ Context {n W} {decode : decoder n W}.
+ Section with_k.
+ Context {k : nat}.
+ Local Notation limb_widths := (repeat n k).
+
+ Lemma decode_bounded {isdecode : is_decode decode} w
+ : 0 <= n -> bounded limb_widths (map decode (rev (to_list k w))).
+ Proof.
+ intro.
+ eapply bounded_uniform; try solve [ eauto using repeat_spec ].
+ { distr_length. }
+ { intros z H'.
+ apply in_map_iff in H'.
+ destruct H' as [? [? H'] ]; subst; apply decode_range. }
+ Qed.
+
+ (** TODO: Clean up this proof *)
+ Global Instance tuple_is_decode {isdecode : is_decode decode}
+ : is_decode (tuple_decoder (k := k)).
+ Proof.
+ unfold tuple_decoder; hnf; simpl.
+ intro w.
+ destruct (zerop k); [ subst | ].
+ { unfold BaseSystem.decode, BaseSystem.decode'; simpl; omega. }
+ assert (0 <= n)
+ by (destruct k as [ | [|] ]; [ omega | | destruct w ];
+ eauto using decode_exponent_nonnegative).
+ replace (2^(k * n)) with (upper_bound limb_widths)
+ by (erewrite upper_bound_uniform by eauto using repeat_spec; distr_length).
+ apply decode_upper_bound; auto using decode_bounded.
+ { intros ? H'.
+ apply repeat_spec in H'; omega. }
+ { distr_length. }
+ Qed.
+ End with_k.
+
+ Local Arguments base_from_limb_widths : simpl never.
+ Local Arguments repeat : simpl never.
+ Local Arguments Z.mul !_ !_.
+ Lemma tuple_decoder_S {k} w : 0 <= n -> (tuple_decoder (k := S (S k)) w = tuple_decoder (k := S k) (fst w) + (decode (snd w) << (S k * n)))%Z.
+ Proof.
+ intro Hn.
+ destruct w as [? w]; simpl.
+ replace (decode w) with (decode w * 1 + 0)%Z by omega.
+ rewrite map_app, map_cons, map_nil.
+ erewrite decode_shift_uniform_app by (eauto using repeat_spec; distr_length).
+ distr_length.
+ autorewrite with push_skipn natsimplify push_firstn.
+ reflexivity.
+ Qed.
+ Lemma tuple_decoder_O w : tuple_decoder (k := 1) w = decode w.
+ Proof.
+ unfold tuple_decoder, BaseSystem.decode, BaseSystem.decode', accumulate, base_from_limb_widths, repeat.
+ simpl.
+ omega.
+ Qed.
+ Lemma tuple_decoder_m1 w : tuple_decoder (k := 0) w = 0.
+ Proof. reflexivity. Qed.
+ End decode.
+ Local Arguments tuple_decoder : simpl never.
+ Local Opaque tuple_decoder.
+ Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder.
+
+ Hint Extern 1 (decoder _ (tuple ?W 2)) => apply (fun n decode => @tuple_decoder n W decode 2 : decoder (2 * n) (tuple W 2)) : typeclass_instances.
+
+ Lemma ripple_carry_tuple_SS {T} f k xss yss carry
+ : @ripple_carry_tuple T f (S (S k)) xss yss carry
+ = let '(xs, x) := eta xss in
+ let '(ys, y) := eta yss in
+ let '(carry, zs) := eta (@ripple_carry_tuple _ f (S k) xs ys carry) in
+ let '(carry, z) := eta (f x y carry) in
+ (carry, (zs, z)).
+ Proof. reflexivity. Qed.
+
+ Lemma carry_is_good (n z0 z1 k : Z)
+ : 0 <= n ->
+ 0 <= k ->
+ (z1 + z0 >> k) >> n = (z0 + z1 << k) >> (k + n) /\
+ (z0 mod 2 ^ k + ((z1 + z0 >> k) mod 2 ^ n) << k)%Z = (z0 + z1 << k) mod (2 ^ k * 2 ^ n).
+ Proof.
+ intros.
+ assert (0 < 2 ^ n) by auto with zarith.
+ assert (0 < 2 ^ k) by auto with zarith.
+ assert (0 < 2^n * 2^k) by nia.
+ autorewrite with Zshift_to_pow push_Zpow.
+ rewrite <- (Zmod_small ((z0 mod _) + _) (2^k * 2^n)) by (Z.div_mod_to_quot_rem; nia).
+ rewrite <- !Z.mul_mod_distr_r by lia.
+ rewrite !(Z.mul_comm (2^k)); pull_Zmod.
+ split; [ | apply f_equal2 ];
+ Z.div_mod_to_quot_rem; nia.
+ Qed.
+
+ Definition carry_is_good_carry n z0 z1 k H0 H1 := proj1 (@carry_is_good n z0 z1 k H0 H1).
+ Definition carry_is_good_value n z0 z1 k H0 H1 := proj2 (@carry_is_good n z0 z1 k H0 H1).
+
+ Section ripple_carry_adc.
+ Context {n W} {decode : decoder n W} (adc : add_with_carry W).
+
+ Lemma ripple_carry_adc_SS k xss yss carry
+ : ripple_carry_adc (k := S (S k)) adc xss yss carry
+ = let '(xs, x) := eta xss in
+ let '(ys, y) := eta yss in
+ let '(carry, zs) := eta (ripple_carry_adc (k := S k) adc xs ys carry) in
+ let '(carry, z) := eta (adc x y carry) in
+ (carry, (zs, z)).
+ Proof. apply ripple_carry_tuple_SS. Qed.
+
+ Local Existing Instance tuple_decoder.
+
+ Global Instance ripple_carry_is_add_with_carry {k}
+ {isdecode : is_decode decode}
+ {is_adc : is_add_with_carry adc}
+ : is_add_with_carry (ripple_carry_adc (k := k) adc).
+ Proof.
+ destruct k as [|k].
+ { constructor; simpl; intros; autorewrite with zsimplify; reflexivity. }
+ { induction k as [|k IHk].
+ { cbv [ripple_carry_adc ripple_carry_tuple to_list].
+ constructor; simpl @fst; simpl @snd; intros;
+ autorewrite with simpl_tuple_decoder;
+ push_decode;
+ autorewrite with zsimplify; reflexivity. }
+ { apply Build_is_add_with_carry'; intros x y c.
+ assert (0 <= n) by (destruct x; eauto using decode_exponent_nonnegative).
+ assert (2^n <> 0) by auto with zarith.
+ assert (0 <= S k * n) by nia.
+ rewrite !tuple_decoder_S, !ripple_carry_adc_SS by assumption.
+ simplify_projections; push_decode; generalize_decode.
+ erewrite carry_is_good_carry, carry_is_good_value by lia.
+ autorewrite with pull_Zpow push_Zof_nat zsimplify Zshift_to_pow.
+ split; apply f_equal2; nia. } }
+ Qed.
+
+ End ripple_carry_adc.
+
+ Section tuple2.
+ Section spread_left_correct.
+ Context {n W} {decode : decoder n W} {sprl : spread_left_immediate W}
+ {isdecode : is_decode decode}.
+ Lemma is_spread_left_immediate_alt
+ : is_spread_left_immediate sprl
+ <-> (forall r count, 0 <= count < n -> tuple_decoder (k := 2) (sprl r count) = (decode r << count) mod 2^(2*n)).
+ Proof.
+ split; intro H; [ | apply Build_is_spread_left_immediate' ];
+ intros r count Hc;
+ [ | specialize (H r count Hc); revert H ];
+ pose proof (decode_range r);
+ assert (0 < 2^n) by auto with zarith;
+ assert (0 <= 2^count < 2^n) by auto with zarith;
+ assert (0 <= decode r * 2^count < 2^n * 2^n) by (generalize dependent (decode r); intros; nia);
+ autorewrite with simpl_tuple_decoder; push_decode;
+ autorewrite with Zshift_to_pow zsimplify push_Zpow.
+ { reflexivity. }
+ { intro H'; rewrite <- H'.
+ autorewrite with zsimplify; split; reflexivity. }
+ Qed.
+ End spread_left_correct.
+
+ Local Arguments Z.pow !_ !_.
+ Local Arguments Z.mul !_ !_.
+
+ Section spread_left.
+ Context (n : Z) {W}
+ {ldi : load_immediate W}
+ {shl : shift_left_immediate W}
+ {shr : shift_right_immediate W}
+ {decode : decoder n W}
+ {isdecode : is_decode decode}
+ {isldi : is_load_immediate ldi}
+ {isshl : is_shift_left_immediate shl}
+ {isshr : is_shift_right_immediate shr}.
+
+ Lemma spread_left_from_shift_correct
+ r count
+ (H : 0 < count < n)
+ : (decode (shl r count) + decode (shr r (n - count)) << n = decode r << count mod 2^(2*n))%Z.
+ Proof.
+ assert (0 <= n - count < n) by lia.
+ assert (0 < 2^(n-count)) by auto with zarith.
+ assert (2^count < 2^n) by auto with zarith.
+ pose proof (decode_range r).
+ assert (0 <= decode r * 2 ^ count < 2 ^ n * 2^n) by auto with zarith.
+ simpl.
+ push_decode; autorewrite with Zshift_to_pow zsimplify.
+ replace (decode r / 2^(n-count) * 2^n)%Z with ((decode r / 2^(n-count) * 2^(n-count)) * 2^count)%Z
+ by (rewrite <- Z.mul_assoc; autorewrite with pull_Zpow zsimplify; reflexivity).
+ rewrite Z.mul_div_eq' by lia.
+ autorewrite with push_Zmul zsimplify.
+ rewrite <- Z.mul_mod_distr_r_full, Z.add_sub_assoc.
+ repeat autorewrite with pull_Zpow zsimplify in *.
+ reflexivity.
+ Qed.
+
+ Global Instance is_spread_left_from_shift
+ : is_spread_left_immediate (sprl_from_shift n).
+ Proof.
+ apply is_spread_left_immediate_alt.
+ intros r count; intros.
+ pose proof (decode_range r).
+ assert (0 < 2^n) by auto with zarith.
+ assert (decode r < 2^n * 2^n) by (generalize dependent (decode r); intros; nia).
+ autorewrite with simpl_tuple_decoder.
+ destruct (Z_zerop count).
+ { subst; autorewrite with Zshift_to_pow zsimplify.
+ simpl; push_decode.
+ autorewrite with push_Zpow zsimplify.
+ reflexivity. }
+ simpl.
+ rewrite <- spread_left_from_shift_correct by lia.
+ autorewrite with zsimplify Zpow_to_shift.
+ reflexivity.
+ Qed.
+ End spread_left.
+
+ Local Opaque ripple_carry_adc.
+ Section full_from_half.
+ Context {W}
+ {mulhwll : multiply_low_low W}
+ {mulhwhl : multiply_high_low W}
+ {mulhwhh : multiply_high_high W}
+ {adc : add_with_carry W}
+ {shl : shift_left_immediate W}
+ {shr : shift_right_immediate W}
+ {half_n : Z}
+ {ldi : load_immediate W}
+ {decode : decoder (2 * half_n) W}
+ {ismulhwll : is_mul_low_low half_n mulhwll}
+ {ismulhwhl : is_mul_high_low half_n mulhwhl}
+ {ismulhwhh : is_mul_high_high half_n mulhwhh}
+ {isadc : is_add_with_carry adc}
+ {isshl : is_shift_left_immediate shl}
+ {isshr : is_shift_right_immediate shr}
+ {isldi : is_load_immediate ldi}
+ {isdecode : is_decode decode}.
+
+ Local Arguments Z.mul !_ !_.
+ Lemma spread_left_from_shift_half_correct
+ r
+ : (decode (shl r half_n) + decode (shr r half_n) * (2^half_n * 2^half_n)
+ = (decode r * 2^half_n) mod (2^half_n*2^half_n*2^half_n*2^half_n))%Z.
+ Proof.
+ destruct (0 <? half_n) eqn:Hn; Z.ltb_to_lt.
+ { pose proof (spread_left_from_shift_correct (2*half_n) r half_n) as H.
+ specialize_by lia.
+ autorewrite with Zshift_to_pow push_Zpow zsimplify in *.
+ rewrite !Z.mul_assoc in *.
+ simpl in *; rewrite <- H; reflexivity. }
+ { pose proof (decode_range r).
+ pose proof (decode_range (shr r half_n)).
+ pose proof (decode_range (shl r half_n)).
+ simpl in *.
+ autorewrite with push_Zpow in *.
+ destruct (Z_zerop half_n).
+ { subst; simpl in *.
+ autorewrite with zsimplify.
+ nia. }
+ assert (half_n < 0) by lia.
+ assert (2^half_n = 0) by auto with zarith.
+ assert (0 < 0) by nia; omega. }
+ Qed.
+
+ Lemma decode_mul_double_mod x y
+ : (tuple_decoder (mul_double half_n x y) = (decode x * decode y) mod (2^(2 * half_n) * 2^(2*half_n)))%Z.
+ Proof.
+ assert (0 <= 2 * half_n) by eauto using decode_exponent_nonnegative.
+ assert (0 <= half_n) by omega.
+ unfold mul_double.
+ push_decode; autorewrite with simpl_tuple_decoder; simplify_projections.
+ autorewrite with zsimplify Zshift_to_pow push_Zpow.
+ rewrite !spread_left_from_shift_half_correct.
+ push_decode.
+ generalize_decode_var.
+ simpl in *.
+ autorewrite with push_Zpow in *.
+ repeat autorewrite with Zshift_to_pow zsimplify push_Zpow.
+ rewrite <- !(Z.mul_mod_distr_r_full _ _ (_^_ * _^_)), ?Z.mul_assoc.
+ Z.rewrite_mod_small.
+ push_Zmod; pull_Zmod.
+ apply f_equal2; [ | reflexivity ].
+ Z.div_mod_to_quot_rem; nia.
+ Qed.
+
+ Lemma decode_mul_double x y
+ : tuple_decoder (mul_double half_n x y) = (decode x * decode y)%Z.
+ Proof.
+ rewrite decode_mul_double_mod; generalize_decode_var.
+ simpl in *; Z.rewrite_mod_small; reflexivity.
+ Qed.
+
+ Hint Rewrite
+ (fun n (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _))
+ (fun n (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _))
+ using solve [ auto with zarith ]
+ : simpl_tuple_decoder.
+ Local Ltac t :=
+ hnf; intros [??] [??];
+ assert (0 <= 2 * half_n) by eauto using decode_exponent_nonnegative;
+ assert (0 <= half_n) by omega;
+ simpl @Interface.mulhwhh; simpl @Interface.mulhwhl; simpl @Interface.mulhwll;
+ rewrite decode_mul_double_mod; push_decode; autorewrite with simpl_tuple_decoder;
+ simpl;
+ push_decode; generalize_decode_var;
+ autorewrite with Zshift_to_pow zsimplify;
+ autorewrite with push_Zpow in *; Z.rewrite_mod_small;
+ try reflexivity.
+
+ Global Instance mul_double_is_multiply_low_low : is_mul_low_low (2 * half_n) mul_double_multiply_low_low.
+ Proof. t. Qed.
+ Global Instance mul_double_is_multiply_high_low : is_mul_high_low (2 * half_n) mul_double_multiply_high_low.
+ Proof. t. Qed.
+ Global Instance mul_double_is_multiply_high_high : is_mul_high_high (2 * half_n) mul_double_multiply_high_high.
+ Proof. t. Qed.
+ End full_from_half.
+ End tuple2.
+End generic_constructions.
+
+Hint Resolve (fun n W decode pf => (@tuple_is_decode n W decode 2 pf : @is_decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2))) : typeclass_instances.
+Hint Extern 3 (@is_decode _ (tuple ?W ?k) _) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode pf => (@tuple_is_decode n W decode k pf : @is_decode (kv * n) (tuple W k) (@tuple_decoder n W decode k : decode (kv * n) W))) : typeclass_instances.
+
+Hint Extern 2 (@is_add_with_carry _ (tuple ?W ?k) (@tuple_decoder ?n _ ?decode _) (@ripple_carry_adc _ ?adc _))
+ => apply (@ripple_carry_is_add_with_carry n W decode adc k) : typeclass_instances.
+
+Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder.
+Hint Rewrite
+ (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _))
+ (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _))
+ using solve [ auto with zarith ]
+ : simpl_tuple_decoder.
diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v
index 4a14a160b..fe64cd37e 100644
--- a/src/BoundedArithmetic/Interface.v
+++ b/src/BoundedArithmetic/Interface.v
@@ -53,18 +53,32 @@ Section InstructionGallery.
decode_shift_left_immediate :
forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n.
- Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(high, low)] *) }.
+ Record shift_right_immediate := { shr :> W -> imm -> W }.
+
+ Class is_shift_right_immediate (shr : shift_right_immediate) :=
+ decode_shift_right_immediate :
+ forall r count, 0 <= count < n -> decode (shr r count) = (decode r >> count).
+
+ Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(low, high)] *) }.
Class is_spread_left_immediate (sprl : spread_left_immediate) :=
{
decode_fst_spread_left_immediate : forall r count,
- 0 <= count < n
- -> decode (fst (sprl r count)) = (decode r << count) >> n;
- decode_snd_spread_left_immediate : forall r count,
0 <= count < n
- -> decode (snd (sprl r count)) = (decode r << count) mod 2^n
+ -> decode (fst (sprl r count)) = (decode r << count) mod 2^n;
+ decode_snd_spread_left_immediate : forall r count,
+ 0 <= count < n
+ -> decode (snd (sprl r count)) = (decode r << count) >> n;
+
}.
+ Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate)
+ (pf : forall r count, 0 <= count < n
+ -> decode (fst (sprl r count)) = (decode r << count) mod 2^n
+ /\ decode (snd (sprl r count)) = (decode r << count) >> n)
+ := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H);
+ decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}.
+
Record mask_keep_low := { mkl :> W -> imm -> W }.
Class is_mask_keep_low (mkl : mask_keep_low) :=
@@ -81,6 +95,11 @@ Section InstructionGallery.
decode_snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n)
}.
+ Definition Build_is_add_with_carry' (adc : add_with_carry)
+ (pf : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n /\ decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n))
+ := {| bit_fst_add_with_carry x y c := proj1 (pf x y c);
+ decode_snd_add_with_carry x y c := proj2 (pf x y c) |}.
+
Record sub_with_carry := { subc :> W -> W -> bool -> bool * W }.
Class is_sub_with_carry (subc:W->W->bool->bool*W) :=
@@ -89,6 +108,11 @@ Section InstructionGallery.
decode_snd_sub_with_carry : forall x y c, decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n
}.
+ Definition Build_is_sub_with_carry' (subc : sub_with_carry)
+ (pf : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) <? 0) /\ decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n)
+ := {| fst_sub_with_carry x y c := proj1 (pf x y c);
+ decode_snd_sub_with_carry x y c := proj2 (pf x y c) |}.
+
Record multiply := { mul :> W -> W -> W }.
Class is_mul (mul : multiply) :=
@@ -118,12 +142,15 @@ Section InstructionGallery.
Class is_add_modulo (addm : add_modulo) :=
decode_add_modulo : forall x y modulus,
- decode (addm x y modulus) = (decode x + decode y) mod (decode modulus).
+ decode (addm x y modulus) = (if (decode x + decode y) <? decode modulus
+ then (decode x + decode y)
+ else (decode x + decode y) - decode modulus)%Z.
End InstructionGallery.
Global Arguments load_immediate : clear implicits.
Global Arguments shift_right_doubleword_immediate : clear implicits.
Global Arguments shift_left_immediate : clear implicits.
+Global Arguments shift_right_immediate : clear implicits.
Global Arguments spread_left_immediate : clear implicits.
Global Arguments mask_keep_low : clear implicits.
Global Arguments add_with_carry : clear implicits.
@@ -137,6 +164,7 @@ Global Arguments add_modulo : clear implicits.
Global Arguments ldi {_ _} _.
Global Arguments shrd {_ _} _ _ _.
Global Arguments shl {_ _} _ _.
+Global Arguments shr {_ _} _ _.
Global Arguments sprl {_ _} _ _.
Global Arguments mkl {_ _} _ _.
Global Arguments adc {_ _} _ _ _.
@@ -151,6 +179,7 @@ Global Arguments addm {_ _} _ _ _.
Existing Class load_immediate.
Existing Class shift_right_doubleword_immediate.
Existing Class shift_left_immediate.
+Existing Class shift_right_immediate.
Existing Class spread_left_immediate.
Existing Class mask_keep_low.
Existing Class add_with_carry.
@@ -166,6 +195,7 @@ Global Arguments is_decode {_ _} _.
Global Arguments is_load_immediate {_ _ _} _.
Global Arguments is_shift_right_doubleword_immediate {_ _ _} _.
Global Arguments is_shift_left_immediate {_ _ _} _.
+Global Arguments is_shift_right_immediate {_ _ _} _.
Global Arguments is_spread_left_immediate {_ _ _} _.
Global Arguments is_mask_keep_low {_ _ _} _.
Global Arguments is_add_with_carry {_ _ _} _.
@@ -177,45 +207,91 @@ Global Arguments is_mul_high_high {_ _ _} _ _.
Global Arguments is_select_conditional {_ _ _} _.
Global Arguments is_add_modulo {_ _ _} _.
-Ltac bounded_sovlver_tac :=
- solve [ eassumption | typeclasses eauto | omega | auto 6 using decode_range with typeclass_instances omega ].
-
-Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_select_conditional @decode_add_modulo using bounded_sovlver_tac : push_decode.
-
-Ltac push_decode :=
- repeat first [ erewrite !decode_load_immediate by bounded_sovlver_tac
- | erewrite !decode_shift_right_doubleword by bounded_sovlver_tac
- | erewrite !decode_shift_left_immediate by bounded_sovlver_tac
- | erewrite !decode_fst_spread_left_immediate by bounded_sovlver_tac
- | erewrite !decode_snd_spread_left_immediate by bounded_sovlver_tac
- | erewrite !decode_mask_keep_low by bounded_sovlver_tac
- | erewrite !bit_fst_add_with_carry by bounded_sovlver_tac
- | erewrite !decode_snd_add_with_carry by bounded_sovlver_tac
- | erewrite !fst_sub_with_carry by bounded_sovlver_tac
- | erewrite !decode_snd_sub_with_carry by bounded_sovlver_tac
- | erewrite !decode_mul by bounded_sovlver_tac
- | erewrite !decode_mul_low_low by bounded_sovlver_tac
- | erewrite !decode_mul_high_low by bounded_sovlver_tac
- | erewrite !decode_mul_high_high by bounded_sovlver_tac
- | erewrite !decode_select_conditional by bounded_sovlver_tac
- | erewrite !decode_add_modulo by bounded_sovlver_tac ].
-Ltac pull_decode :=
- repeat first [ erewrite <- !decode_load_immediate by bounded_sovlver_tac
- | erewrite <- !decode_shift_right_doubleword by bounded_sovlver_tac
- | erewrite <- !decode_shift_left_immediate by bounded_sovlver_tac
- | erewrite <- !decode_fst_spread_left_immediate by bounded_sovlver_tac
- | erewrite <- !decode_snd_spread_left_immediate by bounded_sovlver_tac
- | erewrite <- !decode_mask_keep_low by bounded_sovlver_tac
- | erewrite <- !bit_fst_add_with_carry by bounded_sovlver_tac
- | erewrite <- !decode_snd_add_with_carry by bounded_sovlver_tac
- | erewrite <- !fst_sub_with_carry by bounded_sovlver_tac
- | erewrite <- !decode_snd_sub_with_carry by bounded_sovlver_tac
- | erewrite <- !decode_mul by bounded_sovlver_tac
- | erewrite <- !decode_mul_low_low by bounded_sovlver_tac
- | erewrite <- !decode_mul_high_low by bounded_sovlver_tac
- | erewrite <- !decode_mul_high_high by bounded_sovlver_tac
- | erewrite <- !decode_select_conditional by bounded_sovlver_tac
- | erewrite <- !decode_add_modulo by bounded_sovlver_tac ].
+Ltac bounded_solver_tac :=
+ solve [ eassumption | typeclasses eauto | omega ].
+
+Lemma decode_proj n W (dec : W -> Z)
+ : @decode n W {| decode := dec |} = dec.
+Proof. reflexivity. Qed.
+
+Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode}
+ (isinhabited : W)
+ : 0 <= n.
+Proof.
+ pose proof (decode_range isinhabited).
+ assert (0 < 2^n) by omega.
+ destruct (Z_lt_ge_dec n 0) as [H'|]; [ | omega ].
+ assert (2^n = 0) by auto using Z.pow_neg_r.
+ omega.
+Qed.
+
+Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_shift_right_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_select_conditional @decode_add_modulo @decode_proj using bounded_solver_tac : push_decode.
+
+Ltac push_decode_step :=
+ first [ rewrite !decode_proj
+ | erewrite !decode_load_immediate by bounded_solver_tac
+ | erewrite !decode_shift_right_doubleword by bounded_solver_tac
+ | erewrite !decode_shift_left_immediate by bounded_solver_tac
+ | erewrite !decode_shift_right_immediate by bounded_solver_tac
+ | erewrite !decode_fst_spread_left_immediate by bounded_solver_tac
+ | erewrite !decode_snd_spread_left_immediate by bounded_solver_tac
+ | erewrite !decode_mask_keep_low by bounded_solver_tac
+ | erewrite !bit_fst_add_with_carry by bounded_solver_tac
+ | erewrite !decode_snd_add_with_carry by bounded_solver_tac
+ | erewrite !fst_sub_with_carry by bounded_solver_tac
+ | erewrite !decode_snd_sub_with_carry by bounded_solver_tac
+ | erewrite !decode_mul by bounded_solver_tac
+ | erewrite !decode_mul_low_low by bounded_solver_tac
+ | erewrite !decode_mul_high_low by bounded_solver_tac
+ | erewrite !decode_mul_high_high by bounded_solver_tac
+ | erewrite !decode_select_conditional by bounded_solver_tac
+ | erewrite !decode_add_modulo by bounded_solver_tac ].
+Ltac pull_decode_step :=
+ first [ erewrite <- !decode_load_immediate by bounded_solver_tac
+ | erewrite <- !decode_shift_right_doubleword by bounded_solver_tac
+ | erewrite <- !decode_shift_left_immediate by bounded_solver_tac
+ | erewrite <- !decode_shift_right_immediate by bounded_solver_tac
+ | erewrite <- !decode_fst_spread_left_immediate by bounded_solver_tac
+ | erewrite <- !decode_snd_spread_left_immediate by bounded_solver_tac
+ | erewrite <- !decode_mask_keep_low by bounded_solver_tac
+ | erewrite <- !bit_fst_add_with_carry by bounded_solver_tac
+ | erewrite <- !decode_snd_add_with_carry by bounded_solver_tac
+ | erewrite <- !fst_sub_with_carry by bounded_solver_tac
+ | erewrite <- !decode_snd_sub_with_carry by bounded_solver_tac
+ | erewrite <- !decode_mul by bounded_solver_tac
+ | erewrite <- !decode_mul_low_low by bounded_solver_tac
+ | erewrite <- !decode_mul_high_low by bounded_solver_tac
+ | erewrite <- !decode_mul_high_high by bounded_solver_tac
+ | erewrite <- !decode_select_conditional by bounded_solver_tac
+ | erewrite <- !decode_add_modulo by bounded_solver_tac ].
+Ltac push_decode := repeat push_decode_step.
+Ltac pull_decode := repeat pull_decode_step.
+
+(* We take special care to handle the case where the decoder is
+ syntactically different but the decoded expression is judgmentally
+ the same; we don't want to split apart variables that should be the
+ same. *)
+Ltac set_decode_step check :=
+ match goal with
+ | [ |- context G[@Interface.decode ?n ?W ?dr ?w] ]
+ => check w;
+ first [ match goal with
+ | [ d := @Interface.decode _ _ _ w |- _ ]
+ => change (@Interface.decode n W dr w) with d
+ end
+ | generalize (@decode_range n W dr _ w);
+ let d := fresh "d" in
+ set (d := @Interface.decode n W dr w);
+ intro ]
+ end.
+Ltac set_decode check := repeat set_decode_step check.
+Ltac clearbody_decode :=
+ repeat match goal with
+ | [ H := @Interface.decode _ _ _ _ |- _ ] => clearbody H
+ end.
+Ltac generalize_decode_by check := set_decode check; clearbody_decode.
+Ltac generalize_decode := generalize_decode_by ltac:(fun w => idtac).
+Ltac generalize_decode_var := generalize_decode_by ltac:(fun w => is_var w).
Module fancy_machine.
Local Notation imm := Z (only parsing).
@@ -227,6 +303,7 @@ Module fancy_machine.
ldi :> load_immediate W;
shrd :> shift_right_doubleword_immediate W;
shl :> shift_left_immediate W;
+ shr :> shift_right_immediate W;
mkl :> mask_keep_low W;
adc :> add_with_carry W;
subc :> sub_with_carry W;
@@ -243,6 +320,7 @@ Module fancy_machine.
load_immediate :> is_load_immediate ldi;
shift_right_doubleword_immediate :> is_shift_right_doubleword_immediate shrd;
shift_left_immediate :> is_shift_left_immediate shl;
+ shift_right_immediate :> is_shift_right_immediate shr;
mask_keep_low :> is_mask_keep_low mkl;
add_with_carry :> is_add_with_carry adc;
sub_with_carry :> is_sub_with_carry subc;