From a86f8004a280dcf5cb5c2ad15b902d63119430bb Mon Sep 17 00:00:00 2001 From: jadep Date: Mon, 13 Jun 2016 14:59:17 -0400 Subject: progress on second stage (conditional constant-time subtraction) of canonicalization proofs --- src/ModularArithmetic/ModularBaseSystemProofs.v | 819 +++++++++--------------- src/Util/ListUtil.v | 34 + src/Util/NatUtil.v | 11 + 3 files changed, 350 insertions(+), 514 deletions(-) (limited to 'src') diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 7c430417b..fcda7b750 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -1,7 +1,7 @@ Require Import Zpower ZArith. Require Import Coq.Numbers.Natural.Peano.NPeano. Require Import List. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. Require Import VerdiTactics. Require Crypto.BaseSystem. Require Import Crypto.ModularArithmetic.ModularBaseSystem Crypto.ModularArithmetic.PrimeFieldTheorems. @@ -1194,14 +1194,6 @@ Section CanonicalizationProofs. (* END proofs about third carry loop *) - Lemma nth_error_combine : forall {A B} i (x : A) (x' : B) l l', nth_error l i = Some x -> - nth_error l' i = Some x' -> nth_error (combine l l') i = Some (x, x'). - Admitted. -(* - Lemma nth_error_range : forall i r, (i < r)%nat -> - nth_error (range r) i = Some i. - Admitted. -*) Lemma carry_full_length : forall us, (length us = length base)%nat -> length (carry_full us) = length us. Proof. @@ -1223,72 +1215,7 @@ Section CanonicalizationProofs. Qed. Opaque carry_full. -(* - Lemma length_range : forall n, length (range n) = n. - Proof. - induction n; intros; auto. - simpl. - rewrite app_length, cons_length, nil_length0. - omega. - Qed. - - Lemma range0_nil : range 0 = nil. - Proof. - reflexivity. - Qed. - - Lemma range_succ : forall n, range (S n) = range n ++ n :: nil. - Proof. - reflexivity. - Qed. - - Lemma nth_default_range : forall d r n, (n < r)%nat -> nth_default d (range r) n = n. - Proof. - induction r; intro; try omega. - intros. - assert (n = r \/ n < r)%nat as cases by omega. - destruct cases; subst; rewrite range_succ, nth_default_app, length_range; break_if; try omega. - + rewrite Nat.sub_diag. - auto using nth_default_cons. - + apply IHr; omega. - Qed. - - Lemma combine_app : forall {A B} (x y : list A) (z : list B), (length (x ++ y) <= length z)%nat -> - combine (x ++ y) z = combine x z ++ combine y (skipn (length x) z). - Proof. - intros. - rewrite <- (firstn_skipn (length x) z) at 1. - rewrite combine_app_samelength by - (rewrite firstn_length, Nat.min_l; auto; rewrite app_length in *; omega). - rewrite <-combine_truncate_r; reflexivity. - Qed. - - Lemma combine_range_succ : forall l r, (S r <= length l)%nat -> - combine (range (S r)) l = (combine (range r) l) ++ (r,nth_default 0 l r) :: nil. - Proof. - intros. - simpl. - rewrite combine_app by (rewrite app_length, cons_length, length_range, nil_length0; omega). - f_equal. - rewrite length_range. - erewrite skipn_nth_default by omega. - reflexivity. - Qed. - Opaque range. - Lemma map_sub_combine_range : forall d d' f l i, (l <> nil) -> (i < length l)%nat -> - nth_default d (map (fun x => snd x - f (fst x)) (combine (range (length l)) l)) i = - nth_default d' l i - f i. - Proof. - intros until 1. - intros lt_i_length. - destruct (nth_error_length_exists_value i l lt_i_length). - erewrite nth_error_value_eq_nth_default; auto. - erewrite map_nth_error; - [ | apply nth_error_combine; try apply nth_error_range; eauto]. - erewrite nth_error_value_eq_nth_default; eauto. - Qed. -*) Lemma isFull'_false : forall us n, isFull' us false n = false. Proof. unfold isFull'; induction n; intros; rewrite Bool.andb_false_r; auto. @@ -1315,12 +1242,6 @@ Section CanonicalizationProofs. + eauto. Qed. - Lemma isFull_lower_bound_0 : forall us, isFull us = true -> - max_bound 0 - c < nth_default 0 us 0. - Proof. - eauto using isFull'_lower_bound_0. - Qed. - Lemma isFull'_true_full : forall us i j b, (i <> 0)%nat -> (i <= j)%nat -> isFull' us b j = true -> max_bound i = nth_default 0 us i. Proof. @@ -1331,14 +1252,6 @@ Section CanonicalizationProofs. + eapply IHj; eauto. Qed. - Lemma isFull_true_full : forall i us, (length us = length base) -> - (0 < i < length base)%nat -> isFull us = true -> - max_bound i = nth_default 0 us i. - Proof. - unfold isFull; intros. - eapply isFull'_true_full with (j := (length us - 1)%nat); eauto; omega. - Qed. - (* TODO : move *) Lemma N_le_1_l : forall p, (1 <= N.pos p)%N. Proof. @@ -1394,29 +1307,7 @@ Section CanonicalizationProofs. apply IHl; auto using in_cons. Qed. Hint Resolve max_ones_nonneg. -(* - Lemma sub_land_max_bound_max_ones_lower : - forall us i, (length us = length base) -> isFull us = true -> - (i < length us)%nat -> - 0 <= nth_default 0 us i - land_max_bound max_ones i. - Proof. - unfold land_max_bound; intros. - break_if. - + subst. apply Z.le_0_sub. - etransitivity. - - apply Z_land_upper_bound_r; auto. - apply Z.le_trans with (m := c - 1); omega. - - rewrite Z.add_1_r. - apply Z.le_succ_l. - auto using isFull_lower_bound_0. - + apply Z.le_0_sub. - etransitivity. - apply Z_land_upper_bound_r; auto. - apply Z.eq_le_incl. - apply isFull_true_full; auto. - omega. - Qed. -*) + (* TODO : move *) Lemma Z_le_fold_right_max : forall low l x, (forall y, In y l -> low <= y) -> In x l -> x <= fold_right Z.max low l. @@ -1465,46 +1356,6 @@ Section CanonicalizationProofs. split; auto; omega. Qed. - Lemma land_max_ones_max_bound_sub_c : - Z.land max_ones (max_bound 0 - c + 1) = max_bound 0 - c + 1. - Proof. - apply land_max_ones_noop with (i := 0%nat). - rewrite <-max_bound_log_cap. - split; auto; try omega. - Qed. -(* - Lemma land_max_bound_pos : forall i, (i < length base)%nat -> - 0 < land_max_bound max_ones i. - Proof. - unfold land_max_bound; intros. - break_if. - + subst. - rewrite land_max_ones_max_bound_sub_c by assumption. - apply Z.lt_le_trans with (m := c); auto. omega. - + rewrite land_max_ones_max_bound by assumption. - auto using max_bound_pos. - Qed. - Local Hint Resolve land_max_bound_pos. - - - Lemma sub_land_max_bound_max_ones_upper : - forall us i, nth_default 0 us i <= max_bound i -> - (length us = length base) -> (i < length us)%nat -> - nth_default 0 us i - land_max_bound max_ones i < 2 ^ log_cap i. - Proof. - intros. - eapply Z.lt_trans. - + eapply Z.lt_sub_pos. - apply land_max_bound_pos; auto; omega. - + rewrite <-max_bound_log_cap. - omega. - Qed. - - - Lemma land_max_bound_0 : forall i, land_max_bound 0 i = 0. - Admitted. -*) - Lemma full_isFull'_true : forall j us, (length us = length base) -> ( max_bound 0 - c < nth_default 0 us 0 /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_bound i)) -> @@ -1531,15 +1382,17 @@ Section CanonicalizationProofs. omega. Qed. - Opaque isFull' (* TODO isFull *) max_ones. - - (* TODO : move *) - Lemma length_nonzero_nonnil : forall {A} (l : list A), (0 < length l)%nat -> - l <> nil. + Lemma isFull'_true_step : forall us j, isFull' us true (S j) = true -> + isFull' us true j = true. Proof. - destruct l; boring; congruence. + simpl; intros ? ? succ_true. + destruct (max_bound (S j) =? nth_default 0 us (S j)); auto. + rewrite isFull'_false in succ_true. + congruence. Qed. + Opaque isFull' max_ones. + Lemma carry_full_3_length : forall us, (length us = length base) -> length (carry_full (carry_full (carry_full us))) = length us. Proof. @@ -1548,137 +1401,34 @@ Section CanonicalizationProofs. Qed. Local Hint Resolve carry_full_3_length. - Lemma freeze_in_bounds : forall us, - pre_carry_bounds us -> (length us = length base) -> - carry_done (freeze us). - Proof. - unfold freeze; intros. - rewrite carry_done_bounds; intro i. - destruct (lt_dec i (length us)). - + rewrite map_sub_combine_range with (d' := 0) by (try apply length_nonzero_nonnil; - (repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto)); auto; try omega). - break_if. - - split; [apply sub_land_max_bound_max_ones_lower - |apply sub_land_max_bound_max_ones_upper ]; - rewrite ?carry_full_3_length; auto. - apply carry_full_3_bounds; auto; omega. - - rewrite land_max_bound_0, <-max_bound_log_cap, Z.lt_succ_r, Z.sub_0_r. - apply carry_full_3_bounds; auto; omega. - + rewrite nth_default_out_of_bounds; [ split; auto; omega | ]. - rewrite map_length, combine_length, length_range, Nat.min_id. - repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto). - omega. - Qed. + Lemma nth_default_map2 : forall {A B C} (f : A -> B -> C) ls1 ls2 i d d1 d2, + nth_default d (map2 f ls1 ls2) i = + if lt_dec i (min (length ls1) (length ls2)) + then f (nth_default d1 ls1 i) (nth_default d2 ls2 i) + else d. + Admitted. - Lemma freeze_length : forall us, (length us = length base) -> - length (freeze us) = length us. - Proof. - unfold freeze; intros. - rewrite map_length, combine_length, length_range, Nat.min_id. - auto. - Qed. + Lemma map2_length : forall A B C (f : A -> B -> C) ls1 ls2, + length (map2 f ls1 ls2) = min (length ls1) (length ls2). + Admitted. - (* TODO : move *) - Lemma nth_default_same_lists_same : (* TODO : rename if this works *) - forall {A} d (l' l : list A), (length l = length l') -> - (forall i, nth_default d l i = nth_default d l' i) -> - l = l'. - Proof. - induction l'; intros until 0; intros lengths_equal nth_default_match. - + apply length0_nil; auto. - + destruct l; rewrite ?nil_length0, !cons_length in lengths_equal; - [congruence | ]. - pose proof (nth_default_match 0%nat) as nth_default_match_0. - rewrite !nth_default_cons in nth_default_match_0. - f_equal; auto. - apply IHl'; [ omega | ]. - intros. - specialize (nth_default_match (S i)). - rewrite !nth_default_cons_S in nth_default_match. - assumption. - Qed. + Lemma modulus_digits_length : length modulus_digits = length base. + Admitted. - Lemma not_full_no_change : forall us, length us = length base -> - map (fun x : nat * Z => snd x - land_max_bound 0 (fst x)) - (combine (range (length us)) us) = us. - Proof. - intros ? lengths_eq. - apply nth_default_same_lists_same with (d := 0). - + rewrite map_length, combine_length, length_range, Nat.min_id; auto. - + intros. - destruct (lt_dec i (length us)). - - erewrite map_sub_combine_range by (auto; intro false_eq; subst; - rewrite nil_length0 in lengths_eq; omega). - rewrite land_max_bound_0. - apply Z.sub_0_r. - - rewrite !nth_default_out_of_bounds; try omega. - rewrite map_length, combine_length, length_range, Nat.min_id; omega. - Qed. + (* Helps with solving goals of the form [x = y -> min x y = x] or [x = y -> min x y = y] *) + Local Hint Resolve Nat.eq_le_incl eq_le_incl_rev. - (* TODO : move *) - Lemma map_cons : forall {A B} (f : A -> B) x xs, map f (x :: xs) = f x :: (map f xs). - Proof. - auto. - Qed. - - (* TODO : move *) - Lemma firstn_firstn : forall {A} m n (l : list A), (n <= m)%nat -> - firstn n (firstn m l) = firstn n l. - Proof. - induction m; destruct n; intros; try omega; auto. - destruct l; auto. - simpl. - f_equal. - apply IHm; omega. - Qed. + Hint Rewrite app_length cons_length map2_length modulus_digits_length length_zeros + map_length combine_length firstn_length map_app : lengths. + Ltac simpl_lengths := autorewrite with lengths; + repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto); + auto using Min.min_l; auto using Min.min_r. - (* TODO : move *) - Lemma firstn_succ : forall n l, (n < length l)%nat -> - firstn (S n) l = (firstn n l) ++ nth_default 0 l n :: nil. + Lemma freeze_length : forall us, (length us = length base) -> + length (freeze us) = length us. Proof. - induction n; destruct l; rewrite ?(@nil_length0 Z); intros; try omega. - + rewrite nth_default_cons; auto. - + simpl. - rewrite nth_default_cons_S. - rewrite <-IHn by (rewrite cons_length in *; omega). - reflexivity. + unfold freeze; intros; simpl_lengths. Qed. -(* -Print BaseSystem.accumulate. -SearchAbout combine range. -mapi : forall {A B}, (nat -> A -> B) -> list A -> list B -mapi (fun x y => (x, y)) ls -map2 : forall {A B C}, (A -> B -> C) -> list A -> list B -> list C - -BaseSystem.decode u (map2 (fun x y => x - y) v w) -= BaseSystem.decode u v - BaseSystem.decode u w - -map2 f ls1 ls2 = map (fun xy => f (fst xy) (snd xy)) (combine ls1 ls2) - -map2 f (map g ls1) ls2 = map2 (fun x y => f (g x) y) ls1 ls2 -map2 f ls1 (map g ls2) = map2 (fun x y => f x (g y)) ls1 ls2 - -Locate mapi. -*) -Print map. - - Fixpoint mapi' {A B} (f : nat -> A -> B) i (l : list A) : list B := - match l with - | nil => nil - | x :: l' => f i x :: mapi' f (S i) l' - end. - - Definition mapi {A B} (f : nat -> A -> B) (l : list A) : list B := mapi' f 0%nat l. - - - Fixpoint map2 {A B C} (f : A -> B -> C) (la : list A) (lb : list B) : list C := - match la with - | nil => nil - | a :: la' => match lb with - | nil => nil - | b :: lb' => f a b :: map2 f la' lb' - end - end. Lemma map2_combine : forall {A B C} (f : A -> B -> C) ls1 ls2, map2 f ls1 ls2 = map (fun xy => f (fst xy) (snd xy)) (combine ls1 ls2). @@ -1692,119 +1442,114 @@ Print map. map2 f ls1 (map g ls2) = map2 (fun x y => f x (g y)) ls1 ls2. Admitted. - (* TODO : rewrite using the above? *) - - Hint Rewrite app_length cons_length map_length combine_length length_range firstn_length map_app : lengths. - - Lemma decode_subtract_elementwise: forall f r l, (length l = length base) -> - (r <= length l)%nat -> - BaseSystem.decode (firstn r base) (map (fun x => snd x - f (fst x)) (combine (range r) l)) = - BaseSystem.decode (firstn r base) (firstn r l) - BaseSystem.decode (firstn r base) (map f (range r)). - Proof. - induction r; intros. - + rewrite range0_nil. - cbv [combine map BaseSystem.decode sum_firstn firstn fold_right]. - rewrite decode_nil. - auto. - + rewrite combine_range_succ by assumption. - rewrite (firstn_succ _ l) by omega. - rewrite range_succ. - rewrite !map_app, !decode'_splice. - autorewrite with lengths. - rewrite Min.min_l, firstn_firstn, firstn_succ by omega. - rewrite skipn_app_sharp by (rewrite firstn_length, Nat.min_l; omega). - simpl. - rewrite !decode'_cons, decode_nil, IHr by omega. - unfold BaseSystem.decode. - ring. + Lemma decode_firstn_succ : forall n us, (length us = length base) -> + (n < length base)%nat -> + BaseSystem.decode' (firstn (S n) base) (firstn (S n) us) = + BaseSystem.decode' (firstn n base) (firstn n us) + + nth_default 0 base n * nth_default 0 us n. + Proof. + intros. + rewrite !firstn_succ with (d := 0) by omega. + rewrite base_app, firstn_app. + autorewrite with lengths; rewrite !Min.min_l by omega. + rewrite Nat.sub_diag, firstn_firstn, firstn0, app_nil_r by omega. + rewrite skipn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega). + rewrite decode'_cons, decode_nil, Z.add_0_r. + reflexivity. Qed. - Definition modulus_digit i := if (eq_nat_dec i 0) then max_bound i - c + 1 else max_bound i. - (* TODO : maybe use this more? *) + Local Hint Resolve sum_firstn_limb_widths_nonneg. + Local Hint Resolve limb_widths_nonneg. + Local Hint Resolve nth_error_value_In. - Lemma modulus_digit_nonneg : forall i, 0 <= modulus_digit i. - Proof. - unfold modulus_digit; intros; break_if; auto; subst; omega. - Qed. - Hint Resolve modulus_digit_nonneg. - - Lemma modulus_digit_lt_cap : forall i, - modulus_digit i < 2 ^ log_cap i. + (* TODO : move *) + Lemma sum_firstn_all_succ : forall n l, (length l <= n)%nat -> + sum_firstn l (S n) = sum_firstn l n. + Admitted. + + Lemma decode_carry_done_upper_bound' : forall n us, carry_done us -> + (length us = length base) -> + BaseSystem.decode (firstn n base) (firstn n us) < 2 ^ (sum_firstn limb_widths n). Proof. - unfold modulus_digit; intros; rewrite <- max_bound_log_cap; break_if; omega. + induction n; intros; [ cbv; congruence | ]. + destruct (lt_dec n (length base)) as [ n_lt_length | ? ]. + + rewrite decode_firstn_succ; auto. + rewrite base_length in n_lt_length. + destruct (nth_error_length_exists_value _ _ n_lt_length). + erewrite sum_firstn_succ; eauto. + rewrite Z.pow_add_r; eauto. + rewrite nth_default_base by (rewrite base_length; assumption). + rewrite Z.lt_add_lt_sub_r. + eapply Z.lt_le_trans; eauto. + rewrite Z.mul_comm at 1. + rewrite <-Z.mul_sub_distr_l. + rewrite <-Z.mul_1_r at 1. + apply Z.mul_le_mono_nonneg_l; [ apply Z.pow_nonneg; omega | ]. + replace 1 with (Z.succ 0) by reflexivity. + rewrite Z.le_succ_l, Z.lt_0_sub. + match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H; specialize (H n) end. + replace x with (log_cap n); try intuition. + rewrite log_cap_eq. + apply nth_error_value_eq_nth_default; auto. + + repeat erewrite firstn_all_strong by omega. + rewrite sum_firstn_all_succ by (rewrite <-base_length; omega). + eapply Z.le_lt_trans; [ | eauto]. + repeat erewrite firstn_all_strong by omega. + omega. Qed. - Hint Resolve modulus_digit_lt_cap. - Lemma modulus_digit_land_max_bound_max_ones : forall i, - land_max_bound max_ones i = modulus_digit i. + Lemma decode_carry_done_upper_bound : forall us, carry_done us -> + (length us = length base) -> BaseSystem.decode base us < 2 ^ k. Proof. - unfold land_max_bound; intros. - eapply land_max_ones_noop; eauto. + unfold k; intros. + rewrite <-(firstn_all_strong base (length limb_widths)) by (rewrite <-base_length; auto). + rewrite <-(firstn_all_strong us (length limb_widths)) by (rewrite <-base_length; auto). + auto using decode_carry_done_upper_bound'. Qed. - Lemma decode_modulus_digit_partial : forall n, (0 < n <= length base)%nat -> - BaseSystem.decode (firstn n base) (map modulus_digit (range (length base))) = - 2 ^ (sum_firstn limb_widths n) - c. - Proof. - induction n; intros; try omega. - rewrite firstn_succ by omega. - rewrite base_app. - rewrite decode'_truncate, firstn_length, Min.min_l in * by omega. - rewrite firstn_firstn by omega. - rewrite skipn_nth_default with (d := 0) by (autorewrite with lengths; omega). - rewrite decode'_cons, decode_base_nil, Z.add_0_r. - erewrite map_nth_default with (y := 0) (x := 0%nat) by - (autorewrite with lengths; omega). - rewrite nth_default_range by (autorewrite with lengths; omega). - rewrite nth_default_base by omega. - unfold modulus_digit at 2; break_if. - + subst. - clear IHn. - cbv [firstn BaseSystem.decode' combine fold_right]. - destruct (nth_error_length_exists_value 0 limb_widths); try (rewrite <-base_length; omega). - erewrite sum_firstn_succ; eauto. - replace (max_bound 0) with (2 ^ log_cap 0 - 1) by (rewrite <-max_bound_log_cap; omega). - rewrite log_cap_eq. - erewrite nth_error_value_eq_nth_default; eauto. - rewrite Z.pow_add_r by (auto using sum_firstn_limb_widths_nonneg; apply limb_widths_nonneg; - auto using (nth_error_value_In 0)). - cbv [sum_firstn firstn fold_right]. - ring. - + rewrite IHn by (auto; omega). - replace (max_bound n) with (2 ^ log_cap n - 1) by (rewrite <-max_bound_log_cap; omega). - rewrite log_cap_eq. - destruct (nth_error_length_exists_value n limb_widths); try (rewrite <- base_length; omega). - erewrite sum_firstn_succ; eauto. - erewrite nth_error_value_eq_nth_default; eauto. - rewrite Z.pow_add_r by (auto using sum_firstn_limb_widths_nonneg; apply limb_widths_nonneg; - auto using (nth_error_value_In n)). - ring. + Lemma decode_carry_done_lower_bound' : forall n us, carry_done us -> + (length us = length base) -> + 0 <= BaseSystem.decode (firstn n base) (firstn n us). + Proof. + induction n; intros; [ cbv; congruence | ]. + destruct (lt_dec n (length base)) as [ n_lt_length | ? ]. + + rewrite decode_firstn_succ by auto. + zero_bounds. + - rewrite nth_default_base by assumption. + apply Z.pow_nonneg; omega. + - match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H; specialize (H n) end. + intuition. + + eapply Z.le_trans; [ apply IHn; eauto | ]. + repeat rewrite firstn_all_strong by omega. + omega. Qed. - Lemma decode_map_modulus_digit : - BaseSystem.decode base (map modulus_digit (range (length base))) = modulus. + Lemma decode_carry_done_lower_bound : forall us, carry_done us -> + (length us = length base) -> 0 <= BaseSystem.decode base us. Proof. - erewrite <-(firstn_all _ base) at 1 by reflexivity. - rewrite decode_modulus_digit_partial by omega. - rewrite base_length. - fold k; unfold c. - ring. + intros. + rewrite <-(firstn_all_strong base (length limb_widths)) by (rewrite <-base_length; auto). + rewrite <-(firstn_all_strong us (length limb_widths)) by (rewrite <-base_length; auto). + auto using decode_carry_done_lower_bound'. Qed. - Lemma decode_subtract_modulus_elementwise : forall us, (length us = length base) -> - BaseSystem.decode base - (map (fun x0 : nat * Z => snd x0 - land_max_bound max_ones (fst x0)) - (combine (range (length us)) us)) = BaseSystem.decode base us - modulus. + + Lemma nth_default_modulus_digits : forall d i, + nth_default d modulus_digits i = + if lt_dec i (length base) + then (if (eq_nat_dec i 0) then max_bound i - c + 1 else max_bound i) + else d. + Admitted. + + Lemma carry_done_modulus_digits : carry_done modulus_digits. Proof. + apply carry_done_bounds. intros. - replace base with (firstn (length us) base) at 1 by (auto using firstn_all). - rewrite decode_subtract_elementwise by omega. - rewrite !firstn_all by auto. - f_equal. - erewrite map_ext; [ | eapply modulus_digit_land_max_bound_max_ones ]. - replace (length us) with (length base) by assumption. - exact decode_map_modulus_digit. + rewrite nth_default_modulus_digits. + break_if; [ | split; auto; omega]. + break_if; subst; split; auto; try rewrite <- max_bound_log_cap; omega. Qed. + Hint Resolve carry_done_modulus_digits. (* TODO : move *) Lemma decode_mod : forall us vs x, (length us = length base) -> (length vs = length base) -> @@ -1818,6 +1563,37 @@ Print map. assumption. Qed. + Lemma decode_map2_sub : forall us vs, + (length us = length base) -> (length vs = length base) -> + BaseSystem.decode base (map2 (fun x y => x - y) us vs) + = BaseSystem.decode base us - BaseSystem.decode base vs. + Admitted. + + Lemma decode_modulus_digits : BaseSystem.decode base modulus_digits = modulus. + Admitted. + + Lemma map_land_max_ones_modulus_digits : map (Z.land max_ones) modulus_digits = modulus_digits. + Admitted. + + Lemma map_land_zero : forall ls, map (Z.land 0) ls = BaseSystem.zeros (length ls). + Admitted. + + Lemma carry_full_preserves_Fdecode : forall us x, (length us = length base) -> + decode us = x -> decode (carry_full us) = x. + Proof. + intros. + apply carry_full_preserves_rep; auto. + unfold rep; auto. + Qed. + + Lemma Fdecode_decode_mod : forall us x, (length us = length base) -> + decode us = x -> BaseSystem.decode base us mod modulus = x. + Proof. + unfold decode; intros ? ? ? decode_us. + rewrite <-decode_us. + apply FieldToZ_ZToField. + Qed. + Lemma freeze_preserves_rep : forall us x, (length us = length base) -> rep us x -> rep (freeze us) x. Proof. @@ -1828,19 +1604,26 @@ Print map. + apply decode_mod with (us := carry_full (carry_full (carry_full us))). - rewrite carry_full_3_length; auto. - autorewrite with lengths. - rewrite Nat.min_id. - rewrite carry_full_3_length; auto. + apply Min.min_r. + simpl_lengths; omega. - repeat apply carry_full_preserves_rep; repeat rewrite carry_full_length; auto. unfold rep; intuition. - - rewrite decode_subtract_modulus_elementwise by (rewrite carry_full_3_length; auto). + - rewrite decode_map2_sub by (simpl_lengths; omega). + rewrite map_land_max_ones_modulus_digits. + rewrite decode_modulus_digits. destruct (Z_eq_dec modulus 0); [ subst; rewrite !Zmod_0_r; reflexivity | ]. rewrite <-Z.add_opp_r. replace (-modulus) with (-1 * modulus) by ring. symmetry; auto using Z.mod_add. - + rewrite not_full_no_change by (rewrite carry_full_3_length; auto). - repeat (apply carry_full_preserves_rep; repeat rewrite carry_full_length; auto). - unfold rep; auto. + + eapply decode_mod; eauto. + simpl_lengths. + rewrite map_land_zero, decode_map2_sub, zeros_rep, Z.sub_0_r by simpl_lengths. + match goal with H : decode ?us = ?x |- _ => erewrite Fdecode_decode_mod; eauto; + do 3 apply carry_full_preserves_Fdecode in H; simpl_lengths + end. + erewrite Fdecode_decode_mod; eauto; simpl_lengths. Qed. + Hint Resolve freeze_preserves_rep. Lemma isFull_true_iff : forall us, (length us = length base) -> (isFull us = true <-> max_bound 0 - c < nth_default 0 us 0 @@ -1863,22 +1646,6 @@ Print map. (in our context, this is the most significant end). *) Definition compare us vs := compare' us vs (length us). - Lemma decode_firstn_succ : forall n us, (length us = length base) -> - (n < length base)%nat -> - BaseSystem.decode' (firstn (S n) base) (firstn (S n) us) = - BaseSystem.decode' (firstn n base) (firstn n us) + - nth_default 0 base n * nth_default 0 us n. - Proof. - intros. - rewrite !firstn_succ by omega. - rewrite base_app, firstn_app. - autorewrite with lengths; rewrite !Min.min_l by omega. - rewrite Nat.sub_diag, firstn_firstn, firstn0, app_nil_r by omega. - rewrite skipn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega). - rewrite decode'_cons, decode_nil, Z.add_0_r. - reflexivity. - Qed. - Lemma decode_lt_next_digit : forall us n, (length us = length base) -> (n < length base)%nat -> (n < length us)%nat -> carry_done us -> @@ -1906,24 +1673,25 @@ Print map. Lemma highest_digit_determines : forall us vs n x, (x < 0) -> (length us = length base) -> + (length vs = length base) -> (n < length us)%nat -> carry_done us -> (n < length vs)%nat -> carry_done vs -> - BaseSystem.decode' (firstn n base) (firstn n us) + + BaseSystem.decode (firstn n base) (firstn n us) + nth_default 0 base n * x - - BaseSystem.decode' (firstn n base) (firstn n vs) < 0. + BaseSystem.decode (firstn n base) (firstn n vs) < 0. Proof. intros. eapply Z.le_lt_trans. - apply Z.le_sub_nonneg. - admit. (* TODO : decode' is nonnegative *) - eapply Z.le_lt_trans. - eapply Z.add_le_mono with (q := nth_default 0 base n * -1); [ apply Z.le_refl | ]. - apply Z.mul_le_mono_nonneg_l; try omega. - admit. (* TODO : 0 <= nth_default 0 base n *) - ring_simplify. - apply Z.lt_sub_0. - apply decode_lt_next_digit; auto. - omega. + + apply Z.le_sub_nonneg. + apply decode_carry_done_lower_bound'; auto. + + eapply Z.le_lt_trans. + - eapply Z.add_le_mono with (q := nth_default 0 base n * -1); [ apply Z.le_refl | ]. + apply Z.mul_le_mono_nonneg_l; try omega. + rewrite nth_default_base by omega; apply Z.pow_nonneg; omega. + - ring_simplify. + apply Z.lt_sub_0. + apply decode_lt_next_digit; auto. + omega. Qed. Lemma Z_compare_decode_step_eq : forall n us vs, @@ -1938,7 +1706,7 @@ Print map. Proof. intros until 3; intro nth_default_eq. destruct (lt_dec n (length us)); try omega. - rewrite firstn_succ, !base_app by omega. + rewrite firstn_succ with (d := 0), !base_app by omega. autorewrite with lengths; rewrite Min.min_l by omega. do 2 (rewrite skipn_nth_default with (d := 0) by omega; rewrite decode'_cons, decode_base_nil, Z.add_0_r). @@ -2021,8 +1789,6 @@ Print map. + assumption. Qed. - Transparent isFull'. - Print compare'. Lemma compare'_succ : forall us j vs, compare' us vs (S j) = if Z.eq_dec (nth_default 0 us j) (nth_default 0 vs j) then compare' us vs j @@ -2032,14 +1798,15 @@ Print map. Qed. - Lemma compare'_firstn_r : forall us j vs, (j <= length vs)%nat -> + Lemma compare'_firstn_r_small_index : forall us j vs, (j <= length vs)%nat -> compare' us vs j = compare' us (firstn j vs) j. Proof. induction j; intros; auto. - rewrite !compare'_succ. + rewrite !compare'_succ by omega. rewrite firstn_succ by omega. rewrite nth_default_app. - autorewrite with lengths; rewrite Min.min_l by omega. + simpl_lengths. + rewrite Min.min_l by omega. destruct (lt_dec j j); try omega. rewrite Nat.sub_diag. rewrite nth_default_cons. @@ -2050,13 +1817,15 @@ Print map. apply IHj; omega. Qed. - Lemma isFull'_true_step : forall us j, isFull' us true (S j) = true -> - isFull' us true j = true. + Lemma compare'_firstn_r : forall us j vs, + compare' us vs j = compare' us (firstn j vs) j. Proof. - simpl; intros ? ? succ_true. - destruct (max_bound (S j) =? nth_default 0 us (S j)); auto. - rewrite isFull'_false in succ_true. - congruence. + intros. + destruct (le_dec j (length vs)). + + auto using compare'_firstn_r_small_index. + + f_equal. symmetry. + apply firstn_all_strong. + omega. Qed. Lemma compare'_not_Lt : forall us vs j, j <> 0%nat -> @@ -2080,67 +1849,47 @@ Print map. specialize (H j) end; omega. Qed. - Lemma nth_default_map_range : forall f n r, (n < r)%nat -> - nth_default 0 (map f (range r)) n = f n. - Proof. - intros. - rewrite map_nth_default with (x := 0%nat) by (autorewrite with lengths; omega). - rewrite nth_default_range by omega. - reflexivity. - Qed. - - Lemma isFull'_compare' : forall us j, j <> 0%nat -> (length us = length base) -> carry_done us -> - (isFull' us true (j - 1) = true <-> - compare' us (map modulus_digit (range j)) j <> Lt). + Lemma isFull'_compare' : forall us j, j <> 0%nat -> (length us = length base) -> + (j <= length base)%nat -> carry_done us -> + (isFull' us true (j - 1) = true <-> compare' us modulus_digits j <> Lt). Proof. unfold compare; induction j; intros; try congruence. replace (S j - 1)%nat with j by omega. - (* rewrite isFull'_true_iff by assumption; *) split; intros. + simpl. - break_if. - - rewrite compare'_firstn_r by (autorewrite with lengths; omega). - rewrite range_succ, map_app, firstn_app. - autorewrite with lengths. - rewrite Nat.sub_diag, app_nil_r. - rewrite firstn_all by (autorewrite with lengths; reflexivity). - destruct (eq_nat_dec j 0); [ subst; simpl; try congruence | ]. - apply IHj; auto. + break_if; [destruct (eq_nat_dec j 0) | ]. + - subst. cbv; congruence. + - apply IHj; auto; try omega. apply isFull'_true_step. replace (S (j - 1)) with j by omega; auto. - - match goal with |- appcontext[?a ?= ?b] => case_eq (a ?= b) end; - intros compare_eq; try congruence. - rewrite Z.compare_lt_iff in compare_eq. - rewrite nth_default_map_range in * by omega. - match goal with H : isFull' _ _ _ = true |- _ => apply isFull'_true_iff in H; auto; destruct H end. - - destruct (eq_nat_dec j 0). - * subst. cbv [modulus_digit] in compare_eq. - break_if; try congruence. omega. - * assert (0 < j <= j)%nat as j_range by omega. - specialize (H3 j j_range). - unfold modulus_digit in n. - break_if; omega. + - rewrite nth_default_modulus_digits in *. + repeat (break_if; try omega). + * subst. + match goal with H : isFull' _ _ _ = true |- _ => + apply isFull'_lower_bound_0 in H end. + apply Z.compare_ge_iff. + omega. + * match goal with H : isFull' _ _ _ = true |- _ => + apply isFull'_true_iff in H; try assumption; destruct H as [? eq_max_bound] end. + specialize (eq_max_bound j). + omega. + apply isFull'_true_iff; try assumption. match goal with H : compare' _ _ _ <> Lt |- _ => apply compare'_not_Lt in H; [ destruct H as [Hdigit0 Hnonzero] | | ] end. - - rewrite nth_default_map_range in * by omega. - split; [ unfold modulus_digit in *; break_if; omega | ]. - intros i i_range. - assert (0 < i < S j)%nat as i_range' by omega. - specialize (Hnonzero i i_range'). - rewrite nth_default_map_range in * by omega. - unfold modulus_digit in Hnonzero; break_if; omega. + - split; [ | intros i i_range; assert (0 < i < S j)%nat as i_range' by omega; + specialize (Hnonzero i i_range')]; + rewrite nth_default_modulus_digits in *; + repeat (break_if; try omega). - congruence. - - intros; rewrite nth_default_map_range by omega. - unfold modulus_digit; break_if; try omega. + - intros. + rewrite nth_default_modulus_digits. + repeat (break_if; try omega). rewrite <-Z.lt_succ_r with (m := max_bound i). rewrite max_bound_log_cap; apply carry_done_bounds. assumption. Qed. Lemma isFull_compare : forall us, (length us = length base) -> carry_done us -> - (isFull us = true <-> - compare us (map modulus_digit (range (length base))) <> Lt). + (isFull us = true <-> compare us modulus_digits <> Lt). Proof. unfold compare, isFull; intros ? lengths_eq. intros. rewrite lengths_eq. @@ -2150,85 +1899,127 @@ Print map. Lemma isFull_decode : forall us, (length us = length base) -> carry_done us -> (isFull us = true <-> - (BaseSystem.decode base us ?= BaseSystem.decode base (map modulus_digit (range (length base)))) <> Lt). + (BaseSystem.decode base us ?= BaseSystem.decode base modulus_digits <> Lt)). Proof. intros. - rewrite decode_compare; autorewrite with lengths; auto; - [ apply isFull_compare; auto | ]. - rewrite carry_done_bounds; intro i. - destruct (lt_dec i (length base)). - + rewrite nth_default_map_range; auto. - + rewrite nth_default_out_of_bounds by (autorewrite with lengths; omega). - split; auto; omega. + rewrite decode_compare; autorewrite with lengths; auto. + apply isFull_compare; auto. Qed. - Lemma isFull_false_upper_bound : forall us, (length us = length base) -> carry_done us -> - isFull us = false -> + Lemma isFull_false_upper_bound : forall us, (length us = length base) -> + carry_done us -> isFull us = false -> BaseSystem.decode base us < modulus. Proof. intros. destruct (Z_lt_dec (BaseSystem.decode base us) modulus) as [? | nlt_modulus]; [assumption | exfalso]. apply Z.compare_nlt_iff in nlt_modulus. - rewrite <-decode_map_modulus_digit in nlt_modulus at 2. + rewrite <-decode_modulus_digits in nlt_modulus at 2. apply isFull_decode in nlt_modulus; try assumption; congruence. Qed. -(* Road map: - * x prove isFull us = false -> us < modulus - * _ prove (carry_full^3 us) < 2 * modulus - *) - - Definition twoKMinusOne := mapi (fun _ => max_bound i - - Lemma bounded_digits_lt_2modulus : forall us, (length us = length base) -> carry_done us -> - BaseSystem.decode base us < 2 ^ k. + Lemma isFull_true_lower_bound : forall us, (length us = length base) -> + carry_done us -> isFull us = true -> + modulus <= BaseSystem.decode base us. Proof. - unfold k. - SearchAbout sum_firstn limb_widths. + intros. + rewrite <-decode_modulus_digits at 1. + apply Z.compare_ge_iff. + apply isFull_decode; auto. Qed. - Lemma bounded_digits_lt_2modulus : forall us, (length us = length base) -> carry_done us -> - BaseSystem.decode base us < 2 * modulus. - Proof. - + Lemma land_ones_modulus_digits : map (Z.land max_ones) modulus_digits = modulus_digits. + Admitted. + + Lemma nth_default_map_land_zero : forall l i, nth_default 0 (map (Z.land 0) l) i = 0. + Admitted. - SearchAbout (carry_full (carry_full (carry_full _))). + Lemma freeze_in_bounds : forall us, + pre_carry_bounds us -> (length us = length base) -> + carry_done (freeze us). + Proof. + unfold freeze; intros ? PCB lengths_eq. + rewrite carry_done_bounds; intro i. + rewrite nth_default_map2 with (d1 := 0) (d2 := 0). + simpl_lengths. + break_if; [ | split; (omega || auto)]. + break_if. + + rewrite land_ones_modulus_digits. + apply isFull_true_iff in Heqb; [ | simpl_lengths]. + destruct Heqb as [first_digit high_digits]. + destruct (eq_nat_dec i 0). + - subst. + clear high_digits. + rewrite nth_default_modulus_digits. + repeat (break_if; try omega). + pose proof (carry_full_3_done us PCB lengths_eq) as cf3_done. + rewrite carry_done_bounds in cf3_done. + specialize (cf3_done 0%nat). + omega. + - assert ((0 < i <= length (carry_full (carry_full (carry_full us))) - 1)%nat) as i_range by + (simpl_lengths; apply lt_min_l in l; omega). + specialize (high_digits i i_range). + clear first_digit i_range. + rewrite high_digits. + rewrite <-max_bound_log_cap. + rewrite nth_default_modulus_digits. + repeat (break_if; try omega). + * rewrite Z.sub_diag. + split; try omega. + apply Z.lt_succ_r; auto. + * rewrite Z.lt_succ_r, Z.sub_0_r. split; (omega || auto). + + rewrite nth_default_map_land_zero. + rewrite Z.sub_0_r. + apply carry_done_bounds. + auto using carry_full_3_done. + Qed. + Hint Resolve freeze_in_bounds. + + Lemma two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus. + Proof. + rewrite Z.sub_le_mono_r with (p := 2 ^ k). + rewrite Z.sub_diag. + replace (2 * modulus - 2 ^ k) with (2 ^ k - (2 * c)) by (unfold c; ring). + (* TODO : maybe just require this to be computed? seems doable but annoying to prove this way *) + Admitted. + Local Hint Resolve carry_full_3_done. - Lemma freeze_minimal_rep : forall us, minimal_rep (freeze us). + Lemma freeze_minimal_rep : forall us, pre_carry_bounds us -> (length us = length base) -> + minimal_rep (freeze us). Proof. unfold minimal_rep, freeze. intros. symmetry. apply Z.mod_small. - split. - + admit. - + break_if. - remember (carry_full (carry_full (carry_full us))) as cf3us. - rewrite decode_subtract_modulus_elementwise. - apply isFull_true_ + split; break_if; rewrite decode_map2_sub; simpl_lengths. + + rewrite land_ones_modulus_digits, decode_modulus_digits. + apply Z.le_0_sub. + apply isFull_true_lower_bound; simpl_lengths. + + rewrite map_land_zero, zeros_rep, Z.sub_0_r. + apply decode_carry_done_lower_bound; simpl_lengths. + + rewrite land_ones_modulus_digits, decode_modulus_digits. + rewrite Z.lt_sub_lt_add_r. + apply Z.lt_le_trans with (m := 2 * modulus); try omega. + eapply Z.lt_le_trans; [ | apply two_pow_k_le_2modulus ]. + apply decode_carry_done_upper_bound; simpl_lengths. + + rewrite map_land_zero, zeros_rep, Z.sub_0_r. + apply isFull_false_upper_bound; simpl_lengths. Qed. Hint Resolve freeze_minimal_rep. - Lemma minimal_rep_unique_if_bounded : forall us vs, - minimal_rep us -> minimal_rep vs -> - (forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i) -> - (forall i, 0 <= nth_default 0 vs i < 2 ^ log_cap i) -> + Lemma minimal_rep_unique : forall us vs x, + rep us x -> minimal_rep us -> carry_done us -> + rep vs x -> minimal_rep vs -> carry_done vs -> us = vs. Proof. - Admitted. - Lemma freeze_canonical : forall us vs x y, c_carry_constraint -> + Lemma freeze_canonical : forall us vs x, pre_carry_bounds us -> (length us = length base) -> rep us x -> - pre_carry_bounds vs -> (length vs = length base) -> rep vs y -> - (x mod modulus = y mod modulus) -> + pre_carry_bounds vs -> (length vs = length base) -> rep vs x -> freeze us = freeze vs. Proof. - unfold rep; intros. - apply minimal_rep_unique_if_bounded; auto. - intros. apply freeze_in_bounds; auto. - intros. apply freeze_in_bounds; auto. + intros; eapply minimal_rep_unique; eauto. Qed. End CanonicalizationProofs. \ No newline at end of file diff --git a/src/Util/ListUtil.v b/src/Util/ListUtil.v index dd1e6a5c5..cbd7bd58c 100644 --- a/src/Util/ListUtil.v +++ b/src/Util/ListUtil.v @@ -582,3 +582,37 @@ Lemma In_firstn : forall {T} n l (x : T), In x (firstn n l) -> In x l. Proof. induction n; destruct l; boring. Qed. + +Lemma firstn_firstn : forall {A} m n (l : list A), (n <= m)%nat -> + firstn n (firstn m l) = firstn n l. +Proof. + induction m; destruct n; intros; try omega; auto. + destruct l; auto. + simpl. + f_equal. + apply IHm; omega. +Qed. + +Lemma firstn_succ : forall {A} (d : A) n l, (n < length l)%nat -> + firstn (S n) l = (firstn n l) ++ nth_default d l n :: nil. +Proof. + induction n; destruct l; rewrite ?(@nil_length0 A); intros; try omega. + + rewrite nth_default_cons; auto. + + simpl. + rewrite nth_default_cons_S. + rewrite <-IHn by (rewrite cons_length in *; omega). + reflexivity. +Qed. + +Lemma firstn_all_strong : forall {A} (xs : list A) n, (length xs <= n)%nat -> + firstn n xs = xs. +Proof. + induction xs; intros; try apply firstn_nil. + destruct n; + match goal with H : (length (_ :: _) <= _)%nat |- _ => + simpl in H; try omega end. + simpl. + f_equal. + apply IHxs. + omega. +Qed. diff --git a/src/Util/NatUtil.v b/src/Util/NatUtil.v index 1f69b04d2..59898be7a 100644 --- a/src/Util/NatUtil.v +++ b/src/Util/NatUtil.v @@ -57,7 +57,18 @@ Proof. } Qed. +Lemma lt_min_l : forall x a b, (x < min a b)%nat -> (x < a)%nat. +Proof. + intros ? ? ? lt_min. + apply Nat.min_glb_lt_iff in lt_min. + destruct lt_min; assumption. +Qed. +(* useful for hints *) +Lemma eq_le_incl_rev : forall a b, a = b -> b <= a. +Proof. + intros; omega. +Qed. Lemma beq_nat_eq_nat_dec {R} (x y : nat) (a b : R) : (if EqNat.beq_nat x y then a else b) = (if eq_nat_dec x y then a else b). -- cgit v1.2.3