From c9fc5a3cdf1f5ea2d104c150c30d1b1a6ac64239 Mon Sep 17 00:00:00 2001 From: Andres Erbsen Date: Thu, 6 Apr 2017 22:53:07 -0400 Subject: rename-everything --- src/LegacyArithmetic/ArchitectureToZLike.v | 38 ++ src/LegacyArithmetic/ArchitectureToZLikeProofs.v | 127 +++++ src/LegacyArithmetic/BarretReduction.v | 100 ++++ src/LegacyArithmetic/BaseSystem.v | 39 ++ src/LegacyArithmetic/BaseSystemProofs.v | 133 +++++ src/LegacyArithmetic/Double/Core.v | 253 ++++++++++ src/LegacyArithmetic/Double/Proofs/BitwiseOr.v | 31 ++ src/LegacyArithmetic/Double/Proofs/Decode.v | 184 +++++++ src/LegacyArithmetic/Double/Proofs/LoadImmediate.v | 32 ++ src/LegacyArithmetic/Double/Proofs/Multiply.v | 132 +++++ .../Double/Proofs/RippleCarryAddSub.v | 198 ++++++++ .../Double/Proofs/SelectConditional.v | 25 + src/LegacyArithmetic/Double/Proofs/ShiftLeft.v | 43 ++ .../Double/Proofs/ShiftLeftRightTactic.v | 41 ++ src/LegacyArithmetic/Double/Proofs/ShiftRight.v | 44 ++ .../Double/Proofs/ShiftRightDoubleWordImmediate.v | 42 ++ .../Double/Proofs/SpreadLeftImmediate.v | 148 ++++++ src/LegacyArithmetic/Interface.v | 450 +++++++++++++++++ src/LegacyArithmetic/InterfaceProofs.v | 224 +++++++++ src/LegacyArithmetic/MontgomeryReduction.v | 114 +++++ src/LegacyArithmetic/Pow2Base.v | 19 + src/LegacyArithmetic/Pow2BaseProofs.v | 555 +++++++++++++++++++++ src/LegacyArithmetic/README.md | 3 + src/LegacyArithmetic/VerdiTactics.v | 414 +++++++++++++++ src/LegacyArithmetic/ZBounded.v | 158 ++++++ src/LegacyArithmetic/ZBoundedZ.v | 88 ++++ 26 files changed, 3635 insertions(+) create mode 100644 src/LegacyArithmetic/ArchitectureToZLike.v create mode 100644 src/LegacyArithmetic/ArchitectureToZLikeProofs.v create mode 100644 src/LegacyArithmetic/BarretReduction.v create mode 100644 src/LegacyArithmetic/BaseSystem.v create mode 100644 src/LegacyArithmetic/BaseSystemProofs.v create mode 100644 src/LegacyArithmetic/Double/Core.v create mode 100644 src/LegacyArithmetic/Double/Proofs/BitwiseOr.v create mode 100644 src/LegacyArithmetic/Double/Proofs/Decode.v create mode 100644 src/LegacyArithmetic/Double/Proofs/LoadImmediate.v create mode 100644 src/LegacyArithmetic/Double/Proofs/Multiply.v create mode 100644 src/LegacyArithmetic/Double/Proofs/RippleCarryAddSub.v create mode 100644 src/LegacyArithmetic/Double/Proofs/SelectConditional.v create mode 100644 src/LegacyArithmetic/Double/Proofs/ShiftLeft.v create mode 100644 src/LegacyArithmetic/Double/Proofs/ShiftLeftRightTactic.v create mode 100644 src/LegacyArithmetic/Double/Proofs/ShiftRight.v create mode 100644 src/LegacyArithmetic/Double/Proofs/ShiftRightDoubleWordImmediate.v create mode 100644 src/LegacyArithmetic/Double/Proofs/SpreadLeftImmediate.v create mode 100644 src/LegacyArithmetic/Interface.v create mode 100644 src/LegacyArithmetic/InterfaceProofs.v create mode 100644 src/LegacyArithmetic/MontgomeryReduction.v create mode 100644 src/LegacyArithmetic/Pow2Base.v create mode 100644 src/LegacyArithmetic/Pow2BaseProofs.v create mode 100644 src/LegacyArithmetic/README.md create mode 100644 src/LegacyArithmetic/VerdiTactics.v create mode 100644 src/LegacyArithmetic/ZBounded.v create mode 100644 src/LegacyArithmetic/ZBoundedZ.v (limited to 'src/LegacyArithmetic') diff --git a/src/LegacyArithmetic/ArchitectureToZLike.v b/src/LegacyArithmetic/ArchitectureToZLike.v new file mode 100644 index 000000000..19450f831 --- /dev/null +++ b/src/LegacyArithmetic/ArchitectureToZLike.v @@ -0,0 +1,38 @@ +(*** Implementing ℤ-Like via Architecture *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.Double.Core. +Require Import Crypto.LegacyArithmetic.ZBounded. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.LetIn. + +Local Open Scope Z_scope. + +Section fancy_machine_p256_montgomery_foundation. + Context {n_over_two : Z}. + Local Notation n := (2 * n_over_two). + Context (ops : fancy_machine.instructions n) (modulus : Z). + + Local Instance ZLikeOps_of_ArchitectureBoundedOps_Factored (smaller_bound_exp : Z) + ldi_modulus ldi_0 + : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := + { LargeT := tuple fancy_machine.W 2; + SmallT := fancy_machine.W; + modulus_digits := ldi_modulus; + decode_large := decode; + decode_small := decode; + Mod_SmallBound v := fst v; + DivBy_SmallBound v := snd v; + DivBy_SmallerBound v := if smaller_bound_exp =? n + then snd v + else dlet v := v in shrd (snd v) (fst v) smaller_bound_exp; + Mul x y := muldw x y; + CarryAdd x y := adc x y false; + CarrySubSmall x y := subc x y false; + ConditionalSubtract b x := let v := selc b (ldi_modulus) (ldi_0) in snd (subc x v false); + ConditionalSubtractModulus y := addm y (ldi_0) (ldi_modulus) }. + + Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : Z) + : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := + @ZLikeOps_of_ArchitectureBoundedOps_Factored smaller_bound_exp (ldi modulus) (ldi 0). +End fancy_machine_p256_montgomery_foundation. diff --git a/src/LegacyArithmetic/ArchitectureToZLikeProofs.v b/src/LegacyArithmetic/ArchitectureToZLikeProofs.v new file mode 100644 index 000000000..8d4b59ceb --- /dev/null +++ b/src/LegacyArithmetic/ArchitectureToZLikeProofs.v @@ -0,0 +1,127 @@ +(*** Proving ℤ-Like via Architecture *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.InterfaceProofs. +Require Import Crypto.LegacyArithmetic.Double.Core. +Require Import Crypto.LegacyArithmetic.Double.Proofs.RippleCarryAddSub. +Require Import Crypto.LegacyArithmetic.Double.Proofs.Multiply. +Require Import Crypto.LegacyArithmetic.ArchitectureToZLike. +Require Import Crypto.LegacyArithmetic.ZBounded. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.LetIn. + +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. + +Section fancy_machine_p256_montgomery_foundation. + Context {n_over_two : Z}. + Local Notation n := (2 * n_over_two)%Z. + Context (ops : fancy_machine.instructions n) (modulus : Z). + + Local Arguments Z.mul !_ !_. + Local Arguments BaseSystem.decode !_ !_ / . + Local Arguments BaseSystem.accumulate / _ _. + Local Arguments BaseSystem.decode' !_ !_ / . + + Local Ltac introduce_t_step := + match goal with + | [ |- forall x : bool, _ ] => intros [|] + | [ |- True -> _ ] => intros _ + | [ |- _ <= _ < _ -> _ ] => intro + | _ => let x := fresh "x" in + intro x; + try pose proof (decode_range (fst x)); + try pose proof (decode_range (snd x)); + pose proof (decode_range x) + end. + Local Ltac unfolder_t := + progress unfold LargeT, SmallT, modulus_digits, decode_large, decode_small, Mod_SmallBound, DivBy_SmallBound, DivBy_SmallerBound, Mul, CarryAdd, CarrySubSmall, ConditionalSubtract, ConditionalSubtractModulus, ZLikeOps_of_ArchitectureBoundedOps, ZLikeOps_of_ArchitectureBoundedOps_Factored in *. + Local Ltac saturate_context_step := + match goal with + | _ => unique assert (0 <= 2 * n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] + | _ => unique assert (0 <= n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] + | _ => unique assert (0 <= 2 * (2 * n_over_two)) by (eauto using decode_exponent_nonnegative with typeclass_instances) + | [ H : 0 <= ?x < _ |- _ ] => unique pose proof (proj1 H); unique pose proof (proj2 H) + end. + Local Ltac pre_t := + repeat first [ tauto + | introduce_t_step + | unfolder_t + | saturate_context_step ]. + Local Ltac post_t_step := + match goal with + | _ => reflexivity + | _ => progress subst + | _ => progress unfold Let_In + | _ => progress autorewrite with zsimplify_const + | [ |- fst ?x = (?a <=? ?b) :> bool ] + => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); + [ destruct (fst x), (a <=? b); intro; congruence | ] + | [ H : (_ =? _) = true |- _ ] => apply Z.eqb_eq in H; subst + | [ H : (_ =? _) = false |- _ ] => apply Z.eqb_neq in H + | _ => autorewrite with push_Zpow in *; solve [ reflexivity | assumption ] + | _ => autorewrite with pull_Zpow in *; pull_decode; reflexivity + | _ => progress push_decode + | _ => rewrite (Z.add_comm (_ << _) _); progress pull_decode + | [ |- context[if ?x =? ?y then _ else _] ] => destruct (x =? y) eqn:? + | _ => autorewrite with Zshift_to_pow; Z.rewrite_mod_small; reflexivity + end. + Local Ltac post_t := repeat post_t_step. + Local Ltac t := pre_t; post_t. + + Global Instance ZLikeProperties_of_ArchitectureBoundedOps_Factored + {arith : fancy_machine.arithmetic ops} + ldi_modulus ldi_0 + (Hldi_modulus : ldi_modulus = ldi modulus) + (Hldi_0 : ldi_0 = ldi 0) + (modulus_in_range : 0 <= modulus < 2^n) + (smaller_bound_exp : Z) + (smaller_bound_smaller : 0 <= smaller_bound_exp <= n) + (n_pos : 0 < n) + : ZLikeProperties (ZLikeOps_of_ArchitectureBoundedOps_Factored ops modulus smaller_bound_exp ldi_modulus ldi_0) + := { large_valid v := True; + medium_valid v := 0 <= decode_large v < 2^n * 2^smaller_bound_exp; + small_valid v := True }. + Proof. + (* In 8.5: *) + (* par:t. *) + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + Defined. + + Global Instance ZLikeProperties_of_ArchitectureBoundedOps + {arith : fancy_machine.arithmetic ops} + (modulus_in_range : 0 <= modulus < 2^n) + (smaller_bound_exp : Z) + (smaller_bound_smaller : 0 <= smaller_bound_exp <= n) + (n_pos : 0 < n) + : ZLikeProperties (ZLikeOps_of_ArchitectureBoundedOps ops modulus smaller_bound_exp) + := ZLikeProperties_of_ArchitectureBoundedOps_Factored _ _ eq_refl eq_refl modulus_in_range _ smaller_bound_smaller n_pos. +End fancy_machine_p256_montgomery_foundation. diff --git a/src/LegacyArithmetic/BarretReduction.v b/src/LegacyArithmetic/BarretReduction.v new file mode 100644 index 000000000..1be9361ba --- /dev/null +++ b/src/LegacyArithmetic/BarretReduction.v @@ -0,0 +1,100 @@ +(*** Barrett Reduction *) +(** This file implements Barrett Reduction on [ZLikeOps]. We follow + [BarretReduction/ZHandbook.v]. *) +Require Import Coq.ZArith.ZArith Coq.Lists.List Coq.Classes.Morphisms Coq.micromega.Psatz. +Require Import Crypto.Arithmetic.BarrettReduction.HAC. +Require Import Crypto.LegacyArithmetic.ZBounded. +Require Import Crypto.Util.ZUtil. +(*Require Import Crypto.Util.Tactics.*) +Require Import Crypto.Util.Notations. + +Local Open Scope small_zlike_scope. +Local Open Scope large_zlike_scope. +Local Open Scope Z_scope. + +Section barrett. + Context (m b k μ offset : Z) + (m_pos : 0 < m) + (base_pos : 0 < b) + (k_good : m < b^k) + (μ_good : μ = b^(2*k) / m) (* [/] is [Z.div], which is truncated *) + (offset_nonneg : 0 <= offset) + (k_big_enough : offset <= k) + (m_small : 3 * m <= b^(k+offset)) + (m_large : b^(k-offset) <= m + 1). + Context {ops : ZLikeOps (b^(k+offset)) (b^(k-offset)) m} {props : ZLikeProperties ops} + (μ' : SmallT) + (μ'_good : small_valid μ') + (μ'_eq : decode_small μ' = μ). + + Definition barrett_reduce : forall x : LargeT, + { barrett_reduce : SmallT + | medium_valid x + -> decode_small barrett_reduce = (decode_large x) mod m + /\ small_valid barrett_reduce }. + Proof. + intro x. evar (pr : SmallT); exists pr. intros x_valid. + assert (0 <= decode_large x < b^(k+offset) * b^(k-offset)) by auto using decode_medium_valid. + assert (0 <= decode_large x < b^(2 * k)) by (autorewrite with pull_Zpow zsimplify in *; omega). + assert ((decode_large x) mod b^(k-offset) < b^(k-offset)) by auto with zarith omega. + rewrite (barrett_reduction_small m b (decode_large x) k μ offset) by omega. + rewrite <- μ'_eq. + pull_zlike_decode; cbv zeta; pull_zlike_decode. (* Extra [cbv iota; pull_zlike_decode] to work around bug #4165 (https://coq.inria.fr/bugs/show_bug.cgi?id=4165) in 8.4 *) + subst pr; split; [ reflexivity | exact _ ]. + Defined. + + Definition barrett_reduce_function : LargeT -> SmallT + := Eval cbv [proj1_sig barrett_reduce] + in fun x => proj1_sig (barrett_reduce x). + Lemma barrett_reduce_correct x + : medium_valid x + -> decode_small (barrett_reduce_function x) = (decode_large x) mod m + /\ small_valid (barrett_reduce_function x). + Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg μ'_eq μ'_good μ_good. + exact (proj2_sig (barrett_reduce x)). + Qed. +End barrett. + +Module BarrettBundled. + Class BarrettParameters := + { m : Z; + b : Z; + k : Z; + offset : Z; + μ := b ^ (2 * k) / m; + ops : ZLikeOps (b ^ (k + offset)) (b ^ (k - offset)) m; + μ' : SmallT }. + Global Existing Instance ops. + + Class BarrettParametersCorrect {params : BarrettParameters} := + { m_pos : 0 < m; + base_pos : 0 < b; + offset_nonneg : 0 <= offset; + k_big_enough : offset <= k; + m_small : 3 * m <= b ^ (k + offset); + m_large : b ^ (k - offset) <= m + 1; + props : ZLikeProperties ops; + μ'_good : small_valid μ'; + μ'_eq : decode_small μ' = μ }. + Global Arguments BarrettParametersCorrect : clear implicits. + Global Existing Instance props. + + Module Export functions. + Definition barrett_reduce_function_bundled {params : BarrettParameters} + : LargeT -> SmallT + := barrett_reduce_function m b k offset μ'. + Definition barrett_reduce_correct_bundled {params : BarrettParameters} {params_proofs : BarrettParametersCorrect params} + : forall x, medium_valid x + -> decode_small (barrett_reduce_function_bundled x) = (decode_large x) mod m + /\ small_valid (barrett_reduce_function_bundled x) + := @barrett_reduce_correct + m b k μ offset + m_pos base_pos eq_refl offset_nonneg + k_big_enough m_small m_large + ops props μ' μ'_good μ'_eq. + End functions. +End BarrettBundled. +Export BarrettBundled.functions. +Global Existing Instance BarrettBundled.ops. +Global Arguments BarrettBundled.BarrettParametersCorrect : clear implicits. +Global Existing Instance BarrettBundled.props. diff --git a/src/LegacyArithmetic/BaseSystem.v b/src/LegacyArithmetic/BaseSystem.v new file mode 100644 index 000000000..a54bc483f --- /dev/null +++ b/src/LegacyArithmetic/BaseSystem.v @@ -0,0 +1,39 @@ +Require Import Coq.Lists.List. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zdiv. +Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith. +Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil. +Require Import Crypto.Util.Notations. +Require Export Crypto.Util.FixCoqMistakes. +Import Nat. + +Local Open Scope Z. + +Class BaseVector (base : list Z):= { + base_positive : forall b, In b base -> b > 0; (* nonzero would probably work too... *) + b0_1 : forall x, nth_default x base 0 = 1; (** TODO(jadep,jgross): change to [nth_error base 0 = Some 1], then use [nth_error_value_eq_nth_default] to prove a [forall x, nth_default x base 0 = 1] as a lemma *) + base_good : + forall i j, (i+j < length base)%nat -> + let b := nth_default 0 base in + let r := (b i * b j) / b (i+j)%nat in + b i * b j = r * b (i+j)%nat +}. + +Section BaseSystem. + Context (base : list Z). + (** [BaseSystem] implements an constrained positional number system. + A wide variety of bases are supported: the base coefficients are not + required to be powers of 2, and it is NOT necessarily the case that + $b_{i+j} = b_i b_j$. Implementations of addition and multiplication are + provided, with focus on near-optimal multiplication performance on + non-trivial but small operands: maybe 10 32-bit integers or so. This + module does not handle carries automatically: if no restrictions are put + on the use of a [BaseSystem], each digit is unbounded. This has nothing + to do with modular arithmetic either. + *) + Definition digits : Type := list Z. + + Definition accumulate p acc := fst p * snd p + acc. + Definition decode' bs u := fold_right accumulate 0 (combine u bs). + Definition decode := decode' base. + Definition mul_each u := map (Z.mul u). +End BaseSystem. \ No newline at end of file diff --git a/src/LegacyArithmetic/BaseSystemProofs.v b/src/LegacyArithmetic/BaseSystemProofs.v new file mode 100644 index 000000000..9a06109d1 --- /dev/null +++ b/src/LegacyArithmetic/BaseSystemProofs.v @@ -0,0 +1,133 @@ +Require Import Coq.Lists.List Coq.micromega.Psatz. +Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zdiv. +Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith. +Require Import Crypto.LegacyArithmetic.BaseSystem. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.Notations. +Import Morphisms. +Local Open Scope Z. + +Local Hint Extern 1 (@eq Z _ _) => ring. + +Section BaseSystemProofs. + Context `(base_vector : BaseVector). + + Lemma decode'_truncate : forall bs us, decode' bs us = decode' bs (firstn (length bs) us). + Proof using Type. + unfold decode'; intros; f_equal; apply combine_truncate_l. + Qed. + + Lemma decode'_splice : forall xs ys bs, + decode' bs (xs ++ ys) = + decode' (firstn (length xs) bs) xs + decode' (skipn (length xs) bs) ys. + Proof using Type. + unfold decode'. + induction xs; destruct ys, bs; boring. + + rewrite combine_truncate_r. + do 2 rewrite Z.add_0_r; auto. + + unfold accumulate. + apply Z.add_assoc. + Qed. + + Lemma decode_nil : forall bs, decode' bs nil = 0. + Proof using Type. + + auto. + Qed. + Hint Rewrite decode_nil. + + Lemma decode_base_nil : forall us, decode' nil us = 0. + Proof using Type. + intros; rewrite decode'_truncate; auto. + Qed. + + Lemma mul_each_rep : forall bs u vs, + decode' bs (mul_each u vs) = u * decode' bs vs. + Proof using Type. + unfold decode', accumulate; induction bs; destruct vs; boring; ring. + Qed. + + Lemma base_eq_1cons: base = 1 :: skipn 1 base. + Proof using Type*. + pose proof (b0_1 0) as H. + destruct base; compute in H; try discriminate; boring. + Qed. + + Lemma decode'_cons : forall x1 x2 xs1 xs2, + decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2. + Proof using Type. + unfold decode', accumulate; boring; ring. + Qed. + Hint Rewrite decode'_cons. + + Lemma decode_cons : forall x us, + decode base (x :: us) = x + decode base (0 :: us). + Proof using Type*. + unfold decode; intros. + rewrite base_eq_1cons. + autorewrite with core; ring_simplify; auto. + Qed. + + Lemma decode'_map_mul : forall v xs bs, + decode' (map (Z.mul v) bs) xs = + Z.mul v (decode' bs xs). + Proof using Type. + unfold decode'. + induction xs; destruct bs; boring. + unfold accumulate; simpl; nia. + Qed. + + Lemma decode_map_mul : forall v xs, + decode (map (Z.mul v) base) xs = + Z.mul v (decode base xs). + Proof using Type. + unfold decode; intros; apply decode'_map_mul. + Qed. + + Lemma mul_each_base : forall us bs c, + decode' bs (mul_each c us) = decode' (mul_each c bs) us. + Proof using Type. + induction us; destruct bs; boring; ring. + Qed. + + Hint Rewrite (@nth_default_nil Z). + Hint Rewrite (@firstn_nil Z). + Hint Rewrite (@skipn_nil Z). + + Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys. + Proof using Type. + boring. + Qed. + Hint Rewrite peel_decode. + + Hint Rewrite plus_0_r. + + Lemma set_higher : forall bs vs x, + decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x. + Proof using Type. + intros. + rewrite !decode'_splice. + cbv [decode' nth_default]; break_match; ring_simplify; + match goal with + | [H:_ |- _] => unique pose proof (nth_error_error_length _ _ _ H) + | [H:_ |- _] => unique pose proof (nth_error_value_length _ _ _ _ H) + end; + repeat match goal with + | _ => solve [simpl;ring_simplify; trivial] + | _ => progress ring_simplify + | _ => progress rewrite skipn_all by trivial + | _ => progress rewrite combine_nil_r + | _ => progress rewrite firstn_all2 by trivial + end. + rewrite (combine_truncate_r vs bs); apply (f_equal2 Z.add); trivial; []. + unfold combine; break_match. + { apply (f_equal (@length _)) in Heql; simpl length in Heql; rewrite skipn_length in Heql; omega. } + { cbv -[Z.add Z.mul]; ring_simplify; f_equal. + assert (HH: nth_error (z0 :: l) 0 = Some z) by + ( + pose proof @nth_error_skipn _ (length vs) bs 0; + rewrite plus_0_r in *; + congruence); simpl in HH; congruence. } + Qed. +End BaseSystemProofs. \ No newline at end of file diff --git a/src/LegacyArithmetic/Double/Core.v b/src/LegacyArithmetic/Double/Core.v new file mode 100644 index 000000000..b7be2d18a --- /dev/null +++ b/src/LegacyArithmetic/Double/Core.v @@ -0,0 +1,253 @@ +(*** Implementing Large Bounded Arithmetic via pairs *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.InterfaceProofs. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.LetIn. +Import Bug5107WorkAround. + +Require Crypto.LegacyArithmetic.BaseSystem. +Require Crypto.LegacyArithmetic.Pow2Base. + +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. +Local Notation eta x := (fst x, snd x). + +(** The list is low to high; the tuple is low to high *) +Definition tuple_decoder {n W} {decode : decoder n W} {k : nat} : decoder (k * n) (tuple W k) + := {| decode w := BaseSystem.decode (Pow2Base.base_from_limb_widths (repeat n k)) + (List.map decode (List.rev (Tuple.to_list _ w))) |}. +Global Arguments tuple_decoder : simpl never. +Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances. + +Section ripple_carry_definitions. + (** tuple is high to low ([to_list] reverses) *) + Fixpoint ripple_carry_tuple' {T} (f : T -> T -> bool -> bool * T) k + : forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k + := match k return forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k with + | O => f + | S k' => fun xss yss carry => dlet xss := xss in + dlet yss := yss in + let (xs, x) := eta xss in + let (ys, y) := eta yss in + dlet addv := (@ripple_carry_tuple' _ f k' xs ys carry) in + let (carry, zs) := eta addv in + dlet fxy := (f x y carry) in + let (carry, z) := eta fxy in + (carry, (zs, z)) + end. + + Definition ripple_carry_tuple {T} (f : T -> T -> bool -> bool * T) k + : forall (xs ys : tuple T k) (carry : bool), bool * tuple T k + := match k return forall (xs ys : tuple T k) (carry : bool), bool * tuple T k with + | O => fun xs ys carry => (carry, tt) + | S k' => ripple_carry_tuple' f k' + end. +End ripple_carry_definitions. + +Global Instance ripple_carry_adc + {W} (adc : add_with_carry W) {k} + : add_with_carry (tuple W k) + := { adc := ripple_carry_tuple adc k }. + +Global Instance ripple_carry_subc + {W} (subc : sub_with_carry W) {k} + : sub_with_carry (tuple W k) + := { subc := ripple_carry_tuple subc k }. + +(** constructions on [tuple W 2] *) +Section tuple2. + Section select_conditional. + Context {W} + {selc : select_conditional W}. + + Definition select_conditional_double (b : bool) (x : tuple W 2) (y : tuple W 2) : tuple W 2 + := dlet x := x in + dlet y := y in + let (x1, x2) := eta x in + let (y1, y2) := eta y in + (selc b x1 y1, selc b x2 y2). + + Global Instance selc_double : select_conditional (tuple W 2) + := { selc := select_conditional_double }. + End select_conditional. + + Section load_immediate. + Context (n : Z) {W} + {ldi : load_immediate W}. + + Definition load_immediate_double (r : Z) : tuple W 2 + := (ldi (r mod 2^n), ldi (r / 2^n)). + + (** Require a [decoder] instance to aid typeclass search in + resolving [n] *) + Global Instance ldi_double {decode : decoder n W} : load_immediate (tuple W 2) + := { ldi := load_immediate_double }. + End load_immediate. + + Section bitwise_or. + Context {W} + {or : bitwise_or W}. + + Definition bitwise_or_double (x : tuple W 2) (y : tuple W 2) : tuple W 2 + := dlet x := x in + dlet y := y in + let (x1, x2) := eta x in + let (y1, y2) := eta y in + (or x1 y1, or x2 y2). + + Global Instance or_double : bitwise_or (tuple W 2) + := { or := bitwise_or_double }. + End bitwise_or. + + Section bitwise_and. + Context {W} + {and : bitwise_and W}. + + Definition bitwise_and_double (x : tuple W 2) (y : tuple W 2) : tuple W 2 + := dlet x := x in + dlet y := y in + let (x1, x2) := eta x in + let (y1, y2) := eta y in + (and x1 y1, and x2 y2). + + Global Instance and_double : bitwise_and (tuple W 2) + := { and := bitwise_and_double }. + End bitwise_and. + + Section spread_left. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W}. + + Definition spread_left_from_shift (r : W) (count : Z) : tuple W 2 + := dlet r := r in + (shl r count, if count =? 0 then ldi 0 else shr r (n - count)). + + (** Require a [decoder] instance to aid typeclass search in + resolving [n] *) + Global Instance sprl_from_shift {decode : decoder n W} : spread_left_immediate W + := { sprl := spread_left_from_shift }. + End spread_left. + + Section shl_shr. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {or : bitwise_or W}. + + Definition shift_left_immediate_double (r : tuple W 2) (count : Z) : tuple W 2 + := dlet r := r in + let (r1, r2) := eta r in + (if count =? 0 + then r1 + else if count -> Z. + +Import BoundedRewriteNotations. +Local Open Scope Z_scope. + +Section decode. + Context {n W} {decode : decoder n W}. + Section with_k. + Context {k : nat}. + Local Notation limb_widths := (repeat n k). + + Lemma decode_bounded {isdecode : is_decode decode} w + : 0 <= n -> Pow2Base.bounded limb_widths (List.map decode (rev (to_list k w))). + Proof using Type. + intro. + eapply Pow2BaseProofs.bounded_uniform; try solve [ eauto using repeat_spec ]. + { distr_length. } + { intros z H'. + apply in_map_iff in H'. + destruct H' as [? [? H'] ]; subst; apply decode_range. } + Qed. + + (** TODO: Clean up this proof *) + Global Instance tuple_is_decode {isdecode : is_decode decode} + : is_decode (tuple_decoder (k := k)). + Proof using Type. + unfold tuple_decoder; hnf; simpl. + intro w. + destruct (zerop k); [ subst | ]. + { cbv; intuition congruence. } + assert (0 <= n) + by (destruct k as [ | [|] ]; [ omega | | destruct w ]; + eauto using decode_exponent_nonnegative). + replace (2^(k * n)) with (Pow2Base.upper_bound limb_widths) + by (erewrite Pow2BaseProofs.upper_bound_uniform by eauto using repeat_spec; distr_length). + apply Pow2BaseProofs.decode_upper_bound; auto using decode_bounded. + { intros ? H'. + apply repeat_spec in H'; omega. } + { distr_length. } + Qed. + End with_k. + + Local Arguments Pow2Base.base_from_limb_widths : simpl never. + Local Arguments repeat : simpl never. + Local Arguments Z.mul !_ !_. + Lemma tuple_decoder_S {k} w : 0 <= n -> (tuple_decoder (k := S (S k)) w = tuple_decoder (k := S k) (fst w) + (decode (snd w) << (S k * n)))%Z. + Proof using Type. + intro Hn. + destruct w as [? w]; simpl. + replace (decode w) with (decode w * 1 + 0)%Z by omega. + rewrite map_app, map_cons, map_nil. + erewrite Pow2BaseProofs.decode_shift_uniform_app by (eauto using repeat_spec; distr_length). + distr_length. + autorewrite with push_skipn natsimplify push_firstn. + reflexivity. + Qed. + Global Instance tuple_decoder_O w : tuple_decoder (k := 1) w =~> decode w. + Proof using Type. + cbv [tuple_decoder LegacyArithmetic.BaseSystem.decode LegacyArithmetic.BaseSystem.decode' LegacyArithmetic.BaseSystem.accumulate Pow2Base.base_from_limb_widths repeat]. + simpl; hnf; lia. + Qed. + Global Instance tuple_decoder_m1 w : tuple_decoder (k := 0) w =~> 0. + Proof using Type. reflexivity. Qed. + + Lemma tuple_decoder_n_neg k w {H : is_decode decode} : n <= 0 -> tuple_decoder (k := k) w =~> 0. + Proof using Type. + pose proof (tuple_is_decode w) as H'; hnf in H'. + intro; assert (k * n <= 0) by nia. + assert (2^(k * n) <= 2^0) by (apply Z.pow_le_mono_r; omega). + simpl in *; hnf. + omega. + Qed. + Lemma tuple_decoder_O_ind_prod + (P : forall n, decoder n W -> Type) + (P_ext : forall n (a b : decoder n W), (forall x, a x = b x) -> P _ a -> P _ b) + : (P _ (tuple_decoder (k := 1)) -> P _ decode) + * (P _ decode -> P _ (tuple_decoder (k := 1))). + Proof using Type. + unfold tuple_decoder, BaseSystem.decode, BaseSystem.decode', BaseSystem.accumulate, Pow2Base.base_from_limb_widths, repeat. + simpl; hnf. + rewrite Z.mul_1_l. + split; apply P_ext; simpl; intro; autorewrite with zsimplify_const; reflexivity. + Qed. + + Global Instance tuple_decoder_2' w : (0 <= n)%bounded_rewrite -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << (1%nat * n))%Z. + Proof using Type. + intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption. + reflexivity. + Qed. + Global Instance tuple_decoder_2 w : (0 <= n)%bounded_rewrite -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << n)%Z. + Proof using Type. + intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. +End decode. + +Global Arguments tuple_decoder : simpl never. +Local Opaque tuple_decoder. + +Global Instance tuple_decoder_n_O + {W} {decode : decoder 0 W} + {is_decode : is_decode decode} + : forall k w, tuple_decoder (k := k) w =~> 0. +Proof. intros; apply tuple_decoder_n_neg; easy. Qed. + +Lemma is_add_with_carry_1tuple {n W decode adc} + (H : @is_add_with_carry n W decode adc) + : @is_add_with_carry (1 * n) W (@tuple_decoder n W decode 1) adc. +Proof. + apply tuple_decoder_O_ind_prod; try assumption. + intros ??? ext [H0 H1]; apply Build_is_add_with_carry'. + intros x y c; specialize (H0 x y c); specialize (H1 x y c). + rewrite <- !ext; split; assumption. +Qed. + +Hint Extern 1 (@is_add_with_carry _ _ (@tuple_decoder ?n ?W ?decode 1) ?adc) +=> apply (@is_add_with_carry_1tuple n W decode adc) : typeclass_instances. + +Hint Resolve (fun n W decode pf => (@tuple_is_decode n W decode 2 pf : @is_decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2))) : typeclass_instances. +Hint Extern 3 (@is_decode _ (tuple ?W ?k) _) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode pf => (@tuple_is_decode n W decode k pf : @is_decode (kv * n) (tuple W k) (@tuple_decoder n W decode k : decoder (kv * n)%Z (tuple W k)))) : typeclass_instances. + +Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 @tuple_decoder_n_O using solve [ auto with zarith ] : simpl_tuple_decoder. +Hint Rewrite Z.mul_1_l : simpl_tuple_decoder. +Hint Rewrite + (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _)) + (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _)) + (fun n W decode w => @tuple_decoder_m1 n W decode w : @Interface.decode (Z.of_nat 0 * n) unit (@tuple_decoder n W decode 0) w = _) + using solve [ auto with zarith ] + : simpl_tuple_decoder. + +Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. + +Global Instance tuple_decoder_mod {n W} {decode : decoder n W} {k} {isdecode : is_decode decode} (w : tuple W (S (S k))) + : tuple_decoder (k := S k) (fst w) <~= tuple_decoder w mod 2^(S k * n). +Proof. + pose proof (snd w). + assert (0 <= n) by eauto using decode_exponent_nonnegative. + assert (0 <= (S k) * n) by nia. + assert (0 <= tuple_decoder (k := S k) (fst w) < 2^(S k * n)) by apply decode_range. + autorewrite with simpl_tuple_decoder Zshift_to_pow zsimplify. + reflexivity. +Qed. + +Global Instance tuple_decoder_div {n W} {decode : decoder n W} {k} {isdecode : is_decode decode} (w : tuple W (S (S k))) + : decode (snd w) <~= tuple_decoder w / 2^(S k * n). +Proof. + pose proof (snd w). + assert (0 <= n) by eauto using decode_exponent_nonnegative. + assert (0 <= (S k) * n) by nia. + assert (0 <= k * n) by nia. + assert (0 < 2^n) by auto with zarith. + assert (0 <= tuple_decoder (k := S k) (fst w) < 2^(S k * n)) by apply decode_range. + autorewrite with simpl_tuple_decoder Zshift_to_pow zsimplify. + reflexivity. +Qed. + +Global Instance tuple2_decoder_mod {n W} {decode : decoder n W} {isdecode : is_decode decode} (w : tuple W 2) + : decode (fst w) <~= tuple_decoder w mod 2^n. +Proof. + generalize (@tuple_decoder_mod n W decode 0 isdecode w). + autorewrite with simpl_tuple_decoder; trivial. +Qed. + +Global Instance tuple2_decoder_div {n W} {decode : decoder n W} {isdecode : is_decode decode} (w : tuple W 2) + : decode (snd w) <~= tuple_decoder w / 2^n. +Proof. + generalize (@tuple_decoder_div n W decode 0 isdecode w). + autorewrite with simpl_tuple_decoder; trivial. +Qed. diff --git a/src/LegacyArithmetic/Double/Proofs/LoadImmediate.v b/src/LegacyArithmetic/Double/Proofs/LoadImmediate.v new file mode 100644 index 000000000..2c7f87dd7 --- /dev/null +++ b/src/LegacyArithmetic/Double/Proofs/LoadImmediate.v @@ -0,0 +1,32 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.InterfaceProofs. +Require Import Crypto.LegacyArithmetic.Double.Core. +Require Import Crypto.LegacyArithmetic.Double.Proofs.Decode. +Require Import Crypto.Util.ZUtil. + +Local Open Scope Z_scope. +Local Opaque tuple_decoder. +Local Arguments Z.mul !_ !_. + +Section load_immediate. + Context {n W} + {decode : decoder n W} + {is_decode : is_decode decode} + {ldi : load_immediate W} + {is_ldi : is_load_immediate ldi}. + + Global Instance is_load_immediate_double + : is_load_immediate (ldi_double n). + Proof using Type*. + intros x H; hnf in H. + pose proof (decode_exponent_nonnegative decode (ldi x)). + assert (0 <= x mod 2^n < 2^n) by auto with zarith. + assert (x / 2^n < 2^n) + by (apply Z.div_lt_upper_bound; autorewrite with pull_Zpow zsimplify; auto with zarith). + assert (0 <= x / 2^n < 2^n) by (split; Z.zero_bounds). + unfold ldi_double, load_immediate_double; simpl. + autorewrite with simpl_tuple_decoder Zshift_to_pow; simpl; push_decode. + autorewrite with zsimplify; reflexivity. + Qed. +End load_immediate. diff --git a/src/LegacyArithmetic/Double/Proofs/Multiply.v b/src/LegacyArithmetic/Double/Proofs/Multiply.v new file mode 100644 index 000000000..8fed917d9 --- /dev/null +++ b/src/LegacyArithmetic/Double/Proofs/Multiply.v @@ -0,0 +1,132 @@ +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.InterfaceProofs. +Require Import Crypto.LegacyArithmetic.Double.Core. +Require Import Crypto.LegacyArithmetic.Double.Proofs.Decode. +Require Import Crypto.LegacyArithmetic.Double.Proofs.SpreadLeftImmediate. +Require Import Crypto.LegacyArithmetic.Double.Proofs.RippleCarryAddSub. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.SimplifyProjections. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Prod. +Import Bug5107WorkAround. +Import BoundedRewriteNotations. + +Local Open Scope Z_scope. + +Local Opaque tuple_decoder. + +Lemma decode_mul_double_iff + {n W} + {decode : decoder n W} + {muldw : multiply_double W} + {isdecode : is_decode decode} + : is_mul_double muldw + <-> (forall x y, tuple_decoder (muldw x y) = (decode x * decode y)%Z). +Proof. + rewrite is_mul_double_alt by assumption. + split; intros H x y; specialize (H x y); revert H; + pose proof (decode_range x); pose proof (decode_range y); + assert (0 <= decode x * decode y < 2^n * 2^n) by nia; + assert (0 <= n) by eauto using decode_exponent_nonnegative; + autorewrite with simpl_tuple_decoder; + simpl; intro H'; rewrite H'; + Z.rewrite_mod_small; reflexivity. +Qed. + +Global Instance decode_mul_double + {n W} + {decode : decoder n W} + {muldw : multiply_double W} + {isdecode : is_decode decode} + {ismuldw : is_mul_double muldw} + : forall x y, tuple_decoder (muldw x y) <~=~> (decode x * decode y)%Z + := proj1 decode_mul_double_iff _. + +Section tuple2. + Local Arguments Z.pow !_ !_. + Local Arguments Z.mul !_ !_. + + Local Opaque ripple_carry_adc. + Section full_from_half. + Context {W} + {mulhwll : multiply_low_low W} + {mulhwhl : multiply_high_low W} + {mulhwhh : multiply_high_high W} + {adc : add_with_carry W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {half_n : Z} + {ldi : load_immediate W} + {decode : decoder (2 * half_n) W} + {ismulhwll : is_mul_low_low half_n mulhwll} + {ismulhwhl : is_mul_high_low half_n mulhwhl} + {ismulhwhh : is_mul_high_high half_n mulhwhh} + {isadc : is_add_with_carry adc} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr} + {isldi : is_load_immediate ldi} + {isdecode : is_decode decode}. + + Local Arguments Z.mul !_ !_. + + Lemma decode_mul_double_mod x y + : (tuple_decoder (mul_double half_n x y) = (decode x * decode y) mod (2^(2 * half_n) * 2^(2*half_n)))%Z. + Proof using Type*. + assert (0 <= 2 * half_n) by eauto using decode_exponent_nonnegative. + assert (0 <= half_n) by omega. + unfold mul_double, Let_In. + push_decode; autorewrite with simpl_tuple_decoder; simplify_projections. + autorewrite with zsimplify Zshift_to_pow push_Zpow. + rewrite !spread_left_from_shift_half_correct. + push_decode. + generalize_decode_var. + simpl in *. + autorewrite with push_Zpow in *. + repeat autorewrite with Zshift_to_pow zsimplify push_Zpow. + rewrite <- !(Z.mul_mod_distr_r_full _ _ (_^_ * _^_)), ?Z.mul_assoc. + Z.rewrite_mod_small. + push_Zmod; pull_Zmod. + apply f_equal2; [ | reflexivity ]. + Z.div_mod_to_quot_rem; nia. + Qed. + + Lemma decode_mul_double_function x y + : tuple_decoder (mul_double half_n x y) = (decode x * decode y)%Z. + Proof using Type*. + rewrite decode_mul_double_mod; generalize_decode_var. + simpl in *; Z.rewrite_mod_small; reflexivity. + Qed. + + Global Instance mul_double_is_multiply_double : is_mul_double mul_double_multiply. + Proof using Type*. + apply decode_mul_double_iff; apply decode_mul_double_function. + Qed. + End full_from_half. + + Section half_from_full. + Context {n W} + {decode : decoder n W} + {muldw : multiply_double W} + {isdecode : is_decode decode} + {ismuldw : is_mul_double muldw}. + + Local Ltac t := + hnf; intros [??] [??]; + assert (0 <= n) by eauto using decode_exponent_nonnegative; + assert (0 < 2^n) by auto with zarith; + assert (forall x y, 0 <= x < 2^n -> 0 <= y < 2^n -> 0 <= x * y < 2^n * 2^n) by auto with zarith; + simpl @Interface.mulhwhh; simpl @Interface.mulhwhl; simpl @Interface.mulhwll; + rewrite decode_mul_double; autorewrite with simpl_tuple_decoder Zshift_to_pow zsimplify push_Zpow; + Z.rewrite_mod_small; + try reflexivity. + + Global Instance mul_double_is_multiply_low_low : is_mul_low_low n mul_double_multiply_low_low. + Proof using Type*. t. Qed. + Global Instance mul_double_is_multiply_high_low : is_mul_high_low n mul_double_multiply_high_low. + Proof using Type*. t. Qed. + Global Instance mul_double_is_multiply_high_high : is_mul_high_high n mul_double_multiply_high_high. + Proof using Type*. t. Qed. + End half_from_full. +End tuple2. diff --git a/src/LegacyArithmetic/Double/Proofs/RippleCarryAddSub.v b/src/LegacyArithmetic/Double/Proofs/RippleCarryAddSub.v new file mode 100644 index 000000000..e703c2e57 --- /dev/null +++ b/src/LegacyArithmetic/Double/Proofs/RippleCarryAddSub.v @@ -0,0 +1,198 @@ +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.InterfaceProofs. +Require Import Crypto.LegacyArithmetic.Double.Core. +Require Import Crypto.LegacyArithmetic.Double.Proofs.Decode. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SimplifyProjections. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Prod. +Import Bug5107WorkAround. +Import BoundedRewriteNotations. + +Local Coercion Z.of_nat : nat >-> Z. +Local Notation eta x := (fst x, snd x). + +Local Open Scope Z_scope. +Local Opaque tuple_decoder. + +Lemma ripple_carry_tuple_SS' {T} f k xss yss carry + : @ripple_carry_tuple T f (S (S k)) xss yss carry + = dlet xss := xss in + dlet yss := yss in + let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + dlet addv := (@ripple_carry_tuple _ f (S k) xs ys carry) in + let '(carry, zs) := eta addv in + dlet fxy := (f x y carry) in + let '(carry, z) := eta fxy in + (carry, (zs, z)). +Proof. reflexivity. Qed. + +Lemma ripple_carry_tuple_SS {T} f k xss yss carry + : @ripple_carry_tuple T f (S (S k)) xss yss carry + = let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (@ripple_carry_tuple _ f (S k) xs ys carry) in + let '(carry, z) := eta (f x y carry) in + (carry, (zs, z)). +Proof. + rewrite ripple_carry_tuple_SS'. + eta_expand. + reflexivity. +Qed. + +Lemma carry_is_good (n z0 z1 k : Z) + : 0 <= n -> + 0 <= k -> + (z1 + z0 >> k) >> n = (z0 + z1 << k) >> (k + n) /\ + (z0 mod 2 ^ k + ((z1 + z0 >> k) mod 2 ^ n) << k)%Z = (z0 + z1 << k) mod (2 ^ k * 2 ^ n). +Proof. + intros. + assert (0 < 2 ^ n) by auto with zarith. + assert (0 < 2 ^ k) by auto with zarith. + assert (0 < 2^n * 2^k) by nia. + autorewrite with Zshift_to_pow push_Zpow. + rewrite <- (Zmod_small ((z0 mod _) + _) (2^k * 2^n)) by (Z.div_mod_to_quot_rem; nia). + rewrite <- !Z.mul_mod_distr_r by lia. + rewrite !(Z.mul_comm (2^k)); pull_Zmod. + split; [ | apply f_equal2 ]; + Z.div_mod_to_quot_rem; nia. +Qed. +Section carry_sub_is_good. + Context (n k z0 z1 : Z) + (Hn : 0 <= n) + (Hk : 0 <= k) + (Hz1 : -2^n < z1 < 2^n) + (Hz0 : -2^k <= z0 < 2^k). + + Lemma carry_sub_is_good_carry + : ((z1 - if z0 progress break_match + | [ |- context[?x destruct (x reflexivity + | _ => progress Z.ltb_to_lt + | [ |- true = false ] => exfalso + | [ |- false = true ] => exfalso + | [ |- False ] => nia + end. + Qed. + Lemma carry_sub_is_good_value + : (z0 mod 2 ^ k + ((z1 - if z0 Z ] + => first [ cut (q = -1); [ intro; subst; ring | nia ] + | cut (q = 0); [ intro; subst; ring | nia ] + | cut (q = 1); [ intro; subst; ring | nia ] ] + end. + Qed. +End carry_sub_is_good. + +Definition carry_is_good_carry n z0 z1 k H0 H1 := proj1 (@carry_is_good n z0 z1 k H0 H1). +Definition carry_is_good_value n z0 z1 k H0 H1 := proj2 (@carry_is_good n z0 z1 k H0 H1). + +Section ripple_carry_adc. + Context {n W} {decode : decoder n W} (adc : add_with_carry W). + + Lemma ripple_carry_adc_SS k xss yss carry + : ripple_carry_adc (k := S (S k)) adc xss yss carry + = let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (ripple_carry_adc (k := S k) adc xs ys carry) in + let '(carry, z) := eta (adc x y carry) in + (carry, (zs, z)). + Proof using Type. apply ripple_carry_tuple_SS. Qed. + + Local Opaque Z.of_nat. + Global Instance ripple_carry_is_add_with_carry {k} + {isdecode : is_decode decode} + {is_adc : is_add_with_carry adc} + : is_add_with_carry (ripple_carry_adc (k := k) adc). + Proof using Type. + destruct k as [|k]. + { constructor; simpl; intros; autorewrite with zsimplify; reflexivity. } + { induction k as [|k IHk]. + { cbv [ripple_carry_adc ripple_carry_tuple to_list]. + constructor; simpl @fst; simpl @snd; intros; + simpl; pull_decode; reflexivity. } + { apply Build_is_add_with_carry'; intros x y c. + assert (0 <= n) by (destruct x; eauto using decode_exponent_nonnegative). + assert (2^n <> 0) by auto with zarith. + assert (0 <= S k * n) by nia. + rewrite !tuple_decoder_S, !ripple_carry_adc_SS by assumption. + simplify_projections; push_decode; generalize_decode. + erewrite carry_is_good_carry, carry_is_good_value by lia. + autorewrite with pull_Zpow push_Zof_nat zsimplify Zshift_to_pow. + split; apply f_equal2; nia. } } + Qed. + +End ripple_carry_adc. + +Hint Extern 2 (@is_add_with_carry _ (tuple ?W ?k) (@tuple_decoder ?n _ ?decode _) (@ripple_carry_adc _ ?adc _)) +=> apply (@ripple_carry_is_add_with_carry n W decode adc k) : typeclass_instances. +Hint Resolve (fun n W decode adc isdecode isadc + => @ripple_carry_is_add_with_carry n W decode adc 2 isdecode isadc + : @is_add_with_carry (Z.of_nat 2 * n) (W * W) (@tuple_decoder n W decode 2) (@ripple_carry_adc W adc 2)) + : typeclass_instances. + +Section ripple_carry_subc. + Context {n W} {decode : decoder n W} (subc : sub_with_carry W). + + Lemma ripple_carry_subc_SS k xss yss carry + : ripple_carry_subc (k := S (S k)) subc xss yss carry + = let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (ripple_carry_subc (k := S k) subc xs ys carry) in + let '(carry, z) := eta (subc x y carry) in + (carry, (zs, z)). + Proof using Type. apply ripple_carry_tuple_SS. Qed. + + Local Opaque Z.of_nat. + Global Instance ripple_carry_is_sub_with_carry {k} + {isdecode : is_decode decode} + {is_subc : is_sub_with_carry subc} + : is_sub_with_carry (ripple_carry_subc (k := k) subc). + Proof using Type. + destruct k as [|k]. + { constructor; repeat (intros [] || intro); autorewrite with simpl_tuple_decoder zsimplify; reflexivity. } + { induction k as [|k IHk]. + { cbv [ripple_carry_subc ripple_carry_tuple to_list]. + constructor; simpl @fst; simpl @snd; intros; + simpl; push_decode; autorewrite with zsimplify; reflexivity. } + { apply Build_is_sub_with_carry'; intros x y c. + assert (0 <= n) by (destruct x; eauto using decode_exponent_nonnegative). + assert (2^n <> 0) by auto with zarith. + assert (0 <= S k * n) by nia. + rewrite !tuple_decoder_S, !ripple_carry_subc_SS by assumption. + simplify_projections; push_decode; generalize_decode. + erewrite (carry_sub_is_good_carry (S k * n)), carry_sub_is_good_value by (break_match; lia). + autorewrite with pull_Zpow push_Zof_nat zsimplify Zshift_to_pow. + split; apply f_equal2; nia. } } + Qed. + +End ripple_carry_subc. + +Hint Extern 2 (@is_sub_with_carry _ (tuple ?W ?k) (@tuple_decoder ?n _ ?decode _) (@ripple_carry_subc _ ?subc _)) +=> apply (@ripple_carry_is_sub_with_carry n W decode subc k) : typeclass_instances. +Hint Resolve (fun n W decode subc isdecode issubc + => @ripple_carry_is_sub_with_carry n W decode subc 2 isdecode issubc + : @is_sub_with_carry (Z.of_nat 2 * n) (W * W) (@tuple_decoder n W decode 2) (@ripple_carry_subc W subc 2)) + : typeclass_instances. diff --git a/src/LegacyArithmetic/Double/Proofs/SelectConditional.v b/src/LegacyArithmetic/Double/Proofs/SelectConditional.v new file mode 100644 index 000000000..953acf056 --- /dev/null +++ b/src/LegacyArithmetic/Double/Proofs/SelectConditional.v @@ -0,0 +1,25 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.Double.Core. +Require Import Crypto.LegacyArithmetic.Double.Proofs.Decode. + +Section select_conditional. + Context {n W} + {decode : decoder n W} + {is_decode : is_decode decode} + {selc : select_conditional W} + {is_selc : is_select_conditional selc}. + + Global Instance is_select_conditional_double + : is_select_conditional selc_double. + Proof using Type*. + intros b x y. + destruct n. + { rewrite !(tuple_decoder_n_O (W:=W) 2); now destruct b. } + { rewrite (tuple_decoder_2 x), (tuple_decoder_2 y), (tuple_decoder_2 (selc_double b x y)) + by apply Zle_0_pos. + push_decode. + now destruct b. } + { rewrite !(tuple_decoder_n_neg (W:=W) 2); now destruct b. } + Qed. +End select_conditional. diff --git a/src/LegacyArithmetic/Double/Proofs/ShiftLeft.v b/src/LegacyArithmetic/Double/Proofs/ShiftLeft.v new file mode 100644 index 000000000..2230e36b6 --- /dev/null +++ b/src/LegacyArithmetic/Double/Proofs/ShiftLeft.v @@ -0,0 +1,43 @@ +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.Double.Core. +Require Import Crypto.LegacyArithmetic.Double.Proofs.Decode. +Require Import Crypto.LegacyArithmetic.Double.Proofs.ShiftLeftRightTactic. +Require Import Crypto.Util.ZUtil. +(*Require Import Crypto.Util.Tactics.*) + +Local Open Scope Z_scope. + +Local Opaque tuple_decoder. +Local Arguments Z.pow !_ !_. +Local Arguments Z.mul !_ !_. + +Section shl. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {or : bitwise_or W} + {decode : decoder n W} + {isdecode : is_decode decode} + {isldi : is_load_immediate ldi} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr} + {isor : is_bitwise_or or}. + + Global Instance is_shift_left_immediate_double : is_shift_left_immediate (shl_double n). + Proof using Type*. + intros r count H; hnf in H. + assert (0 < 2^count) by auto with zarith. + assert (0 < 2^(n+count)) by auto with zarith. + assert (forall x, 0 <= Z.pow2_mod x n < 2^n) by auto with zarith. + unfold shl_double; simpl. + generalize (decode_range r). + pose proof (decode_range (fst r)). + pose proof (decode_range (snd r)). + assert (forall n', 2^n <= 2^n' -> 0 <= decode (fst r) < 2^n') by (simpl in *; auto with zarith). + assert (forall n', n <= n' -> 0 <= decode (fst r) < 2^n') by auto with zarith omega. + autorewrite with simpl_tuple_decoder; push_decode. + shift_left_right_t. + Qed. +End shl. diff --git a/src/LegacyArithmetic/Double/Proofs/ShiftLeftRightTactic.v b/src/LegacyArithmetic/Double/Proofs/ShiftLeftRightTactic.v new file mode 100644 index 000000000..41234bf6e --- /dev/null +++ b/src/LegacyArithmetic/Double/Proofs/ShiftLeftRightTactic.v @@ -0,0 +1,41 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.Tactics.BreakMatch. + +Local Open Scope Z_scope. + +Local Arguments Z.pow !_ !_. +Local Arguments Z.mul !_ !_. + +Ltac shift_left_right_t := + repeat match goal with + | [ |- ?x = ?x ] => reflexivity + | [ |- Z.testbit ?x ?n = Z.testbit ?x ?n' ] => apply f_equal; try omega + | [ |- orb (Z.testbit ?x _) (Z.testbit ?y _) = orb (Z.testbit ?x _) (Z.testbit ?y _) ] + => apply f_equal2 + | _ => progress Z.ltb_to_lt + | _ => progress subst + | _ => progress unfold AutoRewrite.rewrite_eq + | _ => progress intros + | _ => omega + | _ => solve [ trivial ] + | _ => progress break_match_step ltac:(fun _ => idtac) + | [ |- context[Z.lor (?x >> ?count) (Z.pow2_mod (?y << (?n - ?count)) ?n)] ] + => unique assert (0 <= Z.lor (x >> count) (Z.pow2_mod (y << (n - count)) n) < 2 ^ n) by (autorewrite with Zshift_to_pow; auto with zarith nia) + | _ => progress push_decode + | [ |- context[Interface.decode (fst ?x)] ] => is_var x; destruct x; simpl in * + | [ |- context[@Interface.decode ?n ?W ?d ?x] ] => is_var x; generalize dependent (@Interface.decode n W d x); intros + | _ => progress Z.rewrite_mod_small + | _ => progress autorewrite with convert_to_Ztestbit + | _ => progress autorewrite with zsimplify_fast + | [ |- _ = _ :> Z ] => apply Z.bits_inj'; intros + | _ => progress autorewrite with Ztestbit_full + | _ => progress autorewrite with bool_congr + | [ |- Z.testbit _ (?x - ?y + (?y - ?z)) = false ] + => autorewrite with zsimplify + | [ H : 0 <= ?x < 2^?n |- Z.testbit ?x ?n' = false ] + => assert (n <= n') by auto with zarith; progress Ztestbit + | _ => progress Ztestbit_full + end. diff --git a/src/LegacyArithmetic/Double/Proofs/ShiftRight.v b/src/LegacyArithmetic/Double/Proofs/ShiftRight.v new file mode 100644 index 000000000..16e7c5d6a --- /dev/null +++ b/src/LegacyArithmetic/Double/Proofs/ShiftRight.v @@ -0,0 +1,44 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.Double.Core. +Require Import Crypto.LegacyArithmetic.Double.Proofs.Decode. +Require Import Crypto.LegacyArithmetic.Double.Proofs.ShiftLeftRightTactic. +Require Import Crypto.Util.ZUtil. +(*Require Import Crypto.Util.Tactics.*) + +Local Open Scope Z_scope. + +Local Opaque tuple_decoder. +Local Arguments Z.pow !_ !_. +Local Arguments Z.mul !_ !_. + +Section shr. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {or : bitwise_or W} + {decode : decoder n W} + {isdecode : is_decode decode} + {isldi : is_load_immediate ldi} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr} + {isor : is_bitwise_or or}. + + Global Instance is_shift_right_immediate_double : is_shift_right_immediate (shr_double n). + Proof using Type*. + intros r count H; hnf in H. + assert (0 < 2^count) by auto with zarith. + assert (0 < 2^(n+count)) by auto with zarith. + assert (forall n', ~n' + count < n -> 2^n <= 2^(n'+count)) by auto with zarith omega. + assert (forall n', ~n' + count < n -> 2^n <= 2^(n'+count)) by auto with zarith omega. + unfold shr_double; simpl. + generalize (decode_range r). + pose proof (decode_range (fst r)). + pose proof (decode_range (snd r)). + assert (forall n', 2^n <= 2^n' -> 0 <= decode (fst r) < 2^n') by (simpl in *; auto with zarith). + assert (forall n', n <= n' -> 0 <= decode (fst r) < 2^n') by auto with zarith omega. + autorewrite with simpl_tuple_decoder; push_decode. + shift_left_right_t. + Qed. +End shr. diff --git a/src/LegacyArithmetic/Double/Proofs/ShiftRightDoubleWordImmediate.v b/src/LegacyArithmetic/Double/Proofs/ShiftRightDoubleWordImmediate.v new file mode 100644 index 000000000..00a6d03cd --- /dev/null +++ b/src/LegacyArithmetic/Double/Proofs/ShiftRightDoubleWordImmediate.v @@ -0,0 +1,42 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.Double.Core. +Require Import Crypto.LegacyArithmetic.Double.Proofs.Decode. +Require Import Crypto.LegacyArithmetic.Double.Proofs.ShiftLeftRightTactic. +Require Import Crypto.Util.ZUtil. +(*Require Import Crypto.Util.Tactics.*) + +Local Open Scope Z_scope. + +Local Opaque tuple_decoder. +Local Arguments Z.pow !_ !_. +Local Arguments Z.mul !_ !_. + +Section shrd. + Context (n : Z) {W} + {ldi : load_immediate W} + {shrd : shift_right_doubleword_immediate W} + {decode : decoder n W} + {isdecode : is_decode decode} + {isldi : is_load_immediate ldi} + {isshrd : is_shift_right_doubleword_immediate shrd}. + + Local Ltac zutil_arith ::= solve [ auto with nocore omega ]. + + Global Instance is_shift_right_doubleword_immediate_double : is_shift_right_doubleword_immediate (shrd_double n). + Proof using isdecode isshrd. + intros high low count Hcount; hnf in Hcount. + unfold shrd_double, shift_right_doubleword_immediate_double; simpl. + generalize (decode_range low). + generalize (decode_range high). + generalize (decode_range (fst low)). + generalize (decode_range (snd low)). + generalize (decode_range (fst high)). + generalize (decode_range (snd high)). + assert (forall x, 0 <= Z.pow2_mod x n < 2^n) by auto with zarith. + assert (forall n' x, 2^n <= 2^n' -> 0 <= x < 2^n -> 0 <= x < 2^n') by auto with zarith. + assert (forall n' x, n <= n' -> 0 <= x < 2^n -> 0 <= x < 2^n') by auto with zarith omega. + autorewrite with simpl_tuple_decoder; push_decode. + shift_left_right_t. + Qed. +End shrd. diff --git a/src/LegacyArithmetic/Double/Proofs/SpreadLeftImmediate.v b/src/LegacyArithmetic/Double/Proofs/SpreadLeftImmediate.v new file mode 100644 index 000000000..c50d43616 --- /dev/null +++ b/src/LegacyArithmetic/Double/Proofs/SpreadLeftImmediate.v @@ -0,0 +1,148 @@ +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.LegacyArithmetic.InterfaceProofs. +Require Import Crypto.LegacyArithmetic.Double.Core. +Require Import Crypto.LegacyArithmetic.Double.Proofs.Decode. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.LetIn. +Import Bug5107WorkAround. +Import BoundedRewriteNotations. + +Local Open Scope Z_scope. + +Lemma decode_is_spread_left_immediate_iff + {n W} + {decode : decoder n W} + {sprl : spread_left_immediate W} + {isdecode : is_decode decode} + : is_spread_left_immediate sprl + <-> (forall r count, + 0 <= count < n + -> tuple_decoder (sprl r count) = decode r << count). +Proof. + rewrite is_spread_left_immediate_alt by assumption. + split; intros H r count Hc; specialize (H r count Hc); revert H; + pose proof (decode_range r); + assert (0 < 2^count < 2^n) by auto with zarith; + autorewrite with simpl_tuple_decoder; + simpl; intro H'; rewrite H'; + autorewrite with Zshift_to_pow; + Z.rewrite_mod_small; reflexivity. +Qed. + +Global Instance decode_is_spread_left_immediate + {n W} + {decode : decoder n W} + {sprl : spread_left_immediate W} + {isdecode : is_decode decode} + {issprl : is_spread_left_immediate sprl} + : forall r count, + (0 <= count < n)%bounded_rewrite + -> tuple_decoder (sprl r count) <~=~> decode r << count + := proj1 decode_is_spread_left_immediate_iff _. + + +Section tuple2. + Section spread_left. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {decode : decoder n W} + {isdecode : is_decode decode} + {isldi : is_load_immediate ldi} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr}. + + Lemma spread_left_from_shift_correct + r count + (H : 0 < count < n) + : (decode (shl r count) + decode (shr r (n - count)) << n = decode r << count mod (2^n*2^n))%Z. + Proof using isdecode isshl isshr. + assert (0 <= count < n) by lia. + assert (0 <= n - count < n) by lia. + assert (0 < 2^(n-count)) by auto with zarith. + assert (2^count < 2^n) by auto with zarith. + pose proof (decode_range r). + assert (0 <= decode r * 2 ^ count < 2 ^ n * 2^n) by auto with zarith. + push_decode; autorewrite with Zshift_to_pow zsimplify. + replace (decode r / 2^(n-count) * 2^n)%Z with ((decode r / 2^(n-count) * 2^(n-count)) * 2^count)%Z + by (rewrite <- Z.mul_assoc; autorewrite with pull_Zpow zsimplify; reflexivity). + rewrite Z.mul_div_eq' by lia. + autorewrite with push_Zmul zsimplify. + rewrite <- Z.mul_mod_distr_r_full, Z.add_sub_assoc. + repeat autorewrite with pull_Zpow zsimplify in *. + reflexivity. + Qed. + + Global Instance is_spread_left_from_shift + : is_spread_left_immediate (sprl_from_shift n). + Proof using Type*. + apply is_spread_left_immediate_alt. + intros r count; intros. + pose proof (decode_range r). + assert (0 < 2^n) by auto with zarith. + assert (decode r < 2^n * 2^n) by (generalize dependent (decode r); intros; nia). + autorewrite with simpl_tuple_decoder. + destruct (Z_zerop count). + { subst; autorewrite with Zshift_to_pow zsimplify. + simpl; push_decode. + autorewrite with push_Zpow zsimplify. + reflexivity. } + simpl. + rewrite <- spread_left_from_shift_correct by lia. + autorewrite with zsimplify Zpow_to_shift. + reflexivity. + Qed. + End spread_left. + + Section full_from_half. + Context {W} + {mulhwll : multiply_low_low W} + {mulhwhl : multiply_high_low W} + {mulhwhh : multiply_high_high W} + {adc : add_with_carry W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {half_n : Z} + {ldi : load_immediate W} + {decode : decoder (2 * half_n) W} + {ismulhwll : is_mul_low_low half_n mulhwll} + {ismulhwhl : is_mul_high_low half_n mulhwhl} + {ismulhwhh : is_mul_high_high half_n mulhwhh} + {isadc : is_add_with_carry adc} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr} + {isldi : is_load_immediate ldi} + {isdecode : is_decode decode}. + + Local Arguments Z.mul !_ !_. + Lemma spread_left_from_shift_half_correct + r + : (decode (shl r half_n) + decode (shr r half_n) * (2^half_n * 2^half_n) + = (decode r * 2^half_n) mod (2^half_n*2^half_n*2^half_n*2^half_n))%Z. + Proof using Type*. + destruct (0 Z }. +Coercion decode : decoder >-> Funclass. +Global Arguments decode {n W _} _. + +Class is_decode {n W} (decode : decoder n W) := + decode_range : forall x, 0 <= decode x < 2^n. + +Class bounded_in_range_cls (x y z : Z) := is_bounded_in_range : x <= y < z. +Ltac bounded_solver_tac := + solve [ eassumption | typeclasses eauto | omega ]. +Hint Extern 0 (bounded_in_range_cls _ _ _) => unfold bounded_in_range_cls; bounded_solver_tac : typeclass_instances. +Global Arguments bounded_in_range_cls / _ _ _. +Global Instance decode_range_bound {n W} {decode : decoder n W} {H : is_decode decode} + : forall x, bounded_in_range_cls 0 (decode x) (2^n) + := H. + +Class bounded_le_cls (x y : Z) := is_bounded_le : x <= y. +Hint Extern 0 (bounded_le_cls _ _) => unfold bounded_le_cls; bounded_solver_tac : typeclass_instances. +Global Arguments bounded_le_cls / _ _. + +Inductive bounded_decode_pusher_tag := decode_tag. + +Ltac push_decode_step := + match goal with + | [ |- context[@decode ?n ?W ?decoder ?w] ] + => tc_rewrite (decode_tag) (@decode n W decoder w) -> + | [ |- context[match @fst ?A ?B ?x with true => 1 | false => 0 end] ] + => tc_rewrite (decode_tag) (match @fst A B x with true => 1 | false => 0 end) -> + | [ |- context[@fst bool ?B ?x] ] + => tc_rewrite (decode_tag) (@fst bool B x) -> + end. +Ltac push_decode := repeat push_decode_step. +Ltac pull_decode_step := + match goal with + | [ |- context[?E] ] + => lazymatch type of E with + | Z => idtac + | bool => idtac + end; + tc_rewrite (decode_tag) <- E + end. +Ltac pull_decode := repeat pull_decode_step. + +Delimit Scope bounded_rewrite_scope with bounded_rewrite. + +Infix "<~=~>" := (rewrite_eq decode_tag) : bounded_rewrite_scope. +Infix "=~>" := (rewrite_left_to_right_eq decode_tag) : bounded_rewrite_scope. +Infix "<~=" := (rewrite_right_to_left_eq decode_tag) : bounded_rewrite_scope. +Notation "x <= y" := (bounded_le_cls x y) : bounded_rewrite_scope. +Notation "x <= y < z" := (bounded_in_range_cls x y z) : bounded_rewrite_scope. + +Module Import BoundedRewriteNotations. + Infix "<~=~>" := (rewrite_eq decode_tag) : type_scope. + Infix "=~>" := (rewrite_left_to_right_eq decode_tag) : type_scope. + Infix "<~=" := (rewrite_right_to_left_eq decode_tag) : type_scope. + Open Scope bounded_rewrite_scope. +End BoundedRewriteNotations. + +(** This is required for typeclass resolution to be fast. *) +Typeclasses Opaque decode. + +Section InstructionGallery. + Context (n : Z) (* bit-width of width of [W] *) + {W : Type} (* bounded type, [W] for word *) + (Wdecoder : decoder n W). + Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) + + Class load_immediate := { ldi : imm -> W }. + Global Coercion ldi : load_immediate >-> Funclass. + + Class is_load_immediate {ldi : load_immediate} := + decode_load_immediate :> forall x, 0 <= x < 2^n -> decode (ldi x) =~> x. + + Class shift_right_doubleword_immediate := { shrd : W -> W -> imm -> W }. + Global Coercion shrd : shift_right_doubleword_immediate >-> Funclass. + + Class is_shift_right_doubleword_immediate (shrd : shift_right_doubleword_immediate) := + decode_shift_right_doubleword :> + forall high low count, + 0 <= count < n + -> decode (shrd high low count) <~=~> (((decode high << n) + decode low) >> count) mod 2^n. + + (** Quoting http://www.felixcloutier.com/x86/SHRD.html: + + If the count is 1 or greater, the CF flag is filled with the + last bit shifted out of the destination operand and the SF, ZF, + and PF flags are set according to the value of the result. For a + 1-bit shift, the OF flag is set if a sign change occurred; + otherwise, it is cleared. For shifts greater than 1 bit, the OF + flag is undefined. If a shift occurs, the AF flag is + unde-fined. If the count operand is 0, the flags are not + affected. If the count is greater than the operand size, the + flags are undefined. + + We ignore the CF in the specification; we only have it so that + we can ensure that the CF flag gets appropriately clobbered. *) + Class shift_right_doubleword_immediate_with_CF := { shrdf : W -> W -> imm -> bool * W }. + Global Coercion shrdf : shift_right_doubleword_immediate_with_CF >-> Funclass. + + Class is_shift_right_doubleword_immediate_with_CF (shrdf : shift_right_doubleword_immediate_with_CF) := + decode_snd_shift_right_doubleword_with_CF :> + forall high low count, + 0 <= count < n + -> decode (snd (shrdf high low count)) <~=~> (((decode high << n) + decode low) >> count) mod 2^n. + + Class shift_left_immediate := { shl : W -> imm -> W }. + Global Coercion shl : shift_left_immediate >-> Funclass. + + Class is_shift_left_immediate (shl : shift_left_immediate) := + decode_shift_left_immediate :> + forall r count, 0 <= count < n -> decode (shl r count) <~=~> (decode r << count) mod 2^n. + + (** Quoting http://www.felixcloutier.com/x86/SAL:SAR:SHL:SHR.html: + + The CF flag contains the value of the last bit shifted out of + the destination operand; it is undefined for SHL and SHR + instructions where the count is greater than or equal to the + size (in bits) of the destination operand. The OF flag is + affected only for 1-bit shifts (see “Description” above); + otherwise, it is undefined. The SF, ZF, and PF flags are set + according to the result. If the count is 0, the flags are not + affected. For a non-zero count, the AF flag is undefined. + + We ignore the CF in the specification; we only have it so that + we can ensure that the CF flag gets appropriately clobbered. *) + Class shift_left_immediate_with_CF := { shlf : W -> imm -> bool * W }. + Global Coercion shlf : shift_left_immediate_with_CF >-> Funclass. + + Class is_shift_left_immediate_with_CF (shlf : shift_left_immediate_with_CF) := + decode_shift_left_immediate_with_CF :> + forall r count, 0 <= count < n -> decode (snd (shlf r count)) <~=~> (decode r << count) mod 2^n. + + Class shift_right_immediate := { shr : W -> imm -> W }. + Global Coercion shr : shift_right_immediate >-> Funclass. + + Class is_shift_right_immediate (shr : shift_right_immediate) := + decode_shift_right_immediate :> + forall r count, 0 <= count < n -> decode (shr r count) <~=~> (decode r >> count). + + Class shift_right_immediate_with_CF := { shrf : W -> imm -> bool * W }. + Global Coercion shrf : shift_right_immediate_with_CF >-> Funclass. + + Class is_shift_right_immediate_with_CF (shrf : shift_right_immediate_with_CF) := + decode_shift_right_immediate_with_CF :> + forall r count, 0 <= count < n -> decode (snd (shrf r count)) <~=~> (decode r >> count). + + Class spread_left_immediate := { sprl : W -> imm -> tuple W 2 (* [(low, high)] *) }. + Global Coercion sprl : spread_left_immediate >-> Funclass. + + Class is_spread_left_immediate (sprl : spread_left_immediate) := + { + decode_fst_spread_left_immediate :> forall r count, + 0 <= count < n + -> decode (fst (sprl r count)) =~> (decode r << count) mod 2^n; + decode_snd_spread_left_immediate :> forall r count, + 0 <= count < n + -> decode (snd (sprl r count)) =~> (decode r << count) >> n + + }. + + Class mask_keep_low := { mkl :> W -> imm -> W }. + Global Coercion mkl : mask_keep_low >-> Funclass. + + Class is_mask_keep_low (mkl : mask_keep_low) := + decode_mask_keep_low :> forall r count, + 0 <= count < n -> decode (mkl r count) <~=~> decode r mod 2^count. + + Class bitwise_and := { and : W -> W -> W }. + Global Coercion and : bitwise_and >-> Funclass. + + Class is_bitwise_and (and : bitwise_and) := + { + decode_bitwise_and :> forall x y, decode (and x y) <~=~> Z.land (decode x) (decode y) + }. + + (** Quoting http://www.felixcloutier.com/x86/AND.html: + + The OF and CF flags are cleared; the SF, ZF, and PF flags are set + according to the result. The state of the AF flag is + undefined. *) + Class bitwise_and_with_CF := { andf : W -> W -> bool * W }. + Global Coercion andf : bitwise_and_with_CF >-> Funclass. + + Class is_bitwise_and_with_CF (andf : bitwise_and_with_CF) := + { + decode_snd_bitwise_and_with_CF :> forall x y, decode (snd (andf x y)) <~=~> Z.land (decode x) (decode y); + fst_bitwise_and_with_CF :> forall x y, fst (andf x y) =~> false + }. + + Class bitwise_or := { or : W -> W -> W }. + Global Coercion or : bitwise_or >-> Funclass. + + Class is_bitwise_or (or : bitwise_or) := + { + decode_bitwise_or :> forall x y, decode (or x y) <~=~> Z.lor (decode x) (decode y) + }. + + (** Quoting http://www.felixcloutier.com/x86/OR.html: + + The OF or CF flags are cleared; the SF, ZF, or PF flags are set + according to the result. The state of the AF flag is + undefined. *) + Class bitwise_or_with_CF := { orf : W -> W -> bool * W }. + Global Coercion orf : bitwise_or_with_CF >-> Funclass. + + Class is_bitwise_or_with_CF (orf : bitwise_or_with_CF) := + { + decode_snd_bitwise_or_with_CF :> forall x y, decode (snd (orf x y)) <~=~> Z.lor (decode x) (decode y); + fst_bitwise_or_with_CF :> forall x y, fst (orf x y) =~> false + }. + + Local Notation bit b := (if b then 1 else 0). + + Class add_with_carry := { adc : W -> W -> bool -> bool * W }. + Global Coercion adc : add_with_carry >-> Funclass. + + Class is_add_with_carry (adc : add_with_carry) := + { + bit_fst_add_with_carry :> forall x y c, bit (fst (adc x y c)) <~=~> (decode x + decode y + bit c) >> n; + decode_snd_add_with_carry :> forall x y c, decode (snd (adc x y c)) <~=~> (decode x + decode y + bit c) mod (2^n) + }. + + Class sub_with_carry := { subc : W -> W -> bool -> bool * W }. + Global Coercion subc : sub_with_carry >-> Funclass. + + Class is_sub_with_carry (subc:W->W->bool->bool*W) := + { + fst_sub_with_carry :> forall x y c, fst (subc x y c) <~=~> ((decode x - decode y - bit c) forall x y c, decode (snd (subc x y c)) <~=~> (decode x - decode y - bit c) mod 2^n + }. + + Class multiply := { mul : W -> W -> W }. + Global Coercion mul : multiply >-> Funclass. + + Class is_mul (mul : multiply) := + decode_mul :> forall x y, decode (mul x y) <~=~> (decode x * decode y). + + Class multiply_low_low := { mulhwll : W -> W -> W }. + Global Coercion mulhwll : multiply_low_low >-> Funclass. + Class multiply_high_low := { mulhwhl : W -> W -> W }. + Global Coercion mulhwhl : multiply_high_low >-> Funclass. + Class multiply_high_high := { mulhwhh : W -> W -> W }. + Global Coercion mulhwhh : multiply_high_high >-> Funclass. + Class multiply_double := { muldw : W -> W -> tuple W 2 }. + Global Coercion muldw : multiply_double >-> Funclass. + (** Quoting http://www.felixcloutier.com/x86/MUL.html: + + The OF and CF flags are set to 0 if the upper half of the result + is 0; otherwise, they are set to 1. The SF, ZF, AF, and PF flags + are undefined. + + We ignore the CF in the specification; we only have it so that + we can ensure that the CF flag gets appropriately clobbered. *) + Class multiply_double_with_CF := { muldwf : W -> W -> bool * tuple W 2 }. + Global Coercion muldwf : multiply_double_with_CF >-> Funclass. + + Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := + decode_mul_low_low :> + forall x y, decode (mulhwll x y) <~=~> ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. + Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) := + decode_mul_high_low :> + forall x y, decode (mulhwhl x y) <~=~> ((decode x >> w) * (decode y mod 2^w)) mod 2^n. + Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) := + decode_mul_high_high :> + forall x y, decode (mulhwhh x y) <~=~> ((decode x >> w) * (decode y >> w)) mod 2^n. + Class is_mul_double (muldw : multiply_double) := + { + decode_fst_mul_double :> + forall x y, decode (fst (muldw x y)) =~> (decode x * decode y) mod 2^n; + decode_snd_mul_double :> + forall x y, decode (snd (muldw x y)) =~> (decode x * decode y) >> n + }. + + Class is_mul_double_with_CF (muldwf : multiply_double_with_CF) := + { + decode_fst_mul_double_with_CF :> + forall x y, decode (fst (snd (muldwf x y))) =~> (decode x * decode y) mod 2^n; + decode_snd_mul_double_with_CF :> + forall x y, decode (snd (snd (muldwf x y))) =~> (decode x * decode y) >> n + }. + + Class select_conditional := { selc : bool -> W -> W -> W }. + Global Coercion selc : select_conditional >-> Funclass. + + Class is_select_conditional (selc : select_conditional) := + decode_select_conditional :> forall b x y, + decode (selc b x y) <~=~> if b then decode x else decode y. + + Class add_modulo := { addm : W -> W -> W (* modulus *) -> W }. + Global Coercion addm : add_modulo >-> Funclass. + + Class is_add_modulo (addm : add_modulo) := + decode_add_modulo :> forall x y modulus, + decode (addm x y modulus) <~=~> (if (decode x + decode y) decoder n W; + ldi :> load_immediate W; + shrd :> shift_right_doubleword_immediate W; + shl :> shift_left_immediate W; + shr :> shift_right_immediate W; + adc :> add_with_carry W; + subc :> sub_with_carry W; + mulhwll :> multiply_low_low W; + mulhwhl :> multiply_high_low W; + mulhwhh :> multiply_high_high W; + selc :> select_conditional W; + addm :> add_modulo W + }. + + Class arithmetic {n_over_two} (ops:instructions (2 * n_over_two)) := + { + decode_range :> is_decode decode; + load_immediate :> is_load_immediate ldi; + shift_right_doubleword_immediate :> is_shift_right_doubleword_immediate shrd; + shift_left_immediate :> is_shift_left_immediate shl; + shift_right_immediate :> is_shift_right_immediate shr; + add_with_carry :> is_add_with_carry adc; + sub_with_carry :> is_sub_with_carry subc; + multiply_low_low :> is_mul_low_low n_over_two mulhwll; + multiply_high_low :> is_mul_high_low n_over_two mulhwhl; + multiply_high_high :> is_mul_high_high n_over_two mulhwhh; + select_conditional :> is_select_conditional selc; + add_modulo :> is_add_modulo addm + }. +End fancy_machine. + +Module x86. + Local Notation imm := Z (only parsing). + + Class instructions (n : Z) := + { + W : Type (* [n]-bit word *); + decode :> decoder n W; + ldi :> load_immediate W; + shrdf :> shift_right_doubleword_immediate_with_CF W; + shlf :> shift_left_immediate_with_CF W; + shrf :> shift_right_immediate_with_CF W; + adc :> add_with_carry W; + subc :> sub_with_carry W; + muldwf :> multiply_double_with_CF W; + selc :> select_conditional W; + orf :> bitwise_or_with_CF W + }. + + Class arithmetic {n} (ops:instructions n) := + { + decode_range :> is_decode decode; + load_immediate :> is_load_immediate ldi; + shift_right_doubleword_immediate_with_CF :> is_shift_right_doubleword_immediate_with_CF shrdf; + shift_left_immediate_with_CF :> is_shift_left_immediate_with_CF shlf; + shift_right_immediate_with_CF :> is_shift_right_immediate_with_CF shrf; + add_with_carry :> is_add_with_carry adc; + sub_with_carry :> is_sub_with_carry subc; + multiply_double_with_CF :> is_mul_double_with_CF muldwf; + select_conditional :> is_select_conditional selc; + bitwise_or_with_CF :> is_bitwise_or_with_CF orf + }. +End x86. diff --git a/src/LegacyArithmetic/InterfaceProofs.v b/src/LegacyArithmetic/InterfaceProofs.v new file mode 100644 index 000000000..9ef97fa55 --- /dev/null +++ b/src/LegacyArithmetic/InterfaceProofs.v @@ -0,0 +1,224 @@ +(** * Alternate forms for Interface for bounded arithmetic *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.LegacyArithmetic.Interface. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.AutoRewrite. +Require Import Crypto.Util.Notations. + +Local Open Scope type_scope. +Local Open Scope Z_scope. + +Import BoundedRewriteNotations. +Local Notation bit b := (if b then 1 else 0). + +Lemma decoder_eta {n W} (decode : decoder n W) : decode = {| Interface.decode := decode |}. +Proof. destruct decode; reflexivity. Defined. + +Section InstructionGallery. + Context (n : Z) (* bit-width of width of [W] *) + {W : Type} (* bounded type, [W] for word *) + (Wdecoder : decoder n W). + Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) + + Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate W) + (pf : forall r count, 0 <= count < n + -> _ /\ _) + := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H); + decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}. + + Definition Build_is_add_with_carry' (adc : add_with_carry W) + (pf : forall x y c, _ /\ _) + := {| bit_fst_add_with_carry x y c := proj1 (pf x y c); + decode_snd_add_with_carry x y c := proj2 (pf x y c) |}. + + Definition Build_is_sub_with_carry' (subc : sub_with_carry W) + (pf : forall x y c, _ /\ _) + : is_sub_with_carry subc + := {| fst_sub_with_carry x y c := proj1 (pf x y c); + decode_snd_sub_with_carry x y c := proj2 (pf x y c) |}. + + Definition Build_is_mul_double' (muldw : multiply_double W) + (pf : forall x y, _ /\ _) + := {| decode_fst_mul_double x y := proj1 (pf x y); + decode_snd_mul_double x y := proj2 (pf x y) |}. + + Lemma is_spread_left_immediate_alt + {sprl : spread_left_immediate W} + {isdecode : is_decode Wdecoder} + : is_spread_left_immediate sprl + <-> (forall r count, 0 <= count < n -> decode (fst (sprl r count)) + decode (snd (sprl r count)) << n = (decode r << count) mod (2^n*2^n))%Z. + Proof using Type. + split; intro H; [ | apply Build_is_spread_left_immediate' ]; + intros r count Hc; + [ | specialize (H r count Hc); revert H ]; + unfold bounded_in_range_cls in *; + pose proof (decode_range r); + assert (0 < 2^n) by auto with zarith; + assert (0 <= 2^count < 2^n)%Z by auto with zarith; + assert (0 <= decode r * 2^count < 2^n * 2^n)%Z by (generalize dependent (decode r); intros; nia); + rewrite ?decode_fst_spread_left_immediate, ?decode_snd_spread_left_immediate + by typeclasses eauto with typeclass_instances core; + autorewrite with Zshift_to_pow zsimplify push_Zpow. + { reflexivity. } + { intro H'; rewrite <- H'. + autorewrite with zsimplify; split; reflexivity. } + Qed. + + Lemma is_mul_double_alt + {muldw : multiply_double W} + {isdecode : is_decode Wdecoder} + : is_mul_double muldw + <-> (forall x y, decode (fst (muldw x y)) + decode (snd (muldw x y)) << n = (decode x * decode y) mod (2^n*2^n)). + Proof using Type. + split; intro H; [ | apply Build_is_mul_double' ]; + intros x y; + [ | specialize (H x y); revert H ]; + pose proof (decode_range x); + pose proof (decode_range y); + assert (0 < 2^n) by auto with zarith; + assert (0 <= decode x * decode y < 2^n * 2^n)%Z by nia; + (destruct (0 <=? n) eqn:?; Z.ltb_to_lt; + [ | assert (2^n = 0) by auto with zarith; exfalso; omega ]); + rewrite ?decode_fst_mul_double, ?decode_snd_mul_double + by typeclasses eauto with typeclass_instances core; + autorewrite with Zshift_to_pow zsimplify push_Zpow. + { reflexivity. } + { intro H'; rewrite <- H'. + autorewrite with zsimplify; split; reflexivity. } + Qed. +End InstructionGallery. + +Global Arguments is_spread_left_immediate_alt {_ _ _ _ _}. +Global Arguments is_mul_double_alt {_ _ _ _ _}. + +Ltac bounded_solver_tac := + solve [ eassumption | typeclasses eauto | omega ]. + +Global Instance decode_proj n W (dec : W -> Z) + : @decode n W {| decode := dec |} =~> dec. +Proof. reflexivity. Qed. + +Global Instance decode_if_bool n W (decode : decoder n W) (b : bool) x y + : decode (if b then x else y) + =~> if b then decode x else decode y. +Proof. destruct b; reflexivity. Qed. + +Global Instance decode_mod_small {n W} {decode : decoder n W} {x b} + {H : bounded_in_range_cls 0 (decode x) b} + : decode x <~= decode x mod b. +Proof. + Z.rewrite_mod_small; reflexivity. +Qed. + +Global Instance decode_mod_range {n W decode} {H : @is_decode n W decode} x + : decode x <~= decode x mod 2^n. +Proof. exact _. Qed. + +Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode} + (isinhabited : W) + : (0 <= n)%Z. +Proof. + pose proof (decode_range isinhabited). + assert (0 < 2^n) by omega. + destruct (Z_lt_ge_dec n 0) as [H'|]; [ | omega ]. + assert (2^n = 0) by auto using Z.pow_neg_r. + omega. +Qed. + +Section adc_subc. + Context {n W} + {decode : decoder n W} + {adc : add_with_carry W} + {subc : sub_with_carry W} + {isdecode : is_decode decode} + {isadc : is_add_with_carry adc} + {issubc : is_sub_with_carry subc}. + Global Instance bit_fst_add_with_carry_false + : forall x y, bit (fst (adc x y false)) <~=~> (decode x + decode y) >> n. + Proof using isadc. + intros; erewrite bit_fst_add_with_carry by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance bit_fst_add_with_carry_true + : forall x y, bit (fst (adc x y true)) <~=~> (decode x + decode y + 1) >> n. + Proof using isadc. + intros; erewrite bit_fst_add_with_carry by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance fst_add_with_carry_leb + : forall x y c, fst (adc x y c) <~= (2^n <=? (decode x + decode y + bit c)). + Proof using isadc isdecode. + intros x y c; hnf. + assert (0 <= n)%Z by eauto using decode_exponent_nonnegative. + pose proof (decode_range x); pose proof (decode_range y). + assert (0 <= bit c <= 1)%Z by (destruct c; omega). + lazymatch goal with + | [ |- fst ?x = (?a <=? ?b) :> bool ] + => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); + [ destruct (fst x), (a <=? b); intro; congruence | ] + end. + push_decode. + autorewrite with Zshift_to_pow. + rewrite Z.div_between_0_if by auto with zarith. + reflexivity. + Qed. + Global Instance fst_add_with_carry_false_leb + : forall x y, fst (adc x y false) <~= (2^n <=? (decode x + decode y)). + Proof using isadc isdecode. + intros; erewrite fst_add_with_carry_leb by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance fst_add_with_carry_true_leb + : forall x y, fst (adc x y true) <~=~> (2^n <=? (decode x + decode y + 1)). + Proof using isadc isdecode. + intros; erewrite fst_add_with_carry_leb by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance fst_sub_with_carry_false + : forall x y, fst (subc x y false) <~=~> ((decode x - decode y) ((decode x - decode y - 1) apply @fst_add_with_carry_false_leb : typeclass_instances. +Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + 1))) +=> apply @fst_add_with_carry_true_leb : typeclass_instances. +Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + if ?c then _ else _))) +=> apply @fst_add_with_carry_leb : typeclass_instances. + + +(* We take special care to handle the case where the decoder is + syntactically different but the decoded expression is judgmentally + the same; we don't want to split apart variables that should be the + same. *) +Ltac set_decode_step check := + match goal with + | [ |- context G[@decode ?n ?W ?dr ?w] ] + => check w; + first [ match goal with + | [ d := @decode _ _ _ w |- _ ] + => change (@decode n W dr w) with d + end + | generalize (@decode_range n W dr _ w); + let d := fresh "d" in + set (d := @decode n W dr w); + intro ] + end. +Ltac set_decode check := repeat set_decode_step check. +Ltac clearbody_decode := + repeat match goal with + | [ H := @decode _ _ _ _ |- _ ] => clearbody H + end. +Ltac generalize_decode_by check := set_decode check; clearbody_decode. +Ltac generalize_decode := generalize_decode_by ltac:(fun w => idtac). +Ltac generalize_decode_var := generalize_decode_by ltac:(fun w => is_var w). diff --git a/src/LegacyArithmetic/MontgomeryReduction.v b/src/LegacyArithmetic/MontgomeryReduction.v new file mode 100644 index 000000000..c3538dd01 --- /dev/null +++ b/src/LegacyArithmetic/MontgomeryReduction.v @@ -0,0 +1,114 @@ +(*** Montgomery Multiplication *) +(** This file implements Montgomery Form, Montgomery Reduction, and + Montgomery Multiplication on [ZLikeOps]. We follow [Montgomery/Z.v]. *) +Require Import Coq.ZArith.ZArith Coq.Lists.List Coq.Classes.Morphisms Coq.micromega.Psatz. +Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. +Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. +Require Import Crypto.LegacyArithmetic.ZBounded. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.Test. +Require Import Crypto.Util.Tactics.Not. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. + +Local Open Scope small_zlike_scope. +Local Open Scope large_zlike_scope. +Local Open Scope Z_scope. + +Section montgomery. + Context (small_bound modulus : Z) {ops : ZLikeOps small_bound small_bound modulus} {props : ZLikeProperties ops} + (modulus' : SmallT) + (modulus'_valid : small_valid modulus') + (modulus_nonzero : modulus <> 0). + + (** pull out a common subexpression *) + Local Ltac cse := + let RHS := match goal with |- _ = ?decode ?RHS /\ _ => RHS end in + let v := fresh in + match RHS with + | context[?e] => not is_var e; set (v := e) at 1 2; test clearbody v + end; + revert v; + match goal with + | [ |- let v := ?val in ?LHS = ?decode ?RHS /\ ?P ] + => change (LHS = decode (dlet v := val in RHS) /\ P) + end. + + Definition partial_reduce : forall v : LargeT, + { partial_reduce : SmallT + | large_valid v + -> decode_small partial_reduce = MontgomeryReduction.Definition.partial_reduce modulus small_bound (decode_small modulus') (decode_large v) + /\ small_valid partial_reduce }. + Proof. + intro T. evar (pr : SmallT); exists pr. intros T_valid. + assert (0 <= decode_large T < small_bound * small_bound) by auto using decode_large_valid. + assert (0 <= decode_small (Mod_SmallBound T) < small_bound) by auto using decode_small_valid, Mod_SmallBound_valid. + assert (0 <= decode_small modulus' < small_bound) by auto using decode_small_valid. + assert (0 <= decode_small modulus_digits < small_bound) by auto using decode_small_valid, modulus_digits_valid. + assert (0 <= modulus) by apply (modulus_nonneg _). + assert (modulus < small_bound) by (rewrite <- modulus_digits_correct; omega). + rewrite <- partial_reduce_alt_eq by omega. + cbv [MontgomeryReduction.Definition.partial_reduce MontgomeryReduction.Definition.partial_reduce_alt MontgomeryReduction.Definition.prereduce]. + pull_zlike_decode. + cse. + subst pr; split; [ reflexivity | exact _ ]. + Defined. + + Definition reduce_via_partial : forall v : LargeT, + { reduce : SmallT + | large_valid v + -> decode_small reduce = MontgomeryReduction.Definition.reduce_via_partial modulus small_bound (decode_small modulus') (decode_large v) + /\ small_valid reduce }. + Proof. + intro T. evar (pr : SmallT); exists pr. intros T_valid. + assert (0 <= decode_large T < small_bound * small_bound) by auto using decode_large_valid. + assert (0 <= decode_small (Mod_SmallBound T) < small_bound) by auto using decode_small_valid, Mod_SmallBound_valid. + assert (0 <= decode_small modulus' < small_bound) by auto using decode_small_valid. + assert (0 <= decode_small modulus_digits < small_bound) by auto using decode_small_valid, modulus_digits_valid. + assert (0 <= modulus) by apply (modulus_nonneg _). + assert (modulus < small_bound) by (rewrite <- modulus_digits_correct; omega). + unfold reduce_via_partial. + rewrite <- partial_reduce_alt_eq by omega. + cbv [MontgomeryReduction.Definition.partial_reduce MontgomeryReduction.Definition.partial_reduce_alt MontgomeryReduction.Definition.prereduce]. + pull_zlike_decode. + cse. + subst pr; split; [ reflexivity | exact _ ]. + Defined. + + Section correctness. + Context (R' : Z) + (Hmod : Z.equiv_modulo modulus (small_bound * R') 1) + (Hmod' : Z.equiv_modulo small_bound (modulus * (decode_small modulus')) (-1)) + (v : LargeT) + (H : large_valid v) + (Hv : 0 <= decode_large v <= small_bound * modulus). + Lemma reduce_via_partial_correct' + : Z.equiv_modulo modulus + (decode_small (proj1_sig (reduce_via_partial v))) + (decode_large v * R') + /\ Z.min 0 (small_bound - modulus) <= (decode_small (proj1_sig (reduce_via_partial v))) < modulus. + Proof using H Hmod Hmod' Hv. + rewrite (proj1 (proj2_sig (reduce_via_partial v) H)). + eauto 6 using reduce_via_partial_correct, reduce_via_partial_in_range, decode_small_valid. + Qed. + + Lemma reduce_via_partial_correct'' + : Z.equiv_modulo modulus + (decode_small (proj1_sig (reduce_via_partial v))) + (decode_large v * R') + /\ 0 <= (decode_small (proj1_sig (reduce_via_partial v))) < modulus. + Proof using H Hmod Hmod' Hv. + pose proof (proj2 (proj2_sig (reduce_via_partial v) H)) as H'. + apply decode_small_valid in H'. + destruct reduce_via_partial_correct'; split; eauto; omega. + Qed. + + Theorem reduce_via_partial_correct + : decode_small (proj1_sig (reduce_via_partial v)) = (decode_large v * R') mod modulus. + Proof using H Hmod Hmod' Hv. + rewrite <- (proj1 reduce_via_partial_correct''). + rewrite Z.mod_small by apply reduce_via_partial_correct''. + reflexivity. + Qed. + End correctness. +End montgomery. diff --git a/src/LegacyArithmetic/Pow2Base.v b/src/LegacyArithmetic/Pow2Base.v new file mode 100644 index 000000000..62f1f742d --- /dev/null +++ b/src/LegacyArithmetic/Pow2Base.v @@ -0,0 +1,19 @@ +Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. +Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.ZUtil. +Require Import Coq.Lists.List. + +Local Open Scope Z_scope. + +Section Pow2Base. + Context (limb_widths : list Z). + Local Notation "w[ i ]" := (nth_default 0 limb_widths i). + Fixpoint base_from_limb_widths limb_widths := + match limb_widths with + | nil => nil + | w :: lw => 1 :: map (Z.mul (two_p w)) (base_from_limb_widths lw) + end. + Local Notation base := (base_from_limb_widths limb_widths). + Definition bounded us := forall i, 0 <= nth_default 0 us i < 2 ^ w[i]. + Definition upper_bound := 2 ^ (sum_firstn limb_widths (length limb_widths)). +End Pow2Base. diff --git a/src/LegacyArithmetic/Pow2BaseProofs.v b/src/LegacyArithmetic/Pow2BaseProofs.v new file mode 100644 index 000000000..8a38275dd --- /dev/null +++ b/src/LegacyArithmetic/Pow2BaseProofs.v @@ -0,0 +1,555 @@ +Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Coq.Lists.List. +Require Import Coq.funind.Recdef. +Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. +Require Import Crypto.LegacyArithmetic.VerdiTactics. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.Tactics.RewriteHyp. +Require Import Crypto.LegacyArithmetic.Pow2Base. +Require Import Crypto.Util.Notations. +Require Export Crypto.Util.Bool. +Require Export Crypto.Util.FixCoqMistakes. +Local Open Scope Z_scope. + +Require Crypto.LegacyArithmetic.BaseSystemProofs. + +Create HintDb simpl_add_to_nth discriminated. +Create HintDb push_upper_bound discriminated. +Create HintDb pull_upper_bound discriminated. +Create HintDb push_base_from_limb_widths discriminated. +Create HintDb pull_base_from_limb_widths discriminated. + +Hint Extern 1 => progress autorewrite with push_upper_bound in * : push_upper_bound. +Hint Extern 1 => progress autorewrite with pull_upper_bound in * : pull_upper_bound. +Hint Extern 1 => progress autorewrite with push_base_from_limb_widths in * : push_base_from_limb_widths. +Hint Extern 1 => progress autorewrite with pull_base_from_limb_widths in * : pull_base_from_limb_widths. + +Section Pow2BaseProofs. + Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). + Local Notation base := (base_from_limb_widths limb_widths). + + Lemma base_from_limb_widths_length ls : length (base_from_limb_widths ls) = length ls. + Proof using Type. + clear limb_widths limb_widths_nonneg. + induction ls; [ reflexivity | simpl in * ]. + autorewrite with distr_length; auto. + Qed. + Hint Rewrite base_from_limb_widths_length : distr_length. + + Lemma base_from_limb_widths_cons : forall l0 l, + base_from_limb_widths (l0 :: l) = 1 :: map (Z.mul (two_p l0)) (base_from_limb_widths l). + Proof using Type. reflexivity. Qed. + Hint Rewrite base_from_limb_widths_cons : push_base_from_limb_widths. + Hint Rewrite <- base_from_limb_widths_cons : pull_base_from_limb_widths. + + Lemma base_from_limb_widths_nil : base_from_limb_widths nil = nil. + Proof using Type. reflexivity. Qed. + Hint Rewrite base_from_limb_widths_nil : push_base_from_limb_widths. + + Lemma firstn_base_from_limb_widths : forall n, firstn n (base_from_limb_widths limb_widths) = base_from_limb_widths (firstn n limb_widths). + Proof using Type. + clear limb_widths_nonneg. (* don't use this in the inductive hypothesis *) + induction limb_widths as [|l ls IHls]; intros [|n]; try reflexivity. + autorewrite with push_base_from_limb_widths push_firstn; boring. + Qed. + Hint Rewrite <- @firstn_base_from_limb_widths : push_base_from_limb_widths. + Hint Rewrite <- @firstn_base_from_limb_widths : pull_firstn. + Hint Rewrite @firstn_base_from_limb_widths : pull_base_from_limb_widths. + Hint Rewrite @firstn_base_from_limb_widths : push_firstn. + + Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n. + Proof using Type*. + unfold sum_firstn; intros. + apply fold_right_invariant; try omega. + eauto using Z.add_nonneg_nonneg, limb_widths_nonneg, In_firstn. + Qed. Hint Resolve sum_firstn_limb_widths_nonneg. + + Lemma base_from_limb_widths_step : forall i b w, (S i < length limb_widths)%nat -> + nth_error base i = Some b -> + nth_error limb_widths i = Some w -> + nth_error base (S i) = Some (two_p w * b). + Proof using Type. + clear limb_widths_nonneg. (* don't use this in the inductive hypothesis *) + induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b; + unfold base_from_limb_widths in *; fold base_from_limb_widths in *; + [rewrite (@nil_length0 Z) in *; omega | ]. + simpl in *. + case_eq i; intros; subst. + + subst; apply nth_error_first in nth_err_w. + apply nth_error_first in nth_err_b; subst. + apply map_nth_error. + case_eq l; intros; subst; [simpl in *; omega | ]. + unfold base_from_limb_widths; fold base_from_limb_widths. + reflexivity. + + simpl in nth_err_w. + apply nth_error_map in nth_err_w. + destruct nth_err_w as [x [A B] ]. + subst. + replace (two_p w * (two_p a * x)) with (two_p a * (two_p w * x)) by ring. + apply map_nth_error. + apply IHl; auto. omega. + Qed. + + + Lemma nth_error_base : forall i, (i < length limb_widths)%nat -> + nth_error base i = Some (two_p (sum_firstn limb_widths i)). + Proof using Type*. + induction i; intros. + + unfold sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity. + intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega. + + assert (i < length limb_widths)%nat as lt_i_length by omega. + specialize (IHi lt_i_length). + destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w]. + erewrite base_from_limb_widths_step; eauto. + f_equal. + simpl. + destruct (NPeano.Nat.eq_dec i 0). + - subst; unfold sum_firstn; simpl. + apply nth_error_exists_first in nth_err_w. + destruct nth_err_w as [l' lw_destruct]; subst. + simpl; ring_simplify. + f_equal; ring. + - erewrite sum_firstn_succ; eauto. + symmetry. + apply two_p_is_exp; auto using sum_firstn_limb_widths_nonneg. + apply limb_widths_nonneg. + eapply nth_error_value_In; eauto. + Qed. + + Lemma nth_default_base : forall d i, (i < length limb_widths)%nat -> + nth_default d base i = 2 ^ (sum_firstn limb_widths i). + Proof using Type*. + intros ? ? i_lt_length. + apply nth_error_value_eq_nth_default. + rewrite nth_error_base, two_p_correct by assumption. + reflexivity. + Qed. + + Lemma b0_1 : forall x : Z, limb_widths <> nil -> nth_default x base 0 = 1. + Proof using Type. + case_eq limb_widths; intros; [congruence | reflexivity]. + Qed. + + Lemma base_from_limb_widths_app : forall l0 l + (l0_nonneg : forall x, In x l0 -> 0 <= x) + (l_nonneg : forall x, In x l -> 0 <= x), + base_from_limb_widths (l0 ++ l) + = base_from_limb_widths l0 ++ map (Z.mul (two_p (sum_firstn l0 (length l0)))) (base_from_limb_widths l). + Proof using Type. + induction l0 as [|?? IHl0]. + { simpl; intros; rewrite <- map_id at 1; apply map_ext; intros; omega. } + { simpl; intros; rewrite !IHl0, !map_app, map_map, sum_firstn_succ_cons, two_p_is_exp by auto with znonzero. + do 2 f_equal; apply map_ext; intros; lia. } + Qed. + + Lemma skipn_base_from_limb_widths : forall n, skipn n (base_from_limb_widths limb_widths) = map (Z.mul (two_p (sum_firstn limb_widths n))) (base_from_limb_widths (skipn n limb_widths)). + Proof using Type*. + intro n; pose proof (base_from_limb_widths_app (firstn n limb_widths) (skipn n limb_widths)) as H. + specialize_by eauto using In_firstn, In_skipn. + autorewrite with simpl_firstn simpl_skipn in *. + rewrite H, skipn_app, skipn_all by auto with arith distr_length; clear H. + simpl; distr_length. + apply Min.min_case_strong; intro; + unfold sum_firstn; autorewrite with natsimplify simpl_skipn simpl_firstn; + reflexivity. + Qed. + Hint Rewrite <- @skipn_base_from_limb_widths : push_base_from_limb_widths. + Hint Rewrite <- @skipn_base_from_limb_widths : pull_skipn. + Hint Rewrite @skipn_base_from_limb_widths : pull_base_from_limb_widths. + Hint Rewrite @skipn_base_from_limb_widths : push_skipn. + + Lemma pow2_mod_bounded :forall lw us i, (forall w, In w lw -> 0 <= w) -> bounded lw us -> + Z.pow2_mod (nth_default 0 us i) (nth_default 0 lw i) = nth_default 0 us i. + Proof using Type. + clear. + repeat match goal with + | |- _ => progress (cbv [bounded]; intros) + | |- _ => break_if + | |- _ => apply Z.bits_inj' + | |- _ => rewrite Z.testbit_pow2_mod by (apply nth_default_preserves_properties; auto; omega) + | |- _ => reflexivity + end. + specialize (H0 i). + symmetry. + rewrite <- (Z.mod_pow2_bits_high (nth_default 0 us i) (nth_default 0 lw i) n); + [ rewrite Z.mod_small by omega; reflexivity | ]. + split; try omega. + apply nth_default_preserves_properties; auto; omega. + Qed. + + Lemma bounded_nil_iff : forall us, bounded nil us <-> (forall u, In u us -> u = 0). + Proof using Type. + clear. + split; cbv [bounded]; intros. + + edestruct (In_nth_error_value us u); try assumption. + specialize (H x). + replace u with (nth_default 0 us x) by (auto using nth_error_value_eq_nth_default). + rewrite nth_default_nil, Z.pow_0_r in H. + omega. + + rewrite nth_default_nil, Z.pow_0_r. + apply nth_default_preserves_properties; try omega. + intros. + apply H in H0. + omega. + Qed. + + Lemma bounded_iff : forall lw us, bounded lw us <-> forall i, 0 <= nth_default 0 us i < 2 ^ nth_default 0 lw i. + Proof using Type. + clear. + cbv [bounded]; intros. + reflexivity. + Qed. + + Lemma digit_select : forall us i, bounded limb_widths us -> + nth_default 0 us i = Z.pow2_mod (BaseSystem.decode base us >> sum_firstn limb_widths i) (nth_default 0 limb_widths i). + Proof using Type*. + intro; revert limb_widths limb_widths_nonneg; induction us; intros. + + rewrite nth_default_nil, BaseSystemProofs.decode_nil, Z.shiftr_0_l, Z.pow2_mod_spec, Z.mod_0_l by + (try (apply Z.pow_nonzero; try omega); apply nth_default_preserves_properties; auto; omega). + reflexivity. + + destruct i. + - rewrite nth_default_cons, sum_firstn_0, Z.shiftr_0_r. + destruct limb_widths as [|w lw]. + * cbv [base_from_limb_widths]. + rewrite <-pow2_mod_bounded with (lw := nil); rewrite bounded_nil_iff in *; auto using in_cons; + try solve [intros; exfalso; eauto using in_nil]. + rewrite !nth_default_nil, BaseSystemProofs.decode_base_nil; auto. + cbv. auto using in_eq. + * rewrite nth_default_cons, base_from_limb_widths_cons, BaseSystemProofs.peel_decode. + fold (BaseSystem.mul_each (two_p w)). + rewrite <-BaseSystemProofs.mul_each_base, BaseSystemProofs.mul_each_rep. + rewrite two_p_correct, (Z.mul_comm (2 ^ w)). + rewrite <-Z.shiftl_mul_pow2 by auto using in_eq. + rewrite bounded_iff in *. + specialize (H 0%nat); rewrite !nth_default_cons in H. + rewrite <-Z.lor_shiftl by (auto using in_eq; omega). + apply Z.bits_inj'; intros. + rewrite Z.testbit_pow2_mod by auto using in_eq. + break_if. { + autorewrite with Ztestbit; break_match; + try rewrite Z.testbit_neg_r with (n := n - w) by omega; + autorewrite with bool_congr; + f_equal; ring. + } { + replace a with (a mod 2 ^ w) by (auto using Z.mod_small). + apply Z.mod_pow2_bits_high. split; auto using in_eq; omega. + } + - rewrite nth_default_cons_S. + destruct limb_widths as [|w lw]. + * cbv [base_from_limb_widths]. + rewrite <-pow2_mod_bounded with (lw := nil); rewrite bounded_nil_iff in *; auto using in_cons. + rewrite sum_firstn_nil, !nth_default_nil, BaseSystemProofs.decode_base_nil, Z.shiftr_0_r. + apply nth_default_preserves_properties; intros; auto using in_cons. + f_equal; auto using in_cons. + * rewrite sum_firstn_succ_cons, nth_default_cons_S, base_from_limb_widths_cons, BaseSystemProofs.peel_decode. + fold (BaseSystem.mul_each (two_p w)). + rewrite <-BaseSystemProofs.mul_each_base, BaseSystemProofs.mul_each_rep. + rewrite two_p_correct, (Z.mul_comm (2 ^ w)). + rewrite <-Z.shiftl_mul_pow2 by auto using in_eq. + rewrite bounded_iff in *. + rewrite Z.shiftr_add_shiftl_high by first + [ pose proof (sum_firstn_nonnegative i lw); split; auto using in_eq; specialize_by auto using in_cons; omega + | specialize (H 0%nat); rewrite !nth_default_cons in H; omega ]. + rewrite IHus with (limb_widths := lw) by + (auto using in_cons; rewrite ?bounded_iff; intro j; specialize (H (S j)); + rewrite !nth_default_cons_S in H; assumption). + repeat f_equal; try ring. + Qed. + + Lemma nth_default_limb_widths_nonneg : forall i, 0 <= nth_default 0 limb_widths i. + Proof using Type*. + intros; apply nth_default_preserves_properties; auto; omega. + Qed. Hint Resolve nth_default_limb_widths_nonneg. + + Lemma decode_firstn_pow2_mod : forall us i, + (i <= length us)%nat -> + length us = length limb_widths -> + bounded limb_widths us -> + BaseSystem.decode' base (firstn i us) = Z.pow2_mod (BaseSystem.decode' base us) (sum_firstn limb_widths i). + Proof using Type*. + intros; induction i; + repeat match goal with + | |- _ => rewrite sum_firstn_0, BaseSystemProofs.decode_nil, Z.pow2_mod_0_r; reflexivity + | |- _ => progress distr_length + | |- _ => progress autorewrite with simpl_firstn + | |- _ => rewrite firstn_succ with (d := 0) + | |- _ => rewrite BaseSystemProofs.set_higher + | |- _ => rewrite nth_default_base + | |- _ => rewrite IHi + | |- _ => rewrite <-Z.lor_shiftl by (rewrite ?Z.pow2_mod_spec; try apply Z.mod_pos_bound; zero_bounds) + | |- appcontext[min ?x ?y] => (rewrite Nat.min_l by omega || rewrite Nat.min_r by omega) + | |- appcontext[2 ^ ?a * _] => rewrite (Z.mul_comm (2 ^ a)); rewrite <-Z.shiftl_mul_pow2 + | |- _ => solve [auto] + | |- _ => lia + end. + rewrite digit_select by assumption; apply Z.bits_inj'. + repeat match goal with + | |- _ => progress intros + | |- _ => progress autorewrite with Ztestbit + | |- _ => rewrite Z.testbit_pow2_mod by (omega || trivial) + | |- _ => break_if; try omega + | H : ?a < ?b |- appcontext[Z.testbit _ (?a - ?b)] => + rewrite (Z.testbit_neg_r _ (a-b)) by omega + | |- _ => reflexivity + | |- _ => solve [f_equal; ring] + | |- _ => rewrite sum_firstn_succ_default in *; + pose proof (nth_default_limb_widths_nonneg i); omega + end. + Qed. + + Lemma testbit_decode_firstn_high : forall us i n, + (i <= length us)%nat -> + length us = length limb_widths -> + bounded limb_widths us -> + sum_firstn limb_widths i <= n -> + Z.testbit (BaseSystem.decode base (firstn i us)) n = false. + Proof using Type*. + repeat match goal with + | |- _ => progress intros + | |- _ => progress autorewrite with Ztestbit + | |- _ => rewrite decode_firstn_pow2_mod + | |- _ => rewrite Z.testbit_pow2_mod + | |- _ => break_if + | |- _ => assumption + | |- _ => solve [auto] + | H : ?a <= ?b |- 0 <= ?b => assert (0 <= a) by (omega || auto); omega + end. + Qed. + + Lemma testbit_decode_high : forall us n, + length us = length limb_widths -> + bounded limb_widths us -> + sum_firstn limb_widths (length us) <= n -> + Z.testbit (BaseSystem.decode base us) n = false. + Proof using Type*. + intros. + erewrite <-(firstn_all _ us) by reflexivity. + auto using testbit_decode_firstn_high. + Qed. + + (** TODO: Figure out how to automate and clean up this proof *) + Lemma decode_nonneg : forall us, + length us = length limb_widths -> + bounded limb_widths us -> + 0 <= BaseSystem.decode base us. + Proof using Type*. + intros. + unfold bounded, BaseSystem.decode, BaseSystem.decode' in *; simpl in *. + pose 0 as zero. + assert (0 <= zero) by reflexivity. + replace base with (map (Z.mul (two_p zero)) base) + by (etransitivity; [ | apply map_id ]; apply map_ext; auto with zarith). + clearbody zero. + revert dependent zero. + generalize dependent limb_widths. + induction us as [|u us IHus]; intros [|w limb_widths'] ?? Hbounded ??; simpl in *; + try (reflexivity || congruence). + pose proof (Hbounded 0%nat) as Hbounded0. + pose proof (fun n => Hbounded (S n)) as HboundedS. + unfold nth_default, nth_error in Hbounded0. + unfold nth_default in HboundedS. + rewrite map_map. + unfold BaseSystem.accumulate at 1; simpl. + assert (0 < two_p zero) by (rewrite two_p_equiv; auto with zarith). + replace (map (fun x => two_p zero * (two_p w * x)) (base_from_limb_widths limb_widths')) with (map (Z.mul (two_p (zero + w))) (base_from_limb_widths limb_widths')) + by (apply map_ext; rewrite two_p_is_exp by auto with zarith omega; auto with zarith). + change 0 with (0 + 0) at 1. + apply Z.add_le_mono; simpl in *; auto with zarith. + Qed. + + Lemma decode_upper_bound : forall us, + length us = length limb_widths -> + bounded limb_widths us -> + 0 <= BaseSystem.decode base us < upper_bound limb_widths. + Proof using Type*. + cbv [upper_bound]; intros. + split. + { apply decode_nonneg; auto. } + { apply Z.testbit_false_bound; auto; intros. + rewrite testbit_decode_high; auto; + replace (length us) with (length limb_widths); try omega. } + Qed. + + Lemma decode_shift_app : forall us0 us1, (length (us0 ++ us1) <= length limb_widths)%nat -> + BaseSystem.decode base (us0 ++ us1) = (BaseSystem.decode (base_from_limb_widths (firstn (length us0) limb_widths)) us0) + ((BaseSystem.decode (base_from_limb_widths (skipn (length us0) limb_widths)) us1) << sum_firstn limb_widths (length us0)). + Proof using Type*. + unfold BaseSystem.decode; intros us0 us1 ?. + assert (0 <= sum_firstn limb_widths (length us0)) by auto using sum_firstn_nonnegative. + rewrite BaseSystemProofs.decode'_splice; autorewrite with push_firstn. + apply Z.add_cancel_l. + autorewrite with pull_base_from_limb_widths Zshift_to_pow zsimplify. + rewrite BaseSystemProofs.decode'_map_mul, two_p_correct; nia. + Qed. + + Lemma decode_shift : forall us u0, (length (u0 :: us) <= length limb_widths)%nat -> + BaseSystem.decode base (u0 :: us) = u0 + ((BaseSystem.decode (base_from_limb_widths (tl limb_widths)) us) << (nth_default 0 limb_widths 0)). + Proof using Type*. + intros; etransitivity; [ apply (decode_shift_app (u0::nil)); assumption | ]. + transitivity (u0 * 1 + 0 + ((BaseSystem.decode (base_from_limb_widths (tl limb_widths)) us) << (nth_default 0 limb_widths 0 + 0))); [ | autorewrite with zsimplify; reflexivity ]. + destruct limb_widths; distr_length; reflexivity. + Qed. + + Lemma upper_bound_nil : upper_bound nil = 1. + Proof using Type. reflexivity. Qed. + + Lemma upper_bound_cons x xs : 0 <= x -> 0 <= sum_firstn xs (length xs) -> upper_bound (x::xs) = 2^x * upper_bound xs. + Proof using Type. + intros Hx Hxs. + unfold upper_bound; simpl. + autorewrite with simpl_sum_firstn pull_Zpow. + reflexivity. + Qed. + + Lemma upper_bound_app xs ys : 0 <= sum_firstn xs (length xs) -> 0 <= sum_firstn ys (length ys) -> upper_bound (xs ++ ys) = upper_bound xs * upper_bound ys. + Proof using Type. + intros Hxs Hys. + unfold upper_bound; simpl. + autorewrite with distr_length simpl_sum_firstn pull_Zpow. + reflexivity. + Qed. + +End Pow2BaseProofs. +Hint Rewrite base_from_limb_widths_cons base_from_limb_widths_nil : push_base_from_limb_widths. +Hint Rewrite <- base_from_limb_widths_cons : pull_base_from_limb_widths. + +Hint Rewrite <- @firstn_base_from_limb_widths : push_base_from_limb_widths. +Hint Rewrite <- @firstn_base_from_limb_widths : pull_firstn. +Hint Rewrite @firstn_base_from_limb_widths : pull_base_from_limb_widths. +Hint Rewrite @firstn_base_from_limb_widths : push_firstn. +Hint Rewrite <- @skipn_base_from_limb_widths : push_base_from_limb_widths. +Hint Rewrite <- @skipn_base_from_limb_widths : pull_skipn. +Hint Rewrite @skipn_base_from_limb_widths : pull_base_from_limb_widths. +Hint Rewrite @skipn_base_from_limb_widths : push_skipn. + +Hint Rewrite @base_from_limb_widths_length : distr_length. +Hint Rewrite @upper_bound_nil @upper_bound_cons @upper_bound_app using solve [ eauto with znonzero ] : push_upper_bound. +Hint Rewrite <- @upper_bound_cons @upper_bound_app using solve [ eauto with znonzero ] : pull_upper_bound. + +Section UniformBase. + Context {width : Z} (limb_width_nonneg : 0 <= width). + Context (limb_widths : list Z) + (limb_widths_uniform : forall w, In w limb_widths -> w = width). + Local Notation base := (base_from_limb_widths limb_widths). + + Lemma bounded_uniform : forall us, (length us <= length limb_widths)%nat -> + (bounded limb_widths us <-> (forall u, In u us -> 0 <= u < 2 ^ width)). + Proof using Type*. + cbv [bounded]; split; intro A; intros. + + let G := fresh "G" in + match goal with H : In _ us |- _ => + eapply In_nth in H; destruct H as [? G]; destruct G as [? G]; + rewrite <-nth_default_eq in G; rewrite <-G end. + specialize (A x). + split; try eapply A. + eapply Z.lt_le_trans; try apply A. + apply nth_default_preserves_properties; [ | apply Z.pow_le_mono_r; omega ] . + intros; apply Z.eq_le_incl. + f_equal; auto. + + apply nth_default_preserves_properties_length_dep; + try solve [apply nth_default_preserves_properties; split; zero_bounds; rewrite limb_widths_uniform; auto || omega]. + intros; apply nth_default_preserves_properties_length_dep; try solve [intros; omega]. + let x := fresh "x" in intro x; intros; + replace x with width; try symmetry; auto. + Qed. + + Lemma uniform_limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w. + Proof using Type*. + intros. + replace w with width by (symmetry; auto). + assumption. + Qed. + + Lemma nth_default_uniform_base_full : forall i, + nth_default 0 limb_widths i = if lt_dec i (length limb_widths) + then width else 0. + Admitted. + + Lemma nth_default_uniform_base : forall i, (i < length limb_widths)%nat -> + nth_default 0 limb_widths i = width. + Proof using Type*. + intros; rewrite nth_default_uniform_base_full. + edestruct lt_dec; omega. + Qed. + + Lemma sum_firstn_uniform_base : forall i, (i <= length limb_widths)%nat -> + sum_firstn limb_widths i = Z.of_nat i * width. + Proof using limb_widths_uniform. + clear limb_width_nonneg. (* clear this before induction so we don't depend on this *) + induction limb_widths as [|x xs IHxs]; (intros [|i] ?); + simpl @length in *; + autorewrite with simpl_sum_firstn push_Zof_nat zsimplify; + try reflexivity; + try omega. + assert (x = width) by auto with datatypes; subst. + rewrite IHxs by auto with datatypes omega; omega. + Qed. + + Lemma sum_firstn_uniform_base_strong : forall i, (length limb_widths <= i)%nat -> + sum_firstn limb_widths i = Z.of_nat (length limb_widths) * width. + Proof using limb_widths_uniform. + intros; rewrite sum_firstn_all, sum_firstn_uniform_base by omega; reflexivity. + Qed. + + Lemma upper_bound_uniform : upper_bound limb_widths = 2^(Z.of_nat (length limb_widths) * width). + Proof using limb_widths_uniform. + unfold upper_bound; rewrite sum_firstn_uniform_base_strong by omega; reflexivity. + Qed. + + (* TODO : move *) + Lemma decode_truncate_base : forall us bs, BaseSystem.decode bs us = BaseSystem.decode (firstn (length us) bs) us. + Proof using Type. + clear. + induction us; intros. + + rewrite !BaseSystemProofs.decode_nil; reflexivity. + + distr_length. + destruct bs. + - rewrite firstn_nil, !BaseSystemProofs.decode_base_nil; reflexivity. + - rewrite firstn_cons, !BaseSystemProofs.peel_decode. + f_equal. + apply IHus. + Qed. + + (* TODO : move *) + Lemma tl_repeat : forall {A} xs n (x : A), (forall y, In y xs -> y = x) -> + (n < length xs)%nat -> + firstn n xs = firstn n (tl xs). + Proof using Type. + intros. + erewrite (repeat_spec_eq xs) by first [ eassumption | reflexivity ]. + rewrite ListUtil.tl_repeat. + autorewrite with push_firstn. + apply f_equal; omega *. + Qed. + + Lemma decode_tl_base : forall us, (length us < length limb_widths)%nat -> + BaseSystem.decode base us = BaseSystem.decode (base_from_limb_widths (tl limb_widths)) us. + Proof using limb_widths_uniform. + intros. + match goal with |- BaseSystem.decode ?b1 _ = BaseSystem.decode ?b2 _ => + rewrite (decode_truncate_base _ b1), (decode_truncate_base _ b2) end. + rewrite !firstn_base_from_limb_widths. + do 2 f_equal. + eauto using tl_repeat. + Qed. + + Lemma decode_shift_uniform_tl : forall us u0, (length (u0 :: us) <= length limb_widths)%nat -> + BaseSystem.decode base (u0 :: us) = u0 + ((BaseSystem.decode (base_from_limb_widths (tl limb_widths)) us) << width). + Proof using Type*. + intros. + rewrite <- (nth_default_uniform_base 0) by distr_length. + rewrite decode_shift by auto using uniform_limb_widths_nonneg. + reflexivity. + Qed. + + Lemma decode_shift_uniform_app : forall us0 us1, (length (us0 ++ us1) <= length limb_widths)%nat -> + BaseSystem.decode base (us0 ++ us1) = (BaseSystem.decode (base_from_limb_widths (firstn (length us0) limb_widths)) us0) + ((BaseSystem.decode (base_from_limb_widths (skipn (length us0) limb_widths)) us1) << (Z.of_nat (length us0) * width)). + Proof using Type*. + intros. + rewrite <- sum_firstn_uniform_base by (distr_length; omega). + rewrite decode_shift_app by auto using uniform_limb_widths_nonneg. + reflexivity. + Qed. +End UniformBase. \ No newline at end of file diff --git a/src/LegacyArithmetic/README.md b/src/LegacyArithmetic/README.md new file mode 100644 index 000000000..b0137664c --- /dev/null +++ b/src/LegacyArithmetic/README.md @@ -0,0 +1,3 @@ +The development of this directory predates `src/Arithmetic`, and should probably +be considered to be superseded by it. The p256 Montgomery reduction for +a 128-bit cpu synthesized here still works. diff --git a/src/LegacyArithmetic/VerdiTactics.v b/src/LegacyArithmetic/VerdiTactics.v new file mode 100644 index 000000000..4060fc675 --- /dev/null +++ b/src/LegacyArithmetic/VerdiTactics.v @@ -0,0 +1,414 @@ +(* +Copyright (c) 2014-2015, Verdi Team +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*) + +Ltac subst_max := idtac "VerdiTactics is deprecated in fiat-crypto"; + repeat match goal with + | [ H : ?X = _ |- _ ] => subst X + | [H : _ = ?X |- _] => subst X + end. + +Ltac inv H := idtac "VerdiTactics is deprecated in fiat-crypto"; inversion H; subst_max. +Ltac invc H := idtac "VerdiTactics is deprecated in fiat-crypto"; inv H; clear H. +Ltac invcs H := idtac "VerdiTactics is deprecated in fiat-crypto"; invc H; simpl in *. + +Ltac break_if := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ |- context [ if ?X then _ else _ ] ] => + match type of X with + | sumbool _ _ => destruct X + | _ => destruct X eqn:? + end + | [ H : context [ if ?X then _ else _ ] |- _] => + match type of X with + | sumbool _ _ => destruct X + | _ => destruct X eqn:? + end + end. + +Ltac break_match_hyp := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : context [ match ?X with _ => _ end ] |- _] => + match type of X with + | sumbool _ _ => destruct X + | _ => destruct X eqn:? + end + end. + +Ltac break_match_goal := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ |- context [ match ?X with _ => _ end ] ] => + match type of X with + | sumbool _ _ => destruct X + | _ => destruct X eqn:? + end + end. + +Ltac break_match := idtac "VerdiTactics is deprecated in fiat-crypto"; break_match_goal || break_match_hyp. + + +Ltac break_exists := idtac "VerdiTactics is deprecated in fiat-crypto"; + repeat match goal with + | [H : exists _, _ |- _ ] => destruct H + end. + +Ltac break_exists_exists := idtac "VerdiTactics is deprecated in fiat-crypto"; + repeat match goal with + | H:exists _, _ |- _ => + let x := fresh "x" in + destruct H as [x]; exists x + end. + +Ltac break_and := idtac "VerdiTactics is deprecated in fiat-crypto"; + repeat match goal with + | [H : _ /\ _ |- _ ] => destruct H + end. + +Ltac solve_by_inversion' tac := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [H : _ |- _] => solve [inv H; tac] + end. + +Ltac solve_by_inversion := idtac "VerdiTactics is deprecated in fiat-crypto"; solve_by_inversion' auto. + +Ltac apply_fun f H:= idtac "VerdiTactics is deprecated in fiat-crypto"; + match type of H with + | ?X = ?Y => assert (f X = f Y) + end. + +Ltac conclude H tac := idtac "VerdiTactics is deprecated in fiat-crypto"; + (let H' := fresh in + match type of H with + | ?P -> _ => assert P as H' by (tac) + end; specialize (H H'); clear H'). + +Ltac concludes := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : ?P -> _ |- _ ] => conclude H auto + end. + +Ltac forward H := idtac "VerdiTactics is deprecated in fiat-crypto"; + let H' := fresh in + match type of H with + | ?P -> _ => assert P as H' + end. + +Ltac forwards := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : ?P -> _ |- _ ] => forward H + end. + +Ltac find_contradiction := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : ?X = _, H' : ?X = _ |- _ ] => rewrite H in H'; solve_by_inversion + end. + +Ltac find_rewrite := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : ?X _ _ _ _ = _, H' : ?X _ _ _ _ = _ |- _ ] => rewrite H in H' + | [ H : ?X = _, H' : ?X = _ |- _ ] => rewrite H in H' + | [ H : ?X = _, H' : context [ ?X ] |- _ ] => rewrite H in H' + | [ H : ?X = _ |- context [ ?X ] ] => rewrite H + end. + +Ltac find_rewrite_lem lem := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ |- _ ] => + rewrite lem in H; [idtac] + end. + +Ltac find_rewrite_lem_by lem t := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ |- _ ] => + rewrite lem in H by t + end. + +Ltac find_erewrite_lem lem := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ |- _] => erewrite lem in H by eauto + end. + +Ltac find_reverse_rewrite := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ = ?X _ _ _ _, H' : ?X _ _ _ _ = _ |- _ ] => rewrite <- H in H' + | [ H : _ = ?X, H' : context [ ?X ] |- _ ] => rewrite <- H in H' + | [ H : _ = ?X |- context [ ?X ] ] => rewrite <- H + end. + +Ltac find_inversion := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : ?X _ _ _ _ _ _ = ?X _ _ _ _ _ _ |- _ ] => invc H + | [ H : ?X _ _ _ _ _ = ?X _ _ _ _ _ |- _ ] => invc H + | [ H : ?X _ _ _ _ = ?X _ _ _ _ |- _ ] => invc H + | [ H : ?X _ _ _ = ?X _ _ _ |- _ ] => invc H + | [ H : ?X _ _ = ?X _ _ |- _ ] => invc H + | [ H : ?X _ = ?X _ |- _ ] => invc H + end. + +Ltac prove_eq := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : ?X ?x1 ?x2 ?x3 = ?X ?y1 ?y2 ?y3 |- _ ] => + assert (x1 = y1) by congruence; + assert (x2 = y2) by congruence; + assert (x3 = y3) by congruence; + clear H + | [ H : ?X ?x1 ?x2 = ?X ?y1 ?y2 |- _ ] => + assert (x1 = y1) by congruence; + assert (x2 = y2) by congruence; + clear H + | [ H : ?X ?x1 = ?X ?y1 |- _ ] => + assert (x1 = y1) by congruence; + clear H + end. + +Ltac tuple_inversion := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : (_, _, _, _) = (_, _, _, _) |- _ ] => invc H + | [ H : (_, _, _) = (_, _, _) |- _ ] => invc H + | [ H : (_, _) = (_, _) |- _ ] => invc H + end. + +Ltac f_apply H f := idtac "VerdiTactics is deprecated in fiat-crypto"; + match type of H with + | ?X = ?Y => + assert (f X = f Y) by (rewrite H; auto) + end. + +Ltac break_let := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : context [ (let (_,_) := ?X in _) ] |- _ ] => destruct X eqn:? + | [ |- context [ (let (_,_) := ?X in _) ] ] => destruct X eqn:? + end. + +Ltac break_or_hyp := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ \/ _ |- _ ] => invc H + end. + +Ltac copy_apply lem H := idtac "VerdiTactics is deprecated in fiat-crypto"; + let x := fresh in + pose proof H as x; + apply lem in x. + +Ltac copy_eapply lem H := idtac "VerdiTactics is deprecated in fiat-crypto"; + let x := fresh in + pose proof H as x; + eapply lem in x. + +Ltac conclude_using tac := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : ?P -> _ |- _ ] => conclude H tac + end. + +Ltac find_higher_order_rewrite := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ = _ |- _ ] => rewrite H in * + | [ H : forall _, _ = _ |- _ ] => rewrite H in * + | [ H : forall _ _, _ = _ |- _ ] => rewrite H in * + end. + +Ltac find_reverse_higher_order_rewrite := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ = _ |- _ ] => rewrite <- H in * + | [ H : forall _, _ = _ |- _ ] => rewrite <- H in * + | [ H : forall _ _, _ = _ |- _ ] => rewrite <- H in * + end. + +Ltac clean := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : ?X = ?X |- _ ] => clear H + end. + +Ltac find_apply_hyp_goal := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ |- _ ] => solve [apply H] + end. + +Ltac find_copy_apply_lem_hyp lem := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ |- _ ] => copy_apply lem H + end. + +Ltac find_apply_hyp_hyp := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : forall _, _ -> _, + H' : _ |- _ ] => + apply H in H'; [idtac] + | [ H : _ -> _ , H' : _ |- _ ] => + apply H in H'; auto; [idtac] + end. + +Ltac find_copy_apply_hyp_hyp := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : forall _, _ -> _, + H' : _ |- _ ] => + copy_apply H H'; [idtac] + | [ H : _ -> _ , H' : _ |- _ ] => + copy_apply H H'; auto; [idtac] + end. + +Ltac find_apply_lem_hyp lem := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ |- _ ] => apply lem in H + end. + +Ltac find_eapply_lem_hyp lem := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ |- _ ] => eapply lem in H + end. + +Ltac insterU H := idtac "VerdiTactics is deprecated in fiat-crypto"; + match type of H with + | forall _ : ?T, _ => + let x := fresh "x" in + evar (x : T); + let x' := (eval unfold x in x) in + clear x; specialize (H x') + end. + +Ltac find_insterU := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : forall _, _ |- _ ] => insterU H + end. + +Ltac eapply_prop P := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | H : P _ |- _ => + eapply H + end. + +Ltac isVar t := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | v : _ |- _ => + match t with + | v => idtac + end + end. + +Ltac remGen t := idtac "VerdiTactics is deprecated in fiat-crypto"; + let x := fresh in + let H := fresh in + remember t as x eqn:H; + generalize dependent H. + +Ltac remGenIfNotVar t := idtac "VerdiTactics is deprecated in fiat-crypto"; first [isVar t| remGen t]. + +Ltac rememberNonVars H := idtac "VerdiTactics is deprecated in fiat-crypto"; + match type of H with + | _ ?a ?b ?c ?d ?e => + remGenIfNotVar a; + remGenIfNotVar b; + remGenIfNotVar c; + remGenIfNotVar d; + remGenIfNotVar e + | _ ?a ?b ?c ?d => + remGenIfNotVar a; + remGenIfNotVar b; + remGenIfNotVar c; + remGenIfNotVar d + | _ ?a ?b ?c => + remGenIfNotVar a; + remGenIfNotVar b; + remGenIfNotVar c + | _ ?a ?b => + remGenIfNotVar a; + remGenIfNotVar b + | _ ?a => + remGenIfNotVar a + end. + +Ltac generalizeEverythingElse H := idtac "VerdiTactics is deprecated in fiat-crypto"; + repeat match goal with + | [ x : ?T |- _ ] => + first [ + match H with + | x => fail 2 + end | + match type of H with + | context [x] => fail 2 + end | + revert x] + end. + +Ltac prep_induction H := idtac "VerdiTactics is deprecated in fiat-crypto"; + rememberNonVars H; + generalizeEverythingElse H. + +Ltac econcludes := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : ?P -> _ |- _ ] => conclude H eauto + end. + +Ltac find_copy_eapply_lem_hyp lem := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : _ |- _ ] => copy_eapply lem H + end. + +Ltac apply_prop_hyp P Q := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : context [ P ], H' : context [ Q ] |- _ ] => + apply H in H' + end. + + +Ltac eapply_prop_hyp P Q := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : context [ P ], H' : context [ Q ] |- _ ] => + eapply H in H' + end. + +Ltac copy_eapply_prop_hyp P Q := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : context [ P ], H' : context [ Q ] |- _ ] => + copy_eapply H H' + end. + +Ltac find_false := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | H : _ -> False |- _ => exfalso; apply H + end. + +Ltac injc H := idtac "VerdiTactics is deprecated in fiat-crypto"; + injection H; clear H; intro; subst_max. + +Ltac find_injection := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : ?X _ _ _ _ _ _ = ?X _ _ _ _ _ _ |- _ ] => injc H + | [ H : ?X _ _ _ _ _ = ?X _ _ _ _ _ |- _ ] => injc H + | [ H : ?X _ _ _ _ = ?X _ _ _ _ |- _ ] => injc H + | [ H : ?X _ _ _ = ?X _ _ _ |- _ ] => injc H + | [ H : ?X _ _ = ?X _ _ |- _ ] => injc H + | [ H : ?X _ = ?X _ |- _ ] => injc H + end. + +Ltac aggresive_rewrite_goal := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with H : _ |- _ => rewrite H end. + +Ltac break_exists_name x := idtac "VerdiTactics is deprecated in fiat-crypto"; + match goal with + | [ H : exists _, _ |- _ ] => destruct H as [x H] + end. diff --git a/src/LegacyArithmetic/ZBounded.v b/src/LegacyArithmetic/ZBounded.v new file mode 100644 index 000000000..bccbf7428 --- /dev/null +++ b/src/LegacyArithmetic/ZBounded.v @@ -0,0 +1,158 @@ +(*** Bounded ℤ-Like Types *) +(** This file specifies a ℤ-like type of bounded integers, with + operations for Montgomery Reduction and Barrett Reduction. *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Notations. + +Local Open Scope Z_scope. + +Class ZLikeOps (small_bound smaller_bound : Z) (modulus : Z) := + { + LargeT : Type; + SmallT : Type; + modulus_digits : SmallT; + decode_large : LargeT -> Z; + decode_small : SmallT -> Z; + Mod_SmallBound : LargeT -> SmallT; + DivBy_SmallBound : LargeT -> SmallT; + DivBy_SmallerBound : LargeT -> SmallT; + Mul : SmallT -> SmallT -> LargeT; + CarryAdd : LargeT -> LargeT -> bool * LargeT; + CarrySubSmall : SmallT -> SmallT -> bool * SmallT; + ConditionalSubtract : bool -> SmallT -> SmallT; + ConditionalSubtractModulus : SmallT -> SmallT + }. + +Delimit Scope small_zlike_scope with small_zlike. +Delimit Scope large_zlike_scope with large_zlike. +Local Open Scope small_zlike_scope. +Local Open Scope large_zlike_scope. +Local Open Scope Z_scope. +Bind Scope small_zlike_scope with SmallT. +Bind Scope large_zlike_scope with LargeT. +Arguments decode_large (_ _ _)%Z _ _%large_zlike. +Arguments decode_small (_ _ _)%Z _ _%small_zlike. +Arguments Mod_SmallBound (_ _ _)%Z _ _%large_zlike. +Arguments DivBy_SmallBound (_ _ _)%Z _ _%large_zlike. +Arguments DivBy_SmallerBound (_ _ _)%Z _ _%large_zlike. +Arguments Mul (_ _ _)%Z _ (_ _)%small_zlike. +Arguments CarryAdd (_ _ _)%Z _ (_ _)%large_zlike. +Arguments CarrySubSmall (_ _ _)%Z _ (_ _)%large_zlike. +Arguments ConditionalSubtract (_ _ _)%Z _ _%bool _%small_zlike. +Arguments ConditionalSubtractModulus (_ _ _)%Z _ _%small_zlike. + +Infix "*" := Mul : large_zlike_scope. +Notation "x + y" := (snd (CarryAdd x y)) : large_zlike_scope. + +Class ZLikeProperties {small_bound smaller_bound modulus : Z} (Zops : ZLikeOps small_bound smaller_bound modulus) := + { + large_valid : LargeT -> Prop; + medium_valid : LargeT -> Prop; + small_valid : SmallT -> Prop; + decode_large_valid : forall v, large_valid v -> 0 <= decode_large v < small_bound * small_bound; + decode_medium_valid : forall v, medium_valid v -> 0 <= decode_large v < small_bound * smaller_bound; + medium_to_large_valid : forall v, medium_valid v -> large_valid v; + decode_small_valid : forall v, small_valid v -> 0 <= decode_small v < small_bound; + modulus_digits_valid : small_valid modulus_digits; + modulus_digits_correct : decode_small modulus_digits = modulus; + Mod_SmallBound_valid : forall v, large_valid v -> small_valid (Mod_SmallBound v); + Mod_SmallBound_correct + : forall v, large_valid v -> decode_small (Mod_SmallBound v) = decode_large v mod small_bound; + DivBy_SmallBound_valid : forall v, large_valid v -> small_valid (DivBy_SmallBound v); + DivBy_SmallBound_correct + : forall v, large_valid v -> decode_small (DivBy_SmallBound v) = decode_large v / small_bound; + DivBy_SmallerBound_valid : forall v, medium_valid v -> small_valid (DivBy_SmallerBound v); + DivBy_SmallerBound_correct + : forall v, medium_valid v -> decode_small (DivBy_SmallerBound v) = decode_large v / smaller_bound; + Mul_valid : forall x y, small_valid x -> small_valid y -> large_valid (Mul x y); + Mul_correct + : forall x y, small_valid x -> small_valid y -> decode_large (Mul x y) = decode_small x * decode_small y; + CarryAdd_valid : forall x y, large_valid x -> large_valid y -> large_valid (snd (CarryAdd x y)); + CarryAdd_correct_fst + : forall x y, large_valid x -> large_valid y -> fst (CarryAdd x y) = (small_bound * small_bound <=? decode_large x + decode_large y); + CarryAdd_correct_snd + : forall x y, large_valid x -> large_valid y -> decode_large (snd (CarryAdd x y)) = (decode_large x + decode_large y) mod (small_bound * small_bound); + CarrySubSmall_valid : forall x y, small_valid x -> small_valid y -> small_valid (snd (CarrySubSmall x y)); + CarrySubSmall_correct_fst + : forall x y, small_valid x -> small_valid y -> fst (CarrySubSmall x y) = (decode_small x - decode_small y small_valid y -> decode_small (snd (CarrySubSmall x y)) = (decode_small x - decode_small y) mod small_bound; + ConditionalSubtract_valid : forall b x, small_valid x -> small_valid (ConditionalSubtract b x); + ConditionalSubtract_correct + : forall b x, small_valid x -> decode_small (ConditionalSubtract b x) + = if b then (decode_small x - decode_small modulus_digits) mod small_bound else decode_small x; + ConditionalSubtractModulus_valid : forall x, small_valid x -> small_valid (ConditionalSubtractModulus x); + ConditionalSubtractModulus_correct + : forall x, small_valid x -> decode_small (ConditionalSubtractModulus x) + = if (decode_small x decode_small (ConditionalSubtractModulus x) + = if (decode_small modulus_digits <=? decode_small x) then decode_small x - decode_small modulus_digits else decode_small x. +Proof. + intros; rewrite ConditionalSubtractModulus_correct by assumption. + break_match; Z.ltb_to_lt; omega. +Qed. + +Lemma modulus_nonneg {small_bound smaller_bound modulus} (Zops : ZLikeOps small_bound smaller_bound modulus) {_ : ZLikeProperties Zops} : 0 <= modulus. +Proof. + pose proof (decode_small_valid _ modulus_digits_valid) as H. + rewrite modulus_digits_correct in H. + omega. +Qed. + +Create HintDb push_zlike_decode discriminated. +Create HintDb pull_zlike_decode discriminated. +Hint Rewrite @Mod_SmallBound_correct @DivBy_SmallBound_correct @DivBy_SmallerBound_correct @Mul_correct @CarryAdd_correct_fst @CarryAdd_correct_snd @CarrySubSmall_correct_fst @CarrySubSmall_correct_snd @ConditionalSubtract_correct @ConditionalSubtractModulus_correct @ConditionalSubtractModulus_correct' @modulus_digits_correct using solve [ typeclasses eauto ] : push_zlike_decode. +Hint Rewrite <- @Mod_SmallBound_correct @DivBy_SmallBound_correct @DivBy_SmallerBound_correct @Mul_correct @CarryAdd_correct_fst @CarryAdd_correct_snd @CarrySubSmall_correct_fst @CarrySubSmall_correct_snd @ConditionalSubtract_correct @ConditionalSubtractModulus_correct @modulus_digits_correct using solve [ typeclasses eauto ] : pull_zlike_decode. + +Ltac get_modulus := + match goal with + | [ _ : ZLikeOps _ _ ?modulus |- _ ] => modulus + end. + +Ltac push_zlike_decode := + let modulus := get_modulus in + repeat first [ erewrite !Mod_SmallBound_correct by typeclasses eauto + | erewrite !DivBy_SmallBound_correct by typeclasses eauto + | erewrite !DivBy_SmallerBound_correct by typeclasses eauto + | erewrite !Mul_correct by typeclasses eauto + | erewrite !CarryAdd_correct_fst by typeclasses eauto + | erewrite !CarryAdd_correct_snd by typeclasses eauto + | erewrite !CarrySubSmall_correct_fst by typeclasses eauto + | erewrite !CarrySubSmall_correct_snd by typeclasses eauto + | erewrite !ConditionalSubtract_correct by typeclasses eauto + | erewrite !ConditionalSubtractModulus_correct by typeclasses eauto + | erewrite !ConditionalSubtractModulus_correct' by typeclasses eauto + | erewrite !(@modulus_digits_correct _ modulus _ _) by typeclasses eauto ]. +Ltac pull_zlike_decode := + let modulus := get_modulus in + repeat first [ match goal with + | [ |- context G[modulus] ] + => let G' := context G[decode_small modulus_digits] in + cut G'; [ rewrite !modulus_digits_correct by typeclasses eauto; exact (fun x => x) | ] + end + | erewrite <- !Mod_SmallBound_correct by typeclasses eauto + | erewrite <- !DivBy_SmallBound_correct by typeclasses eauto + | erewrite <- !DivBy_SmallerBound_correct by typeclasses eauto + | erewrite <- !Mul_correct by typeclasses eauto + | erewrite <- !CarryAdd_correct_fst by typeclasses eauto + | erewrite <- !CarryAdd_correct_snd by typeclasses eauto + | erewrite <- !ConditionalSubtract_correct by typeclasses eauto + | erewrite <- !CarrySubSmall_correct_fst by typeclasses eauto + | erewrite <- !CarrySubSmall_correct_snd by typeclasses eauto + | erewrite <- !ConditionalSubtractModulus_correct by typeclasses eauto + | erewrite <- !ConditionalSubtractModulus_correct' by typeclasses eauto + | erewrite <- !(@modulus_digits_correct _ modulus _ _) by typeclasses eauto ]. diff --git a/src/LegacyArithmetic/ZBoundedZ.v b/src/LegacyArithmetic/ZBoundedZ.v new file mode 100644 index 000000000..fef654f47 --- /dev/null +++ b/src/LegacyArithmetic/ZBoundedZ.v @@ -0,0 +1,88 @@ +(*** ℤ can be a bounded ℤ-Like type *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.LegacyArithmetic.ZBounded. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. + +Local Open Scope Z_scope. + +Global Instance ZZLikeOps small_bound_exp smaller_bound_exp modulus : ZLikeOps (2^small_bound_exp) (2^smaller_bound_exp) modulus + := { LargeT := Z; + SmallT := Z; + modulus_digits := modulus; + decode_large x := x; + decode_small x := x; + Mod_SmallBound x := Z.pow2_mod x small_bound_exp; + DivBy_SmallBound x := Z.shiftr x small_bound_exp; + DivBy_SmallerBound x := Z.shiftr x smaller_bound_exp; + Mul x y := (x * y)%Z; + CarryAdd x y := dlet xpy := x + y in + ((2^small_bound_exp * 2^small_bound_exp <=? xpy), Z.pow2_mod xpy (2 * small_bound_exp)); + CarrySubSmall x y := dlet xmy := x - y in (xmy vm_compute; reflexivity : typeclass_instances. + +Local Ltac pre_t := + unfold cls_is_true, Let_In in *; Z.ltb_to_lt; + match goal with + | [ H : ?smaller_bound_exp <= ?small_bound_exp |- _ ] + => is_var smaller_bound_exp; is_var small_bound_exp; + assert (2^smaller_bound_exp <= 2^small_bound_exp) by auto with zarith; + assert (2^small_bound_exp * 2^smaller_bound_exp <= 2^small_bound_exp * 2^small_bound_exp) by auto with zarith + end. + +Local Ltac t_step := + first [ progress simpl in * + | progress intros + | progress autorewrite with push_Zpow Zshift_to_pow in * + | rewrite Z.pow2_mod_spec by omega + | progress Z.ltb_to_lt + | progress unfold Let_In in * + | solve [ auto with zarith ] + | nia + | progress break_match ]. +Local Ltac t := pre_t; repeat t_step. + +Global Instance ZZLikeProperties {small_bound_exp smaller_bound_exp modulus} + {Hss : cls_is_true (0 <=? smaller_bound_exp)} + {Hs : cls_is_true (0 <=? small_bound_exp)} + {Hs_ss : cls_is_true (smaller_bound_exp <=? small_bound_exp)} + {Hmod0 : cls_is_true (0 <=? modulus)} + {Hmod1 : cls_is_true (modulus