diff options
author | jadep <jade.philipoom@gmail.com> | 2016-09-20 23:32:46 -0400 |
---|---|---|
committer | jadep <jade.philipoom@gmail.com> | 2016-09-21 13:44:10 -0400 |
commit | 3482333812490f41f2bb962fa1c9a48811ec189f (patch) | |
tree | a38009fd5063924e32f5ff1a11864713c626e6f6 /src/ModularArithmetic | |
parent | 639e6cc7cf989bf88c35cbffe2d5ac71e527d479 (diff) |
Proved specification of constant-time modulus comparison (except for one ZUtil lemma)
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemList.v | 4 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemListProofs.v | 226 |
2 files changed, 225 insertions, 5 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemList.v b/src/ModularArithmetic/ModularBaseSystemList.v index 836c7644c..a472c3534 100644 --- a/src/ModularArithmetic/ModularBaseSystemList.v +++ b/src/ModularArithmetic/ModularBaseSystemList.v @@ -57,11 +57,11 @@ Section Defs. are less than 2 ^ their respective limb width. *) Fixpoint ge_modulus' us acc i := match i with - | O => andb (Z.ltb (modulus_digits [0]) (us [0])) acc + | O => andb (Z.leb (modulus_digits [0]) (us [0])) acc | S i' => ge_modulus' us (andb (Z.eqb (modulus_digits [i]) (us [i])) acc) i' end. - Definition ge_modulus us := ge_modulus' us true (length base - 1)%nat. + Definition ge_modulus us := ge_modulus' us true (length limb_widths - 1)%nat. Definition conditional_subtract_modulus (us : digits) (cond : bool) := let and_term := if cond then max_ones else 0 in diff --git a/src/ModularArithmetic/ModularBaseSystemListProofs.v b/src/ModularArithmetic/ModularBaseSystemListProofs.v index b3eff4caa..16cc2fb3c 100644 --- a/src/ModularArithmetic/ModularBaseSystemListProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemListProofs.v @@ -2,6 +2,7 @@ Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. Require Import Coq.Numbers.Natural.Peano.NPeano. Require Import Coq.Lists.List. Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.BaseSystem. Require Import Crypto.BaseSystemProofs. Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.ModularArithmetic.Pow2BaseProofs. @@ -11,6 +12,7 @@ Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. Require Import Crypto.Util.Tactics. Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Notations. Require Import Crypto.ModularArithmetic.ModularBaseSystemList. @@ -146,14 +148,232 @@ Section LengthProofs. End LengthProofs. Section ConditionalSubtractModulusProofs. - Context `{prm :PseudoMersenneBaseParams}. + Context `{prm :PseudoMersenneBaseParams} + (c_upper_bound : c - 1 < 2 ^ nth_default 0 limb_widths 0). Local Notation base := (base_from_limb_widths limb_widths). + Local Hint Resolve sum_firstn_limb_widths_nonneg. + Local Hint Resolve limb_widths_nonneg. - Lemma ge_modulus_spec : forall u, length u = length limb_widths -> - (ge_modulus u = false <-> 0 <= BaseSystem.decode base u < modulus). + Fixpoint compare' us vs i := + match i with + | O => Eq + | S i' => if Z_eq_dec (nth_default 0 us i') (nth_default 0 vs i') + then compare' us vs i' + else Z.compare (nth_default 0 us i') (nth_default 0 vs i') + end. + + (* Lexicographically compare two vectors of equal length, starting from the END of the list + (in our context, this is the most significant end). NOT constant time. *) + Definition compare us vs := compare' us vs (length us). + + (* TODO : move to ZUtil *) + Lemma add_compare_mono_r: forall n m p, (n + p ?= m + p) = (n ?= m). + Proof. + intros. + rewrite <-!(Z.add_comm p). + apply Z.add_compare_mono_l. + Qed. + + (* TODO : move to ZUtil *) + Lemma pow2_mod_id_iff : forall a n, 0 <= n -> + Z.pow2_mod a n = a <-> 0 <= a < 2 ^ n. + Proof. + intros. + rewrite Z.pow2_mod_spec by assumption. + assert (0 < 2 ^ n) by zero_bounds. + rewrite Z.mod_small_iff by omega. + split; intros; intuition omega. + Qed. + + (* TODO : move to ZUtil *) + Lemma compare_add_shiftl : forall x1 y1 x2 y2 n, 0 <= n -> + Z.pow2_mod x1 n = x1 -> Z.pow2_mod x2 n = x2 -> + x1 + (y1 << n) ?= x2 + (y2 << n) = + if Z_eq_dec y1 y2 + then x1 ?= x2 + else y1 ?= y2. + Proof. + repeat match goal with + | |- _ => progress intros + | |- _ => progress subst y1 + | |- _ => rewrite Z.shiftl_mul_pow2 by omega + | |- _ => rewrite add_compare_mono_r + | |- _ => rewrite <-Z.mul_sub_distr_r + | |- _ => break_if + | H : Z.pow2_mod _ _ = _ |- _ => rewrite pow2_mod_id_iff in H by omega + | H : ?a <> ?b |- _ = (?a ?= ?b) => + case_eq (a ?= b); rewrite ?Z.compare_eq_iff, ?Z.compare_gt_iff, ?Z.compare_lt_iff + | |- _ + (_ * _) > _ + (_ * _) => cbv [Z.gt] + | |- _ + (_ * ?x) < _ + (_ * ?x) => + apply Z.lt_sub_lt_add; apply Z.lt_le_trans with (m := 1 * x); [omega|] + | |- _ => apply Z.mul_le_mono_nonneg_r; omega + | |- _ => reflexivity + | |- _ => congruence + end. + Qed. + + Lemma decode_firstn_compare' : forall us vs i, + (i <= length limb_widths)%nat -> + length us = length limb_widths -> bounded limb_widths us -> + length vs = length limb_widths -> bounded limb_widths vs -> + (Z.compare (decode' base (firstn i us)) (decode' base (firstn i vs)) + = compare' us vs i). + Proof. + induction i; + repeat match goal with + | |- _ => progress intros + | |- _ => progress (simpl compare') + | |- _ => progress specialize_by (assumption || omega) + | |- _ => rewrite sum_firstn_0 + | |- _ => rewrite set_higher + | |- _ => rewrite nth_default_base by eauto + | |- _ => rewrite firstn_length, Min.min_l by omega + | |- _ => rewrite firstn_O + | |- _ => rewrite firstn_succ with (d := 0) by omega + | |- _ => rewrite compare_add_shiftl by + (eauto || (rewrite decode_firstn_pow2_mod, Z.pow2_mod_pow2_mod, Z.min_id by + (eauto || omega); reflexivity)) + | |- appcontext[2 ^ ?x * ?y] => replace (2 ^ x * y) with (y << x) by + (rewrite (Z.mul_comm (2 ^ x)); apply Z.shiftl_mul_pow2; eauto) + | |- _ => tauto + | |- _ => split + | |- _ => break_if + end. + Qed. + + Lemma decode_compare' : forall us vs, + length us = length limb_widths -> bounded limb_widths us -> + length vs = length limb_widths -> bounded limb_widths vs -> + (Z.compare (decode' base us) (decode' base vs) + = compare' us vs (length limb_widths)). + Proof. + intros. + rewrite <-decode_firstn_compare' by (auto || omega). + rewrite !firstn_all by auto. + reflexivity. + Qed. + + Lemma ge_modulus'_false : forall us i, + ge_modulus' us false i = false. + Proof. + induction i; intros; simpl; rewrite Bool.andb_false_r; auto. + Qed. + + (* TODO : ZUtil *) + Lemma add_pow_mod_l : forall a b c, a <> 0 -> 0 < b -> + ((a ^ b) + c) mod a = c mod a. + Proof. + intros. + replace b with (b - 1 + 1) by ring. + rewrite Z.pow_add_r, Z.pow_1_r by omega. + auto using Z.mod_add_l. + Qed. + + (* TODO : ZUtil *) + Lemma testbit_sub_pow2 : forall n i x, 0 <= i < n -> 0 < x < 2 ^ n -> + Z.testbit (2 ^ n - x) i = negb (Z.testbit (x - 1) i). Proof. Admitted. + Lemma decode_modulus_digits : decode' base modulus_digits = modulus. + Proof. + cbv [modulus_digits]. + pose proof c_pos. pose proof modulus_pos. + rewrite encodeZ_spec by eauto using limb_widths_nonnil, limb_widths_good. + apply Z.mod_small. + cbv [upper_bound]. fold k. + assert (modulus = 2 ^ k - c) by (cbv [c]; ring). + omega. + Qed. + + Lemma modulus_digits_ones : forall i, (0 < i < length limb_widths)%nat -> + nth_default 0 modulus_digits i = Z.ones (nth_default 0 limb_widths i). + Proof. + repeat match goal with + | |- _ => progress (cbv [BaseSystem.decode]; intros) + | |- _ => progress autorewrite with Ztestbit + | |- _ => unique pose proof c_pos + | |- _ => unique pose proof modulus_pos + | |- _ => unique assert (modulus = 2 ^ k - c) by (cbv [c]; ring) + | |- _ => break_if + | |- _ => rewrite decode_modulus_digits + | |- _ => rewrite Z.testbit_pow2_mod + by eauto using nth_default_limb_widths_nonneg + | |- _ => rewrite Z.ones_spec by eauto using nth_default_limb_widths_nonneg + | |- _ => erewrite digit_select + by (eauto; apply bounded_encodeZ; eauto; omega) + | |- Z.testbit (2 ^ k - c) _ = _ => + rewrite testbit_sub_pow2 by (try omega; cbv [k]; + pose proof (sum_firstn_prefix_le limb_widths (S i) (length limb_widths)); + specialize_by (eauto || omega); + rewrite sum_firstn_succ_default in *; split; zero_bounds; eauto) + | |- Z.pow2_mod _ _ = Z.ones _ => apply Z.bits_inj' + | |- Z.testbit modulus ?i = true => transitivity (Z.testbit (2 ^ k - c) i) + | |- _ => congruence + end. + + replace (c - 1) with ((c - 1) mod 2 ^ nth_default 0 limb_widths 0) by (apply Z.mod_small; omega). + rewrite Z.mod_pow2_bits_high; auto. + pose proof (sum_firstn_prefix_le limb_widths 1 i). + specialize_by (eauto || omega). + rewrite !sum_firstn_succ_default, !sum_firstn_0 in *. + split; zero_bounds; eauto using nth_default_limb_widths_nonneg. + Qed. + + Lemma bounded_le_modulus_digits : forall us i, length us = length limb_widths -> + bounded limb_widths us -> (0 < i < length limb_widths)%nat -> + nth_default 0 us i <= nth_default 0 modulus_digits i. + Proof. + intros until 0; rewrite bounded_iff; intros. + rewrite modulus_digits_ones by omega. + specialize (H0 i). + rewrite Z.ones_equiv. + omega. + Qed. + + Lemma ge_modulus'_compare' : forall us, length us = length limb_widths -> bounded limb_widths us -> + forall i, (i < length limb_widths)%nat -> + (ge_modulus' us true i = false <-> compare' us modulus_digits (S i) = Lt). + Proof. + induction i; + repeat match goal with + | |- _ => progress intros + | |- _ => progress (simpl ge_modulus'; simpl compare' in *) + | |- _ => progress specialize_by omega + | |- _ => progress rewrite ?Z.compare_eq_iff, + ?Z.compare_gt_iff, ?Z.compare_lt_iff in * + | |- _ => break_if + | |- _ => rewrite Nat.sub_0_r + | |- _ => rewrite ge_modulus'_false + | |- _ => rewrite Bool.andb_true_r + | |- _ => rewrite Z.leb_compare; break_match + | |- _ => rewrite Z.eqb_compare; break_match + | |- _ => split; (congruence || omega) + | |- _ => assumption + end; + pose proof (bounded_le_modulus_digits us (S i)); + specialize_by (auto || omega); omega. + Qed. + + Lemma ge_modulus_spec : forall u, length u = length limb_widths -> + bounded limb_widths u -> + (ge_modulus u = false <-> 0 <= BaseSystem.decode base u < modulus). + Proof. + cbv [ge_modulus]; intros. + assert (0 < length limb_widths)%nat + by (pose proof limb_widths_nonnil; destruct limb_widths; + distr_length; omega || congruence). + rewrite ge_modulus'_compare' by (auto || omega). + replace (S (length limb_widths - 1)) with (length limb_widths) by omega. + rewrite <-decode_compare' + by (try (apply length_modulus_digits || apply bounded_encodeZ); eauto; + pose proof modulus_pos; omega). + rewrite Z.compare_lt_iff. + rewrite decode_modulus_digits. + repeat (split; intros; eauto using decode_nonneg). + cbv [BaseSystem.decode] in *. omega. + Qed. + Lemma conditional_subtract_modulus_spec : forall u cond, length u = length limb_widths -> BaseSystem.decode base (conditional_subtract_modulus u cond) = |