From 6897a4f42c86c4a6bfdbab6887276e7334317661 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 23 Aug 2016 15:59:35 -0700 Subject: Hook up the bounded interface, finish proofs --- src/BoundedArithmetic/ArchitectureToZLike.v | 115 +------ src/BoundedArithmetic/ArchitectureToZLikeProofs.v | 109 +++++++ src/BoundedArithmetic/DoubleBounded.v | 105 ++++++- src/BoundedArithmetic/DoubleBoundedProofs.v | 354 ++++++++++++++++++++++ src/BoundedArithmetic/Interface.v | 168 +++++++--- 5 files changed, 690 insertions(+), 161 deletions(-) create mode 100644 src/BoundedArithmetic/ArchitectureToZLikeProofs.v create mode 100644 src/BoundedArithmetic/DoubleBoundedProofs.v (limited to 'src/BoundedArithmetic') diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v index 01387e969..e30fcfd09 100644 --- a/src/BoundedArithmetic/ArchitectureToZLike.v +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -3,123 +3,28 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.BoundedArithmetic.Interface. Require Import Crypto.BoundedArithmetic.DoubleBounded. Require Import Crypto.ModularArithmetic.ZBounded. -Require Import Coq.Lists.List. -Import ListNotations. +Require Import Crypto.Util.Tuple. -Local Open Scope nat_scope. Local Open Scope Z_scope. -Local Open Scope type_scope. - -Local Coercion Z.of_nat : nat >-> Z. Section fancy_machine_p256_montgomery_foundation. Context {n_over_two : Z}. - Local Notation n := (2 * n_over_two)%Z. + Local Notation n := (2 * n_over_two). Context (ops : fancy_machine.instructions n) (modulus : Z). - Definition two_list_to_tuple {A B} (x : A * list B) - := match x return match x with - | (a, [b0; b1]) => A * (B * B) - | _ => True - end - with - | (a, [b0; b1]) => (a, (b0, b1)) - | _ => I - end. -(* - (* make all machine-specific constructions here, preferrably as - thing wrappers around generic constructions *) - Local Instance DoubleArchitectureBoundedOps : ArchitectureBoundedOps (2 * n)%nat - := { BoundedType := BoundedType * BoundedType (* [(high, low)] *); - decode high_low := (decode (fst high_low) * 2^n + decode (snd high_low))%Z; - encode z := (encode (z / 2^n), encode (z mod 2^n)); - ShiftRight a high_low - := let '(high, low) := eta high_low in - if n <=? a then - (ShiftRight (a - n)%nat (encode 0, fst high), ShiftRight (a - n)%nat high) - else - (ShiftRight a (snd high, fst low), ShiftRight a low); - ShiftLeft a high_low - := let '(high, low) := eta high_low in - if 2 * n <=? a then - let '(high0, low) := eta (ShiftLeft (a - 2 * n)%nat low) in - let '(high_high, high1) := eta (ShiftLeft (a - 2 * n)%nat high) in - ((snd (CarryAdd false high0 high1), low), (encode 0, encode 0)) - else if n <=? a then - let '(high0, low) := eta (ShiftLeft (a - n)%nat low) in - let '(high_high, high1) := eta (ShiftLeft (a - n)%nat high) in - ((high_high, snd (CarryAdd false high0 high1)), (low, encode 0)) - else - let '(high0, low) := eta (ShiftLeft a low) in - let '(high_high, high1) := eta (ShiftLeft a high) in - ((encode 0, high_high), (snd (CarryAdd false high0 high1), low)); - Mod2Pow a high_low - := let '(high, low) := (fst high_low, snd high_low) in - (Mod2Pow (a - n)%nat high, Mod2Pow a low); - CarryAdd carry x_high_low y_high_low - := let '(xhigh, xlow) := eta x_high_low in - let '(yhigh, ylow) := eta y_high_low in - two_list_to_tuple (ripple_carry CarryAdd carry [xhigh; xlow] [yhigh; ylow]); - CarrySub carry x_high_low y_high_low - := let '(xhigh, xlow) := eta x_high_low in - let '(yhigh, ylow) := eta y_high_low in - two_list_to_tuple (ripple_carry CarrySub carry [xhigh; xlow] [yhigh; ylow]) }. - - Definition BoundedOfHalfBounded (x : @BoundedHalfType (2 * n)%nat _) : @BoundedType n _ - := match x with - | UpperHalf x => fst x - | LowerHalf x => snd x - end. - - Local Instance DoubleArchitectureBoundedHalfWidthMulOpsOfFullMulOps - {base_mops : ArchitectureBoundedFullMulOps n} - : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat := - { HalfWidthMul a b - := Mul (BoundedOfHalfBounded a) (BoundedOfHalfBounded b) }. - End single. - - Local Existing Instance DoubleArchitectureBoundedOps. - - Section full_from_half. - Context (n : size) {base_ops : ArchitectureBoundedOps (2 * n)%nat}. - - Local Infix "*" := HalfWidthMul. - - Local Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps - {base_mops : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat} - : ArchitectureBoundedFullMulOps (2 * n)%nat := - { Mul a b - := let '(a1, a0) := (UpperHalf a, LowerHalf a) in - let '(b1, b0) := (UpperHalf b, LowerHalf b) in - let out := a0 * b0 in - let outHigh := a1 * b1 in - let tmp := a1 * b0 in - let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in - let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in - let tmp := a0 * b1 in - let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in - let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in - (outHigh, out) }. - End full_from_half. - - Local Existing Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps. -*) - Axiom admit : forall {T}, T. - Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : Z) : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := - { LargeT := fancy_machine.W * fancy_machine.W; + { LargeT := tuple fancy_machine.W 2; SmallT := fancy_machine.W; modulus_digits := ldi modulus; - decode_large := _; + decode_large := decode; decode_small := decode; - Mod_SmallBound v := snd v; - DivBy_SmallBound v := fst v; - DivBy_SmallerBound v := shrd (fst v) (snd v) smaller_bound_exp; - Mul x y := _ (*mulhwll (ldi 0, x) (ldi 0, y)*); - CarryAdd x y := _ (*adc x y false*); + Mod_SmallBound v := fst v; + DivBy_SmallBound v := snd v; + DivBy_SmallerBound v := shrd (snd v) (fst v) smaller_bound_exp; + Mul x y := mulhwll (W := tuple _ 2) (sprl x 0) (sprl y 0); + CarryAdd x y := adc x y false; CarrySubSmall x y := subc x y false; - ConditionalSubtract b x := let v := selc b (ldi 0) (ldi modulus) in snd (subc x v false); + ConditionalSubtract b x := let v := selc b (ldi modulus) (ldi 0) in snd (subc x v false); ConditionalSubtractModulus y := addm y (ldi 0) (ldi modulus) }. - Abort. End fancy_machine_p256_montgomery_foundation. diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v new file mode 100644 index 000000000..b7cac2bb3 --- /dev/null +++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v @@ -0,0 +1,109 @@ +(*** Proving ℤ-Like via Architecture *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.DoubleBounded. +Require Import Crypto.BoundedArithmetic.DoubleBoundedProofs. +Require Import Crypto.BoundedArithmetic.ArchitectureToZLike. +Require Import Crypto.ModularArithmetic.ZBounded. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ZUtil Crypto.Util.Tactics. + +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. + +Section fancy_machine_p256_montgomery_foundation. + Context {n_over_two : Z}. + Local Notation n := (2 * n_over_two)%Z. + Context (ops : fancy_machine.instructions n) (modulus : Z). + + Local Arguments Z.mul !_ !_. + Local Arguments BaseSystem.decode !_ !_ / . + Local Arguments BaseSystem.accumulate / . + Local Arguments BaseSystem.decode' !_ !_ / . + + Local Ltac introduce_t_step := + match goal with + | [ |- forall x : bool, _ ] => intros [|] + | [ |- True -> _ ] => intros _ + | [ |- _ <= _ < _ -> _ ] => intro + | _ => let x := fresh "x" in + intro x; + try pose proof (decode_range (fst x)); + try pose proof (decode_range (snd x)); + pose proof (decode_range x) + end. + Local Ltac unfolder_t := + progress unfold LargeT, SmallT, modulus_digits, decode_large, decode_small, Mod_SmallBound, DivBy_SmallBound, DivBy_SmallerBound, Mul, CarryAdd, CarrySubSmall, ConditionalSubtract, ConditionalSubtractModulus, ZLikeOps_of_ArchitectureBoundedOps in *. + Local Ltac saturate_context_step := + match goal with + | _ => unique assert (0 <= 2 * n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] + | _ => unique assert (0 <= n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] + | _ => unique assert (0 <= 2 * (2 * n_over_two)) by (eauto using decode_exponent_nonnegative with typeclass_instances) + end. + Local Ltac pre_t := + repeat first [ tauto + | introduce_t_step + | unfolder_t + | saturate_context_step ]. + Local Ltac post_t_step := + match goal with + | _ => tauto + | _ => progress autorewrite with zsimplify_const in * + | _ => progress push_decode + | _ => progress autorewrite with push_Zpow in * + | _ => progress Z.rewrite_mod_small + | [ |- fst ?x = (?a <=? ?b) :> bool ] + => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); + [ destruct (fst x), (a <=? b); intro; congruence | ] + | [ |- appcontext[let (a, b) := ?x in _] ] + => rewrite (surjective_pairing x); simplify_projections + | _ => progress autorewrite with Zshift_to_pow in * + | _ => progress autorewrite with simpl_tuple_decoder in * + | _ => progress autorewrite with zsimplify + | [ |- _ / ?y = _ / ?y ] => apply f_equal2; omega + | [ |- _ / _ = if _ then _ else _ ] => apply Z.div_between_0_if; auto with zarith omega + end. + Local Ltac post_t := repeat post_t_step. + Local Ltac t := pre_t; post_t. + + Global Instance ZLikeProperties_of_ArchitectureBoundedOps + {arith : fancy_machine.arithmetic ops} + (modulus_in_range : 0 <= modulus < 2^n) + (smaller_bound_exp : Z) + (smaller_bound_smaller : 0 <= smaller_bound_exp < n) + : ZLikeProperties (ZLikeOps_of_ArchitectureBoundedOps ops modulus smaller_bound_exp) + := { large_valid v := True; + medium_valid v := 0 <= decode_large v < 2^n * 2^smaller_bound_exp; + small_valid v := True }. + Proof. + (* In 8.5: *) + (* par:t. *) + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + Defined. +End fancy_machine_p256_montgomery_foundation. diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index 7fa0d4db1..a368b96a0 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -1,13 +1,11 @@ (*** Implementing Large Bounded Arithmetic via pairs *) -Require Import Coq.ZArith.ZArith Coq.Lists.List. +Require Import Coq.ZArith.ZArith. Require Import Crypto.BoundedArithmetic.Interface. -Require Import Crypto.BaseSystem. -Require Import Crypto.BaseSystemProofs. Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil. Require Import Crypto.Util.Notations. -Import ListNotations. Local Open Scope list_scope. Local Open Scope nat_scope. Local Open Scope Z_scope. @@ -17,6 +15,17 @@ Local Coercion Z.of_nat : nat >-> Z. Local Notation eta x := (fst x, snd x). Section generic_constructions. + Section decode. + Context {n W} {decode : decoder n W}. + Section with_k. + Context {k : nat}. + Let limb_widths := repeat n k. + (** The list is low to high; the tuple is low to high *) + Local Instance tuple_decoder : decoder (k * n) (tuple W k) + := { decode w := BaseSystem.decode (base_from_limb_widths limb_widths) (List.map decode (List.rev (Tuple.to_list _ w))) }. + End with_k. + End decode. + Definition ripple_carry {T} (f : T -> T -> bool -> bool * T) (xs ys : list T) (carry : bool) : bool * list T := List.fold_right @@ -27,17 +36,91 @@ Section generic_constructions. (carry, nil) (List.combine xs ys). + (** tuple is high to low ([to_list] reverses) *) + Fixpoint ripple_carry_tuple' {T} (f : T -> T -> bool -> bool * T) k + : forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k + := match k return forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k with + | O => f + | S k' => fun xss yss carry => let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (@ripple_carry_tuple' _ f k' xs ys carry) in + let '(carry, z) := eta (f x y carry) in + (carry, (zs, z)) + end. + + Definition ripple_carry_tuple {T} (f : T -> T -> bool -> bool * T) k + : forall (xs ys : tuple T k) (carry : bool), bool * tuple T k + := match k return forall (xs ys : tuple T k) (carry : bool), bool * tuple T k with + | O => fun xs ys carry => (carry, tt) + | S k' => ripple_carry_tuple' f k' + end. + Section ripple_carry_adc. Context {n W} {decode : decoder n W} (adc : add_with_carry W). - Global Instance ripple_carry_add_with_carry : add_with_carry (list W) - := {| Interface.adc := ripple_carry adc |}. - (* - Global Instance ripple_carry_is_add_with_carry {is_adc : is_add_with_carry adc} - : is_add_with_carry ripple_carry_add_with_carry.*) + Global Instance ripple_carry_adc {k} : add_with_carry (tuple W k) + := {| Interface.adc := ripple_carry_tuple adc k |}. End ripple_carry_adc. (* TODO: Would it made sense to make generic-width shift operations here? *) - (* FUTURE: here go proofs about [ripple_carry] with [f] that satisfies [is_add_with_carry] *) + Section tuple2. + Section spread_left. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W}. + + Definition spread_left_from_shift (r : W) (count : Z) : tuple W 2 + := (shl r count, if count =? 0 then ldi 0 else shr r (n - count)). + + (** Require a [decode] instance to aid typeclass search in + resolving [n] *) + Global Instance sprl_from_shift {decode : decoder n W} : spread_left_immediate W + := {| Interface.sprl := spread_left_from_shift |}. + End spread_left. + + Section full_from_half. + Context {W} + {mulhwll : multiply_low_low W} + {mulhwhl : multiply_high_low W} + {mulhwhh : multiply_high_high W} + {adc : add_with_carry W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W}. + + Section def. + Context (half_n : Z). + Definition mul_double (a b : W) : tuple W 2 + := let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in + let tmp := mulhwhl a b in + let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + let tmp := mulhwhl b a in + let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + out. + End def. + + Section instances. + Context {half_n : Z} + {ldi : load_immediate W}. + + (** Require a dummy [decoder] for these instances to allow + typeclass inference of the [half_n] argument *) + Global Instance mul_double_multiply_low_low {decode : decoder (2 * half_n) W} + : multiply_low_low (tuple W 2) + := {| Interface.mulhwll a b := mul_double half_n (fst a) (fst b) |}. + Global Instance mul_double_multiply_high_low {decode : decoder (2 * half_n) W} + : multiply_high_low (tuple W 2) + := {| Interface.mulhwhl a b := mul_double half_n (snd a) (fst b) |}. + Global Instance mul_double_multiply_high_high {decode : decoder (2 * half_n) W} + : multiply_high_high (tuple W 2) + := {| Interface.mulhwhh a b := mul_double half_n (snd a) (snd b) |}. + End instances. + End full_from_half. + End tuple2. End generic_constructions. + +Global Arguments tuple_decoder : simpl never. + +Hint Resolve (fun n W decode => (@tuple_decoder n W decode 2 : decoder (2 * n) (tuple W 2))) : typeclass_instances. +Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances. diff --git a/src/BoundedArithmetic/DoubleBoundedProofs.v b/src/BoundedArithmetic/DoubleBoundedProofs.v new file mode 100644 index 000000000..d878a1373 --- /dev/null +++ b/src/BoundedArithmetic/DoubleBoundedProofs.v @@ -0,0 +1,354 @@ +(*** Proofs About Large Bounded Arithmetic via pairs *) +Require Import Coq.ZArith.ZArith Coq.Lists.List Coq.micromega.Psatz. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BaseSystem. +Require Import Crypto.BaseSystemProofs. +Require Import Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.ModularArithmetic.Pow2BaseProofs. +Require Import Crypto.BoundedArithmetic.DoubleBounded. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.Notations. + +Import ListNotations. +Local Open Scope list_scope. +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. +Local Coercion Pos.to_nat : positive >-> nat. +Local Notation eta x := (fst x, snd x). + +Section generic_constructions. + Section decode. + Context {n W} {decode : decoder n W}. + Section with_k. + Context {k : nat}. + Local Notation limb_widths := (repeat n k). + + Lemma decode_bounded {isdecode : is_decode decode} w + : 0 <= n -> bounded limb_widths (map decode (rev (to_list k w))). + Proof. + intro. + eapply bounded_uniform; try solve [ eauto using repeat_spec ]. + { distr_length. } + { intros z H'. + apply in_map_iff in H'. + destruct H' as [? [? H'] ]; subst; apply decode_range. } + Qed. + + (** TODO: Clean up this proof *) + Global Instance tuple_is_decode {isdecode : is_decode decode} + : is_decode (tuple_decoder (k := k)). + Proof. + unfold tuple_decoder; hnf; simpl. + intro w. + destruct (zerop k); [ subst | ]. + { unfold BaseSystem.decode, BaseSystem.decode'; simpl; omega. } + assert (0 <= n) + by (destruct k as [ | [|] ]; [ omega | | destruct w ]; + eauto using decode_exponent_nonnegative). + replace (2^(k * n)) with (upper_bound limb_widths) + by (erewrite upper_bound_uniform by eauto using repeat_spec; distr_length). + apply decode_upper_bound; auto using decode_bounded. + { intros ? H'. + apply repeat_spec in H'; omega. } + { distr_length. } + Qed. + End with_k. + + Local Arguments base_from_limb_widths : simpl never. + Local Arguments repeat : simpl never. + Local Arguments Z.mul !_ !_. + Lemma tuple_decoder_S {k} w : 0 <= n -> (tuple_decoder (k := S (S k)) w = tuple_decoder (k := S k) (fst w) + (decode (snd w) << (S k * n)))%Z. + Proof. + intro Hn. + destruct w as [? w]; simpl. + replace (decode w) with (decode w * 1 + 0)%Z by omega. + rewrite map_app, map_cons, map_nil. + erewrite decode_shift_uniform_app by (eauto using repeat_spec; distr_length). + distr_length. + autorewrite with push_skipn natsimplify push_firstn. + reflexivity. + Qed. + Lemma tuple_decoder_O w : tuple_decoder (k := 1) w = decode w. + Proof. + unfold tuple_decoder, BaseSystem.decode, BaseSystem.decode', accumulate, base_from_limb_widths, repeat. + simpl. + omega. + Qed. + Lemma tuple_decoder_m1 w : tuple_decoder (k := 0) w = 0. + Proof. reflexivity. Qed. + End decode. + Local Arguments tuple_decoder : simpl never. + Local Opaque tuple_decoder. + Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. + + Hint Extern 1 (decoder _ (tuple ?W 2)) => apply (fun n decode => @tuple_decoder n W decode 2 : decoder (2 * n) (tuple W 2)) : typeclass_instances. + + Lemma ripple_carry_tuple_SS {T} f k xss yss carry + : @ripple_carry_tuple T f (S (S k)) xss yss carry + = let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (@ripple_carry_tuple _ f (S k) xs ys carry) in + let '(carry, z) := eta (f x y carry) in + (carry, (zs, z)). + Proof. reflexivity. Qed. + + Lemma carry_is_good (n z0 z1 k : Z) + : 0 <= n -> + 0 <= k -> + (z1 + z0 >> k) >> n = (z0 + z1 << k) >> (k + n) /\ + (z0 mod 2 ^ k + ((z1 + z0 >> k) mod 2 ^ n) << k)%Z = (z0 + z1 << k) mod (2 ^ k * 2 ^ n). + Proof. + intros. + assert (0 < 2 ^ n) by auto with zarith. + assert (0 < 2 ^ k) by auto with zarith. + assert (0 < 2^n * 2^k) by nia. + autorewrite with Zshift_to_pow push_Zpow. + rewrite <- (Zmod_small ((z0 mod _) + _) (2^k * 2^n)) by (Z.div_mod_to_quot_rem; nia). + rewrite <- !Z.mul_mod_distr_r by lia. + rewrite !(Z.mul_comm (2^k)); pull_Zmod. + split; [ | apply f_equal2 ]; + Z.div_mod_to_quot_rem; nia. + Qed. + + Definition carry_is_good_carry n z0 z1 k H0 H1 := proj1 (@carry_is_good n z0 z1 k H0 H1). + Definition carry_is_good_value n z0 z1 k H0 H1 := proj2 (@carry_is_good n z0 z1 k H0 H1). + + Section ripple_carry_adc. + Context {n W} {decode : decoder n W} (adc : add_with_carry W). + + Lemma ripple_carry_adc_SS k xss yss carry + : ripple_carry_adc (k := S (S k)) adc xss yss carry + = let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (ripple_carry_adc (k := S k) adc xs ys carry) in + let '(carry, z) := eta (adc x y carry) in + (carry, (zs, z)). + Proof. apply ripple_carry_tuple_SS. Qed. + + Local Existing Instance tuple_decoder. + + Global Instance ripple_carry_is_add_with_carry {k} + {isdecode : is_decode decode} + {is_adc : is_add_with_carry adc} + : is_add_with_carry (ripple_carry_adc (k := k) adc). + Proof. + destruct k as [|k]. + { constructor; simpl; intros; autorewrite with zsimplify; reflexivity. } + { induction k as [|k IHk]. + { cbv [ripple_carry_adc ripple_carry_tuple to_list]. + constructor; simpl @fst; simpl @snd; intros; + autorewrite with simpl_tuple_decoder; + push_decode; + autorewrite with zsimplify; reflexivity. } + { apply Build_is_add_with_carry'; intros x y c. + assert (0 <= n) by (destruct x; eauto using decode_exponent_nonnegative). + assert (2^n <> 0) by auto with zarith. + assert (0 <= S k * n) by nia. + rewrite !tuple_decoder_S, !ripple_carry_adc_SS by assumption. + simplify_projections; push_decode; generalize_decode. + erewrite carry_is_good_carry, carry_is_good_value by lia. + autorewrite with pull_Zpow push_Zof_nat zsimplify Zshift_to_pow. + split; apply f_equal2; nia. } } + Qed. + + End ripple_carry_adc. + + Section tuple2. + Section spread_left_correct. + Context {n W} {decode : decoder n W} {sprl : spread_left_immediate W} + {isdecode : is_decode decode}. + Lemma is_spread_left_immediate_alt + : is_spread_left_immediate sprl + <-> (forall r count, 0 <= count < n -> tuple_decoder (k := 2) (sprl r count) = (decode r << count) mod 2^(2*n)). + Proof. + split; intro H; [ | apply Build_is_spread_left_immediate' ]; + intros r count Hc; + [ | specialize (H r count Hc); revert H ]; + pose proof (decode_range r); + assert (0 < 2^n) by auto with zarith; + assert (0 <= 2^count < 2^n) by auto with zarith; + assert (0 <= decode r * 2^count < 2^n * 2^n) by (generalize dependent (decode r); intros; nia); + autorewrite with simpl_tuple_decoder; push_decode; + autorewrite with Zshift_to_pow zsimplify push_Zpow. + { reflexivity. } + { intro H'; rewrite <- H'. + autorewrite with zsimplify; split; reflexivity. } + Qed. + End spread_left_correct. + + Local Arguments Z.pow !_ !_. + Local Arguments Z.mul !_ !_. + + Section spread_left. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {decode : decoder n W} + {isdecode : is_decode decode} + {isldi : is_load_immediate ldi} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr}. + + Lemma spread_left_from_shift_correct + r count + (H : 0 < count < n) + : (decode (shl r count) + decode (shr r (n - count)) << n = decode r << count mod 2^(2*n))%Z. + Proof. + assert (0 <= n - count < n) by lia. + assert (0 < 2^(n-count)) by auto with zarith. + assert (2^count < 2^n) by auto with zarith. + pose proof (decode_range r). + assert (0 <= decode r * 2 ^ count < 2 ^ n * 2^n) by auto with zarith. + simpl. + push_decode; autorewrite with Zshift_to_pow zsimplify. + replace (decode r / 2^(n-count) * 2^n)%Z with ((decode r / 2^(n-count) * 2^(n-count)) * 2^count)%Z + by (rewrite <- Z.mul_assoc; autorewrite with pull_Zpow zsimplify; reflexivity). + rewrite Z.mul_div_eq' by lia. + autorewrite with push_Zmul zsimplify. + rewrite <- Z.mul_mod_distr_r_full, Z.add_sub_assoc. + repeat autorewrite with pull_Zpow zsimplify in *. + reflexivity. + Qed. + + Global Instance is_spread_left_from_shift + : is_spread_left_immediate (sprl_from_shift n). + Proof. + apply is_spread_left_immediate_alt. + intros r count; intros. + pose proof (decode_range r). + assert (0 < 2^n) by auto with zarith. + assert (decode r < 2^n * 2^n) by (generalize dependent (decode r); intros; nia). + autorewrite with simpl_tuple_decoder. + destruct (Z_zerop count). + { subst; autorewrite with Zshift_to_pow zsimplify. + simpl; push_decode. + autorewrite with push_Zpow zsimplify. + reflexivity. } + simpl. + rewrite <- spread_left_from_shift_correct by lia. + autorewrite with zsimplify Zpow_to_shift. + reflexivity. + Qed. + End spread_left. + + Local Opaque ripple_carry_adc. + Section full_from_half. + Context {W} + {mulhwll : multiply_low_low W} + {mulhwhl : multiply_high_low W} + {mulhwhh : multiply_high_high W} + {adc : add_with_carry W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {half_n : Z} + {ldi : load_immediate W} + {decode : decoder (2 * half_n) W} + {ismulhwll : is_mul_low_low half_n mulhwll} + {ismulhwhl : is_mul_high_low half_n mulhwhl} + {ismulhwhh : is_mul_high_high half_n mulhwhh} + {isadc : is_add_with_carry adc} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr} + {isldi : is_load_immediate ldi} + {isdecode : is_decode decode}. + + Local Arguments Z.mul !_ !_. + Lemma spread_left_from_shift_half_correct + r + : (decode (shl r half_n) + decode (shr r half_n) * (2^half_n * 2^half_n) + = (decode r * 2^half_n) mod (2^half_n*2^half_n*2^half_n*2^half_n))%Z. + Proof. + destruct (0 (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _)) + (fun n (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _)) + using solve [ auto with zarith ] + : simpl_tuple_decoder. + Local Ltac t := + hnf; intros [??] [??]; + assert (0 <= 2 * half_n) by eauto using decode_exponent_nonnegative; + assert (0 <= half_n) by omega; + simpl @Interface.mulhwhh; simpl @Interface.mulhwhl; simpl @Interface.mulhwll; + rewrite decode_mul_double_mod; push_decode; autorewrite with simpl_tuple_decoder; + simpl; + push_decode; generalize_decode_var; + autorewrite with Zshift_to_pow zsimplify; + autorewrite with push_Zpow in *; Z.rewrite_mod_small; + try reflexivity. + + Global Instance mul_double_is_multiply_low_low : is_mul_low_low (2 * half_n) mul_double_multiply_low_low. + Proof. t. Qed. + Global Instance mul_double_is_multiply_high_low : is_mul_high_low (2 * half_n) mul_double_multiply_high_low. + Proof. t. Qed. + Global Instance mul_double_is_multiply_high_high : is_mul_high_high (2 * half_n) mul_double_multiply_high_high. + Proof. t. Qed. + End full_from_half. + End tuple2. +End generic_constructions. + +Hint Resolve (fun n W decode pf => (@tuple_is_decode n W decode 2 pf : @is_decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2))) : typeclass_instances. +Hint Extern 3 (@is_decode _ (tuple ?W ?k) _) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode pf => (@tuple_is_decode n W decode k pf : @is_decode (kv * n) (tuple W k) (@tuple_decoder n W decode k : decode (kv * n) W))) : typeclass_instances. + +Hint Extern 2 (@is_add_with_carry _ (tuple ?W ?k) (@tuple_decoder ?n _ ?decode _) (@ripple_carry_adc _ ?adc _)) + => apply (@ripple_carry_is_add_with_carry n W decode adc k) : typeclass_instances. + +Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. +Hint Rewrite + (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _)) + (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _)) + using solve [ auto with zarith ] + : simpl_tuple_decoder. diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index 4a14a160b..fe64cd37e 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -53,18 +53,32 @@ Section InstructionGallery. decode_shift_left_immediate : forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n. - Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(high, low)] *) }. + Record shift_right_immediate := { shr :> W -> imm -> W }. + + Class is_shift_right_immediate (shr : shift_right_immediate) := + decode_shift_right_immediate : + forall r count, 0 <= count < n -> decode (shr r count) = (decode r >> count). + + Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(low, high)] *) }. Class is_spread_left_immediate (sprl : spread_left_immediate) := { decode_fst_spread_left_immediate : forall r count, - 0 <= count < n - -> decode (fst (sprl r count)) = (decode r << count) >> n; - decode_snd_spread_left_immediate : forall r count, 0 <= count < n - -> decode (snd (sprl r count)) = (decode r << count) mod 2^n + -> decode (fst (sprl r count)) = (decode r << count) mod 2^n; + decode_snd_spread_left_immediate : forall r count, + 0 <= count < n + -> decode (snd (sprl r count)) = (decode r << count) >> n; + }. + Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate) + (pf : forall r count, 0 <= count < n + -> decode (fst (sprl r count)) = (decode r << count) mod 2^n + /\ decode (snd (sprl r count)) = (decode r << count) >> n) + := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H); + decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}. + Record mask_keep_low := { mkl :> W -> imm -> W }. Class is_mask_keep_low (mkl : mask_keep_low) := @@ -81,6 +95,11 @@ Section InstructionGallery. decode_snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) }. + Definition Build_is_add_with_carry' (adc : add_with_carry) + (pf : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n /\ decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n)) + := {| bit_fst_add_with_carry x y c := proj1 (pf x y c); + decode_snd_add_with_carry x y c := proj2 (pf x y c) |}. + Record sub_with_carry := { subc :> W -> W -> bool -> bool * W }. Class is_sub_with_carry (subc:W->W->bool->bool*W) := @@ -89,6 +108,11 @@ Section InstructionGallery. decode_snd_sub_with_carry : forall x y c, decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n }. + Definition Build_is_sub_with_carry' (subc : sub_with_carry) + (pf : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) W -> W -> W }. Class is_mul (mul : multiply) := @@ -118,12 +142,15 @@ Section InstructionGallery. Class is_add_modulo (addm : add_modulo) := decode_add_modulo : forall x y modulus, - decode (addm x y modulus) = (decode x + decode y) mod (decode modulus). + decode (addm x y modulus) = (if (decode x + decode y) Z) + : @decode n W {| decode := dec |} = dec. +Proof. reflexivity. Qed. + +Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode} + (isinhabited : W) + : 0 <= n. +Proof. + pose proof (decode_range isinhabited). + assert (0 < 2^n) by omega. + destruct (Z_lt_ge_dec n 0) as [H'|]; [ | omega ]. + assert (2^n = 0) by auto using Z.pow_neg_r. + omega. +Qed. + +Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_shift_right_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_select_conditional @decode_add_modulo @decode_proj using bounded_solver_tac : push_decode. + +Ltac push_decode_step := + first [ rewrite !decode_proj + | erewrite !decode_load_immediate by bounded_solver_tac + | erewrite !decode_shift_right_doubleword by bounded_solver_tac + | erewrite !decode_shift_left_immediate by bounded_solver_tac + | erewrite !decode_shift_right_immediate by bounded_solver_tac + | erewrite !decode_fst_spread_left_immediate by bounded_solver_tac + | erewrite !decode_snd_spread_left_immediate by bounded_solver_tac + | erewrite !decode_mask_keep_low by bounded_solver_tac + | erewrite !bit_fst_add_with_carry by bounded_solver_tac + | erewrite !decode_snd_add_with_carry by bounded_solver_tac + | erewrite !fst_sub_with_carry by bounded_solver_tac + | erewrite !decode_snd_sub_with_carry by bounded_solver_tac + | erewrite !decode_mul by bounded_solver_tac + | erewrite !decode_mul_low_low by bounded_solver_tac + | erewrite !decode_mul_high_low by bounded_solver_tac + | erewrite !decode_mul_high_high by bounded_solver_tac + | erewrite !decode_select_conditional by bounded_solver_tac + | erewrite !decode_add_modulo by bounded_solver_tac ]. +Ltac pull_decode_step := + first [ erewrite <- !decode_load_immediate by bounded_solver_tac + | erewrite <- !decode_shift_right_doubleword by bounded_solver_tac + | erewrite <- !decode_shift_left_immediate by bounded_solver_tac + | erewrite <- !decode_shift_right_immediate by bounded_solver_tac + | erewrite <- !decode_fst_spread_left_immediate by bounded_solver_tac + | erewrite <- !decode_snd_spread_left_immediate by bounded_solver_tac + | erewrite <- !decode_mask_keep_low by bounded_solver_tac + | erewrite <- !bit_fst_add_with_carry by bounded_solver_tac + | erewrite <- !decode_snd_add_with_carry by bounded_solver_tac + | erewrite <- !fst_sub_with_carry by bounded_solver_tac + | erewrite <- !decode_snd_sub_with_carry by bounded_solver_tac + | erewrite <- !decode_mul by bounded_solver_tac + | erewrite <- !decode_mul_low_low by bounded_solver_tac + | erewrite <- !decode_mul_high_low by bounded_solver_tac + | erewrite <- !decode_mul_high_high by bounded_solver_tac + | erewrite <- !decode_select_conditional by bounded_solver_tac + | erewrite <- !decode_add_modulo by bounded_solver_tac ]. +Ltac push_decode := repeat push_decode_step. +Ltac pull_decode := repeat pull_decode_step. + +(* We take special care to handle the case where the decoder is + syntactically different but the decoded expression is judgmentally + the same; we don't want to split apart variables that should be the + same. *) +Ltac set_decode_step check := + match goal with + | [ |- context G[@Interface.decode ?n ?W ?dr ?w] ] + => check w; + first [ match goal with + | [ d := @Interface.decode _ _ _ w |- _ ] + => change (@Interface.decode n W dr w) with d + end + | generalize (@decode_range n W dr _ w); + let d := fresh "d" in + set (d := @Interface.decode n W dr w); + intro ] + end. +Ltac set_decode check := repeat set_decode_step check. +Ltac clearbody_decode := + repeat match goal with + | [ H := @Interface.decode _ _ _ _ |- _ ] => clearbody H + end. +Ltac generalize_decode_by check := set_decode check; clearbody_decode. +Ltac generalize_decode := generalize_decode_by ltac:(fun w => idtac). +Ltac generalize_decode_var := generalize_decode_by ltac:(fun w => is_var w). Module fancy_machine. Local Notation imm := Z (only parsing). @@ -227,6 +303,7 @@ Module fancy_machine. ldi :> load_immediate W; shrd :> shift_right_doubleword_immediate W; shl :> shift_left_immediate W; + shr :> shift_right_immediate W; mkl :> mask_keep_low W; adc :> add_with_carry W; subc :> sub_with_carry W; @@ -243,6 +320,7 @@ Module fancy_machine. load_immediate :> is_load_immediate ldi; shift_right_doubleword_immediate :> is_shift_right_doubleword_immediate shrd; shift_left_immediate :> is_shift_left_immediate shl; + shift_right_immediate :> is_shift_right_immediate shr; mask_keep_low :> is_mask_keep_low mkl; add_with_carry :> is_add_with_carry adc; sub_with_carry :> is_sub_with_carry subc; -- cgit v1.2.3