diff options
author | jadep <jade.philipoom@gmail.com> | 2016-06-11 09:17:09 -0400 |
---|---|---|
committer | jadep <jade.philipoom@gmail.com> | 2016-06-11 09:17:09 -0400 |
commit | a37eb79f7750a419346211abb58ec45b79975da0 (patch) | |
tree | ba62e0ff6176a8dc208f65bfae52c8415076d107 /src/ModularArithmetic | |
parent | 2e566c32baf2a140cd7820c4f06437ee5c43ac44 (diff) |
starting rewrite using different definition of map
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystem.v | 46 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemProofs.v | 1161 | ||||
-rw-r--r-- | src/ModularArithmetic/PseudoMersenneBaseParamProofs.v | 7 | ||||
-rw-r--r-- | src/ModularArithmetic/PseudoMersenneBaseParams.v | 2 |
4 files changed, 1140 insertions, 76 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index 558b9a5a2..a48ec2536 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -76,42 +76,44 @@ Section Canonicalization. Definition full_carry_chain := make_chain (length limb_widths). (* compute at compile time *) - Definition max_ones := Z.ones - ((fix loop current_max lw := - match lw with - | nil => current_max - | w :: lw' => loop (Z.max w current_max) lw' - end - ) 0 limb_widths). + Definition max_ones := Z.ones (fold_right Z.max 0 limb_widths). (* compute at compile time? *) Definition carry_full := carry_sequence full_carry_chain. Definition max_bound i := Z.ones (log_cap i). - Definition isFull us := - (fix loop full i := - match i with - | O => full (* don't test 0; the test for 0 is the initial value of [full]. *) - | S i' => loop (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i' - end - ) (Z.ltb (max_bound 0 - (c + 1)) (nth_default 0 us 0)) (length us - 1)%nat. - - Fixpoint range' n m := - match m with - | O => nil - | S m' => (n - m)%nat :: range' n m' + Fixpoint isFull' us full i := + match i with + | O => andb (Z.ltb (max_bound 0 - c) (nth_default 0 us 0)) full + | S i' => isFull' us (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i' end. - Definition range n := range' n n. + Definition isFull us := isFull' us true (length us - 1)%nat. + + Fixpoint modulus_digits' i := + match i with + | O => max_bound i - c + 1 :: nil + | S i' => max_bound i :: modulus_digits' i' + end. - Definition land_max_bound and_term i := Z.land and_term (max_bound i). + (* compute at compile time *) + Definition modulus_digits := modulus_digits' (length base). + + Fixpoint map2 {A B C} (f : A -> B -> C) (la : list A) (lb : list B) : list C := + match la with + | nil => nil + | a :: la' => match lb with + | nil => nil + | b :: lb' => f a b :: map2 f la' lb' + end + end. Definition freeze us := let us' := carry_full (carry_full (carry_full us)) in let and_term := if isFull us' then max_ones else 0 in (* [and_term] is all ones if us' is full, so the subtractions subtract q overall. Otherwise, it's all zeroes, and the subtractions do nothing. *) - map (fun x => (snd x) - land_max_bound and_term (fst x)) (combine (range (length us')) us'). + map2 (fun x y => x - y) us' (map (Z.land and_term) modulus_digits). End Canonicalization. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 274acff5a..7c430417b 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -398,7 +398,15 @@ Section CarryProofs. End CarryProofs. Section CanonicalizationProofs. - Context `{prm : PseudoMersenneBaseParams} (lt_1_length_base : (1 < length base)%nat) (c_pos : 0 < c) {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B). + Context `{prm : PseudoMersenneBaseParams} (lt_1_length_base : (1 < length base)%nat) + {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B) + (c_pos : 0 < c) + (* on the first reduce step, we add at most one bit of width to the first digit *) + (c_reduce1 : c * (Z.ones (B - log_cap (pred (length base)))) < max_bound 0 + 1) + (* on the second reduce step, we add at most one bit of width to the first digit, + and leave room to carry c one more time after the highest bit is carried *) + (c_reduce2 : c <= max_bound 0 - c). +(* TODO (c_reduce2: max_bound 0 + c < 2 ^ (log_cap 0 + 1)). c < max_bound 0 + 2*) (* TODO : move *) Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat -> @@ -488,6 +496,26 @@ Section CanonicalizationProofs. omega. Qed. + (* TODO : move *) + Lemma Z_ones_pos_pos : forall i, (0 < i) -> 0 < Z.ones i. + Proof. + intros. + unfold Z.ones. + rewrite Z.shiftl_1_l. + apply Z.lt_succ_lt_pred. + apply Z.pow_gt_1; omega. + Qed. + + Lemma max_bound_pos : forall i, (i < length base)%nat -> 0 < max_bound i. + Proof. + unfold max_bound, log_cap; intros; apply Z_ones_pos_pos. + apply limb_widths_pos. + rewrite nth_default_eq. + apply nth_In. + rewrite <-base_length; assumption. + Qed. + Local Hint Resolve max_bound_pos. + Lemma max_bound_nonneg : forall i, 0 <= max_bound i. Proof. unfold max_bound; intros; auto using Z_ones_nonneg. @@ -550,6 +578,11 @@ Section CanonicalizationProofs. omega. Qed. + Lemma log_cap_eq : forall i, log_cap i = nth_default 0 limb_widths i. + Proof. + reflexivity. + Qed. + (* END groundwork proofs *) Opaque pow2_mod log_cap max_bound. @@ -565,11 +598,6 @@ Section CanonicalizationProofs. (* BEGIN defs *) - Definition c_carry_constraint : Prop := - (c * (Z.ones (B - log_cap (pred (length base)))) < max_bound 0 + 1) - /\ (max_bound 0 + c < 2 ^ (log_cap 0 + 1)) - /\ (c <= max_bound 0 - c). - Definition pre_carry_bounds us := forall i, 0 <= nth_default 0 us i < if (eq_nat_dec i 0) then 2 ^ B else 2 ^ B - 2 ^ (B - log_cap (pred i)). @@ -602,6 +630,10 @@ Section CanonicalizationProofs. destruct i0; add_set_nth; subst; rewrite ?Z.add_0_l; auto. Qed. + Lemma carry_done_bounds : forall us, carry_done us <-> + (forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i). + Admitted. + (* END defs *) (* BEGIN proofs about first carry loop *) @@ -810,7 +842,6 @@ Section CanonicalizationProofs. omega. Qed. - Lemma carry_sequence_no_overflow : forall i us, pre_carry_bounds us -> (length us = length base) -> nth_default 0 (carry_sequence (make_chain i) us) i < 2 ^ B. @@ -879,7 +910,7 @@ Section CanonicalizationProofs. (* BEGIN proofs about second carry loop *) - Lemma carry_sequence_carry_full_bounds_same : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_sequence_carry_full_bounds_same : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full us)) i <= 2 ^ log_cap i. Proof. @@ -905,16 +936,14 @@ Section CanonicalizationProofs. eapply Z.le_lt_trans; [ apply carry_full_bounds_0; auto | ]. replace (2 ^ log_cap 0 * 2) with (2 ^ log_cap 0 + 2 ^ log_cap 0) by ring. rewrite <-max_bound_log_cap, <-Z.add_1_l. - apply Z.add_lt_le_mono; try omega. - unfold c_carry_constraint in *. - intuition. + apply Z.add_lt_le_mono; omega. * eapply Z.le_lt_trans; [ apply IHi; auto; omega | ]. apply Z.lt_mul_diag_r; auto; omega. - rewrite carry_sequence_unaffected by carry_length_conditions. apply carry_full_bounds; auto; omega. Qed. - Lemma carry_full_2_bounds_0 : forall us, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_0 : forall us, pre_carry_bounds us -> (length us = length base)%nat -> (1 < length base)%nat -> 0 <= nth_default 0 (carry_full (carry_full us)) 0 <= max_bound 0 + c. Proof. @@ -945,7 +974,7 @@ Section CanonicalizationProofs. omega. Qed. - Lemma carry_full_2_bounds_succ : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_succ : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < pred (length base))%nat -> ((0 < i < length base)%nat -> 0 <= nth_default 0 @@ -954,13 +983,13 @@ Section CanonicalizationProofs. 0 <= nth_default 0 (carry_simple i (carry_sequence (make_chain i) (carry_full (carry_full us)))) (S i) <= 2 ^ log_cap (S i). Proof. - unfold carry_simple; intros ? ? PCB CCC length_eq ? IH. + unfold carry_simple; intros ? ? PCB length_eq ? IH. add_set_nth. split. + apply Z.add_nonneg_nonneg. apply Z.shiftr_nonneg. destruct i; - [ simpl; pose proof (carry_full_2_bounds_0 us PCB CCC length_eq); omega | ]. + [ simpl; pose proof (carry_full_2_bounds_0 us PCB length_eq); omega | ]. - assert (0 < S i < length base)%nat as IHpre by omega. specialize (IH IHpre). omega. @@ -978,7 +1007,7 @@ Section CanonicalizationProofs. intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. Qed. - Lemma carry_full_2_bounds_same : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_same : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) i <= 2 ^ log_cap i. Proof. @@ -999,14 +1028,13 @@ Section CanonicalizationProofs. - apply Z_div_floor; auto. eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ]. replace (Z.succ 1) with (2 ^ 1) by ring. - rewrite <-Z.pow_add_r by (omega || auto). - unfold c_carry_constraint in *. - intuition. + rewrite <-max_bound_log_cap. + ring_simplify. omega. - apply carry_full_bounds; carry_length_conditions. intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. Qed. - Lemma carry_full_2_bounds' : forall us i j, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds' : forall us i j, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (i + j < length base)%nat -> (j <> 0)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain (i + j)) (carry_full (carry_full us))) i <= max_bound i. Proof. @@ -1017,7 +1045,7 @@ Section CanonicalizationProofs. + apply nth_default_carry_bound_upper; carry_length_conditions. Qed. - Lemma carry_full_2_bounds : forall us i j, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds : forall us i j, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (i < j < length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain j) (carry_full (carry_full us))) i <= max_bound i. Proof. @@ -1026,12 +1054,12 @@ Section CanonicalizationProofs. eapply carry_full_2_bounds'; eauto; omega. Qed. - Lemma carry_carry_full_2_bounds_0_lower : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_carry_full_2_bounds_0_lower : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0). Proof. induction i; try omega. - intros ? ? length_eq ?; simpl. + intros ? length_eq ?; simpl. destruct i. + unfold carry. break_if; @@ -1045,27 +1073,25 @@ Section CanonicalizationProofs. intuition. Qed. - Lemma carry_full_2_bounds_lower :forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_lower :forall us i, pre_carry_bounds us -> (length us = length base)%nat -> 0 <= nth_default 0 (carry_full (carry_full us)) i. Proof. - intros. - destruct i. + intros; destruct i. + apply carry_full_2_bounds_0; auto. + apply carry_full_bounds; try solve [carry_length_conditions]. - intro j. - destruct j. + intro j; destruct j. - apply carry_full_bounds_0; auto. - apply carry_full_bounds; carry_length_conditions. Qed. - Lemma carry_carry_full_2_bounds_0_upper : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_carry_full_2_bounds_0_upper : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0 <= max_bound 0 - c) \/ carry_done (carry_sequence (make_chain i) (carry_full (carry_full us))). Proof. induction i; try omega. - intros ? ? length_eq ?; simpl. + intros ? length_eq ?; simpl. destruct i. + destruct (Z_le_dec (nth_default 0 (carry_full (carry_full us)) 0) (max_bound 0)). - right. @@ -1086,20 +1112,18 @@ Section CanonicalizationProofs. unfold carry_simple. add_set_nth. left. remember ((nth_default 0 (carry_full (carry_full us)) 0)) as x. - apply Z.le_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)). - * replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring. - rewrite pow2_mod_spec by auto. - rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega). - rewrite <-max_bound_log_cap, <-Z.add_1_l, Z.mod_small. - apply Z.sub_le_mono_r. - subst; apply carry_full_2_bounds_0; auto. - split; try omega. - pose proof carry_full_2_bounds_0. - apply Z.le_lt_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); - [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto; - ring_simplify; unfold c_carry_constraint in *; omega | ]. - ring_simplify; unfold c_carry_constraint in *; omega. - * ring_simplify; unfold c_carry_constraint in *; omega. + apply Z.le_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); try omega. + replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring. + rewrite pow2_mod_spec by auto. + rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega). + rewrite <-max_bound_log_cap, <-Z.add_1_l, Z.mod_small. + apply Z.sub_le_mono_r. + subst; apply carry_full_2_bounds_0; auto. + split; try omega. + pose proof carry_full_2_bounds_0. + apply Z.le_lt_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); + [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto; + ring_simplify | ]; omega. + rewrite carry_unaffected_low by carry_length_conditions. assert (0 < S i < length base)%nat by omega. intuition. @@ -1119,7 +1143,7 @@ Section CanonicalizationProofs. (* BEGIN proofs about third carry loop *) - Lemma carry_full_3_bounds : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_3_bounds : forall us i, pre_carry_bounds us -> (length us = length base)%nat ->(i < length base)%nat -> 0 <= nth_default 0 (carry_full (carry_full (carry_full us))) i <= max_bound i. Proof. @@ -1156,24 +1180,1055 @@ Section CanonicalizationProofs. apply shiftr_eq_0_max_bound; auto; omega. Qed. + Lemma carry_full_3_done : forall us, pre_carry_bounds us -> + (length us = length base)%nat -> + carry_done (carry_full (carry_full (carry_full us))). + Proof. + intros. + apply carry_done_bounds; intro i. + destruct (lt_dec i (length base)). + + rewrite <-max_bound_log_cap, Z.lt_succ_r. + auto using carry_full_3_bounds. + + rewrite nth_default_out_of_bounds; carry_length_conditions. + Qed. + + (* END proofs about third carry loop *) + Lemma nth_error_combine : forall {A B} i (x : A) (x' : B) l l', nth_error l i = Some x -> nth_error l' i = Some x' -> nth_error (combine l l') i = Some (x, x'). Admitted. - - Lemma nth_error_range : forall {A} i (l : list A), (i < length l)%nat -> - nth_error (range (length l)) i = Some i. +(* + Lemma nth_error_range : forall i r, (i < r)%nat -> + nth_error (range r) i = Some i. Admitted. +*) + Lemma carry_full_length : forall us, (length us = length base)%nat -> + length (carry_full us) = length us. + Proof. + intros; carry_length_conditions. + Qed. + + (* TODO : move? *) + Lemma make_chain_lt : forall x i : nat, In i (make_chain x) -> (i < x)%nat. + Proof. + induction x; simpl; intuition. + Qed. + + Lemma carry_full_preserves_rep : forall us x, (length us = length base)%nat -> + rep us x -> rep (carry_full us) x. + Proof. + unfold carry_full; intros. + apply carry_sequence_rep; auto. + unfold full_carry_chain; rewrite base_length; apply make_chain_lt. + Qed. - (* END proofs about third carry loop *) Opaque carry_full. +(* + Lemma length_range : forall n, length (range n) = n. + Proof. + induction n; intros; auto. + simpl. + rewrite app_length, cons_length, nil_length0. + omega. + Qed. + + Lemma range0_nil : range 0 = nil. + Proof. + reflexivity. + Qed. + + Lemma range_succ : forall n, range (S n) = range n ++ n :: nil. + Proof. + reflexivity. + Qed. - Lemma freeze_in_bounds : forall us i, (us <> nil)%nat -> - 0 <= nth_default 0 (freeze us) i < 2 ^ log_cap i. + Lemma nth_default_range : forall d r n, (n < r)%nat -> nth_default d (range r) n = n. Proof. + induction r; intro; try omega. + intros. + assert (n = r \/ n < r)%nat as cases by omega. + destruct cases; subst; rewrite range_succ, nth_default_app, length_range; break_if; try omega. + + rewrite Nat.sub_diag. + auto using nth_default_cons. + + apply IHr; omega. + Qed. + + Lemma combine_app : forall {A B} (x y : list A) (z : list B), (length (x ++ y) <= length z)%nat -> + combine (x ++ y) z = combine x z ++ combine y (skipn (length x) z). + Proof. + intros. + rewrite <- (firstn_skipn (length x) z) at 1. + rewrite combine_app_samelength by + (rewrite firstn_length, Nat.min_l; auto; rewrite app_length in *; omega). + rewrite <-combine_truncate_r; reflexivity. + Qed. + + Lemma combine_range_succ : forall l r, (S r <= length l)%nat -> + combine (range (S r)) l = (combine (range r) l) ++ (r,nth_default 0 l r) :: nil. + Proof. + intros. + simpl. + rewrite combine_app by (rewrite app_length, cons_length, length_range, nil_length0; omega). + f_equal. + rewrite length_range. + erewrite skipn_nth_default by omega. + reflexivity. + Qed. + Opaque range. + + Lemma map_sub_combine_range : forall d d' f l i, (l <> nil) -> (i < length l)%nat -> + nth_default d (map (fun x => snd x - f (fst x)) (combine (range (length l)) l)) i = + nth_default d' l i - f i. + Proof. + intros until 1. + intros lt_i_length. + destruct (nth_error_length_exists_value i l lt_i_length). + erewrite nth_error_value_eq_nth_default; auto. + erewrite map_nth_error; + [ | apply nth_error_combine; try apply nth_error_range; eauto]. + erewrite nth_error_value_eq_nth_default; eauto. + Qed. +*) + Lemma isFull'_false : forall us n, isFull' us false n = false. + Proof. + unfold isFull'; induction n; intros; rewrite Bool.andb_false_r; auto. + Qed. + + Lemma isFull'_last : forall us b j, (j <> 0)%nat -> isFull' us b j = true -> + max_bound j = nth_default 0 us j. + Proof. + induction j; simpl; intros; try omega. + match goal with + | [H : isFull' _ ((?comp ?a ?b) && _) _ = true |- _ ] => + case_eq (comp a b); rewrite ?Z.eqb_eq; intro comp_eq; try assumption; + rewrite comp_eq, Bool.andb_false_l, isFull'_false in H; congruence + end. + Qed. + + Lemma isFull'_lower_bound_0 : forall j us b, isFull' us b j = true -> + max_bound 0 - c < nth_default 0 us 0. + Proof. + induction j; intros. + + match goal with H : isFull' _ _ 0 = _ |- _ => cbv [isFull'] in H; + apply Bool.andb_true_iff in H; destruct H end. + apply Z.ltb_lt; assumption. + + eauto. + Qed. + + Lemma isFull_lower_bound_0 : forall us, isFull us = true -> + max_bound 0 - c < nth_default 0 us 0. + Proof. + eauto using isFull'_lower_bound_0. + Qed. + + Lemma isFull'_true_full : forall us i j b, (i <> 0)%nat -> (i <= j)%nat -> isFull' us b j = true -> + max_bound i = nth_default 0 us i. + Proof. + induction j; intros; try omega. + assert (i = S j \/ i <= j)%nat as cases by omega. + destruct cases. + + subst. eapply isFull'_last; eauto. + + eapply IHj; eauto. + Qed. + + Lemma isFull_true_full : forall i us, (length us = length base) -> + (0 < i < length base)%nat -> isFull us = true -> + max_bound i = nth_default 0 us i. + Proof. + unfold isFull; intros. + eapply isFull'_true_full with (j := (length us - 1)%nat); eauto; omega. + Qed. + + (* TODO : move *) + Lemma N_le_1_l : forall p, (1 <= N.pos p)%N. + Proof. + destruct p; cbv; congruence. + Qed. + + (* TODO : move *) + Lemma Pos_land_upper_bound_l : forall a b, (Pos.land a b <= N.pos a)%N. + Proof. + induction a; destruct b; intros; try solve [cbv; congruence]; + simpl; specialize (IHa b); case_eq (Pos.land a b); intro; simpl; + try (apply N_le_1_l || apply N.le_0_l); intro land_eq; + rewrite land_eq in *; unfold N.le, N.compare in *; + rewrite ?Pos.compare_xI_xI, ?Pos.compare_xO_xI, ?Pos.compare_xO_xO; + try assumption. + destruct (p ?=a)%positive; cbv; congruence. + Qed. + + Lemma Zneg_nonneg_false : forall p, 0 <= Z.neg p -> False. Admitted. + Hint Resolve Zneg_nonneg_false. - Lemma freeze_canonical : forall us vs x, rep us x -> rep vs x -> - freeze us = freeze vs. + (* TODO : move *) + Lemma Z_land_upper_bound_l : forall a b, (0 <= a) -> (0 <= b) -> + Z.land a b <= a. + Proof. + intros. + destruct a, b; try solve [exfalso; auto]; try solve [cbv; congruence]. + cbv [Z.land]. + rewrite <-N2Z.inj_pos, <-N2Z.inj_le. + auto using Pos_land_upper_bound_l. + Qed. + + (* TODO : move *) + Lemma Z_land_upper_bound_r : forall a b, (0 <= a) -> (0 <= b) -> + Z.land a b <= b. + Proof. + intros. + rewrite Z.land_comm. + auto using Z_land_upper_bound_l. + Qed. + + Lemma max_ones_nonneg : 0 <= max_ones. + Proof. + unfold max_ones. + apply Z_ones_nonneg. + pose proof limb_widths_nonneg. + induction limb_widths. + cbv; congruence. + simpl. + apply Z.max_le_iff. + right. + apply IHl; auto using in_cons. + Qed. + Hint Resolve max_ones_nonneg. +(* + Lemma sub_land_max_bound_max_ones_lower : + forall us i, (length us = length base) -> isFull us = true -> + (i < length us)%nat -> + 0 <= nth_default 0 us i - land_max_bound max_ones i. + Proof. + unfold land_max_bound; intros. + break_if. + + subst. apply Z.le_0_sub. + etransitivity. + - apply Z_land_upper_bound_r; auto. + apply Z.le_trans with (m := c - 1); omega. + - rewrite Z.add_1_r. + apply Z.le_succ_l. + auto using isFull_lower_bound_0. + + apply Z.le_0_sub. + etransitivity. + apply Z_land_upper_bound_r; auto. + apply Z.eq_le_incl. + apply isFull_true_full; auto. + omega. + Qed. +*) + (* TODO : move *) + Lemma Z_le_fold_right_max : forall low l x, (forall y, In y l -> low <= y) -> + In x l -> x <= fold_right Z.max low l. + Proof. + induction l; intros ? lower_bound In_list; [cbv [In] in *; intuition | ]. + simpl. + destruct (in_inv In_list); subst. + + apply Z.le_max_l. + + etransitivity. + - apply IHl; auto; intuition. + - apply Z.le_max_r. + Qed. + + (* TODO : move *) + Lemma Z_le_fold_right_max_initial : forall low l, low <= fold_right Z.max low l. + Proof. + induction l; intros; try reflexivity. + etransitivity; [ apply IHl | apply Z.le_max_r ]. + Qed. + + Lemma land_max_ones_noop : forall x i, 0 <= x < 2 ^ log_cap i -> Z.land max_ones x = x. + Proof. + unfold max_ones. + intros ? ? x_range. + rewrite Z.land_comm. + rewrite Z.land_ones by apply Z_le_fold_right_max_initial. + apply Z.mod_small. + split; try omega. + eapply Z.lt_le_trans; try eapply x_range. + apply Z.pow_le_mono_r; try omega. + rewrite log_cap_eq. + destruct (lt_dec i (length limb_widths)). + + apply Z_le_fold_right_max. + - apply limb_widths_nonneg. + - rewrite nth_default_eq. + auto using nth_In. + + rewrite nth_default_out_of_bounds by omega. + apply Z_le_fold_right_max_initial. + Qed. + + Lemma land_max_ones_max_bound : forall i, Z.land max_ones (max_bound i) = max_bound i. + Proof. + intros. + apply land_max_ones_noop with (i := i). + rewrite <-max_bound_log_cap. + split; auto; omega. + Qed. + + Lemma land_max_ones_max_bound_sub_c : + Z.land max_ones (max_bound 0 - c + 1) = max_bound 0 - c + 1. + Proof. + apply land_max_ones_noop with (i := 0%nat). + rewrite <-max_bound_log_cap. + split; auto; try omega. + Qed. +(* + Lemma land_max_bound_pos : forall i, (i < length base)%nat -> + 0 < land_max_bound max_ones i. + Proof. + unfold land_max_bound; intros. + break_if. + + subst. + rewrite land_max_ones_max_bound_sub_c by assumption. + apply Z.lt_le_trans with (m := c); auto. omega. + + rewrite land_max_ones_max_bound by assumption. + auto using max_bound_pos. + Qed. + Local Hint Resolve land_max_bound_pos. + + + Lemma sub_land_max_bound_max_ones_upper : + forall us i, nth_default 0 us i <= max_bound i -> + (length us = length base) -> (i < length us)%nat -> + nth_default 0 us i - land_max_bound max_ones i < 2 ^ log_cap i. + Proof. + intros. + eapply Z.lt_trans. + + eapply Z.lt_sub_pos. + apply land_max_bound_pos; auto; omega. + + rewrite <-max_bound_log_cap. + omega. + Qed. + + + Lemma land_max_bound_0 : forall i, land_max_bound 0 i = 0. Admitted. +*) + + Lemma full_isFull'_true : forall j us, (length us = length base) -> + ( max_bound 0 - c < nth_default 0 us 0 + /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_bound i)) -> + isFull' us true j = true. + Proof. + induction j; intros. + + cbv [isFull']; apply Bool.andb_true_iff. + rewrite Z.ltb_lt; intuition. + + intuition. + simpl. + match goal with H : forall j, _ -> ?b j = ?a j |- appcontext[?a ?i =? ?b ?i] => + replace (a i =? b i) with true by (symmetry; apply Z.eqb_eq; symmetry; apply H; omega) end. + apply IHj; auto; intuition. + Qed. + + Lemma isFull'_true_iff : forall j us, (length us = length base) -> (isFull' us true j = true <-> + max_bound 0 - c < nth_default 0 us 0 + /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_bound i)). + Proof. + intros; split; intros; auto using full_isFull'_true. + split; eauto using isFull'_lower_bound_0. + intros. + symmetry; eapply isFull'_true_full; [ omega | | eauto]. + omega. + Qed. + + Opaque isFull' (* TODO isFull *) max_ones. + + (* TODO : move *) + Lemma length_nonzero_nonnil : forall {A} (l : list A), (0 < length l)%nat -> + l <> nil. + Proof. + destruct l; boring; congruence. + Qed. + + Lemma carry_full_3_length : forall us, (length us = length base) -> + length (carry_full (carry_full (carry_full us))) = length us. + Proof. + intros. + repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto); auto. + Qed. + Local Hint Resolve carry_full_3_length. + + Lemma freeze_in_bounds : forall us, + pre_carry_bounds us -> (length us = length base) -> + carry_done (freeze us). + Proof. + unfold freeze; intros. + rewrite carry_done_bounds; intro i. + destruct (lt_dec i (length us)). + + rewrite map_sub_combine_range with (d' := 0) by (try apply length_nonzero_nonnil; + (repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto)); auto; try omega). + break_if. + - split; [apply sub_land_max_bound_max_ones_lower + |apply sub_land_max_bound_max_ones_upper ]; + rewrite ?carry_full_3_length; auto. + apply carry_full_3_bounds; auto; omega. + - rewrite land_max_bound_0, <-max_bound_log_cap, Z.lt_succ_r, Z.sub_0_r. + apply carry_full_3_bounds; auto; omega. + + rewrite nth_default_out_of_bounds; [ split; auto; omega | ]. + rewrite map_length, combine_length, length_range, Nat.min_id. + repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto). + omega. + Qed. + + Lemma freeze_length : forall us, (length us = length base) -> + length (freeze us) = length us. + Proof. + unfold freeze; intros. + rewrite map_length, combine_length, length_range, Nat.min_id. + auto. + Qed. + + (* TODO : move *) + Lemma nth_default_same_lists_same : (* TODO : rename if this works *) + forall {A} d (l' l : list A), (length l = length l') -> + (forall i, nth_default d l i = nth_default d l' i) -> + l = l'. + Proof. + induction l'; intros until 0; intros lengths_equal nth_default_match. + + apply length0_nil; auto. + + destruct l; rewrite ?nil_length0, !cons_length in lengths_equal; + [congruence | ]. + pose proof (nth_default_match 0%nat) as nth_default_match_0. + rewrite !nth_default_cons in nth_default_match_0. + f_equal; auto. + apply IHl'; [ omega | ]. + intros. + specialize (nth_default_match (S i)). + rewrite !nth_default_cons_S in nth_default_match. + assumption. + Qed. + + Lemma not_full_no_change : forall us, length us = length base -> + map (fun x : nat * Z => snd x - land_max_bound 0 (fst x)) + (combine (range (length us)) us) = us. + Proof. + intros ? lengths_eq. + apply nth_default_same_lists_same with (d := 0). + + rewrite map_length, combine_length, length_range, Nat.min_id; auto. + + intros. + destruct (lt_dec i (length us)). + - erewrite map_sub_combine_range by (auto; intro false_eq; subst; + rewrite nil_length0 in lengths_eq; omega). + rewrite land_max_bound_0. + apply Z.sub_0_r. + - rewrite !nth_default_out_of_bounds; try omega. + rewrite map_length, combine_length, length_range, Nat.min_id; omega. + Qed. + + (* TODO : move *) + Lemma map_cons : forall {A B} (f : A -> B) x xs, map f (x :: xs) = f x :: (map f xs). + Proof. + auto. + Qed. + + (* TODO : move *) + Lemma firstn_firstn : forall {A} m n (l : list A), (n <= m)%nat -> + firstn n (firstn m l) = firstn n l. + Proof. + induction m; destruct n; intros; try omega; auto. + destruct l; auto. + simpl. + f_equal. + apply IHm; omega. + Qed. + + (* TODO : move *) + Lemma firstn_succ : forall n l, (n < length l)%nat -> + firstn (S n) l = (firstn n l) ++ nth_default 0 l n :: nil. + Proof. + induction n; destruct l; rewrite ?(@nil_length0 Z); intros; try omega. + + rewrite nth_default_cons; auto. + + simpl. + rewrite nth_default_cons_S. + rewrite <-IHn by (rewrite cons_length in *; omega). + reflexivity. + Qed. +(* +Print BaseSystem.accumulate. +SearchAbout combine range. +mapi : forall {A B}, (nat -> A -> B) -> list A -> list B +mapi (fun x y => (x, y)) ls +map2 : forall {A B C}, (A -> B -> C) -> list A -> list B -> list C + +BaseSystem.decode u (map2 (fun x y => x - y) v w) += BaseSystem.decode u v - BaseSystem.decode u w + +map2 f ls1 ls2 = map (fun xy => f (fst xy) (snd xy)) (combine ls1 ls2) + +map2 f (map g ls1) ls2 = map2 (fun x y => f (g x) y) ls1 ls2 +map2 f ls1 (map g ls2) = map2 (fun x y => f x (g y)) ls1 ls2 + +Locate mapi. +*) +Print map. + + Fixpoint mapi' {A B} (f : nat -> A -> B) i (l : list A) : list B := + match l with + | nil => nil + | x :: l' => f i x :: mapi' f (S i) l' + end. + + Definition mapi {A B} (f : nat -> A -> B) (l : list A) : list B := mapi' f 0%nat l. + + + Fixpoint map2 {A B C} (f : A -> B -> C) (la : list A) (lb : list B) : list C := + match la with + | nil => nil + | a :: la' => match lb with + | nil => nil + | b :: lb' => f a b :: map2 f la' lb' + end + end. + + Lemma map2_combine : forall {A B C} (f : A -> B -> C) ls1 ls2, + map2 f ls1 ls2 = map (fun xy => f (fst xy) (snd xy)) (combine ls1 ls2). + Admitted. + + Lemma map2_map_l : forall {A' A B C} (f : A -> B -> C) (g : A' -> A) ls1 ls2, + map2 f (map g ls1) ls2 = map2 (fun x y => f (g x) y) ls1 ls2. + Admitted. + + Lemma map2_map_r :forall {B' A B C} (f : A -> B -> C) (g : B' -> B) ls1 ls2, + map2 f ls1 (map g ls2) = map2 (fun x y => f x (g y)) ls1 ls2. + Admitted. + + (* TODO : rewrite using the above? *) + + Hint Rewrite app_length cons_length map_length combine_length length_range firstn_length map_app : lengths. + + Lemma decode_subtract_elementwise: forall f r l, (length l = length base) -> + (r <= length l)%nat -> + BaseSystem.decode (firstn r base) (map (fun x => snd x - f (fst x)) (combine (range r) l)) = + BaseSystem.decode (firstn r base) (firstn r l) - BaseSystem.decode (firstn r base) (map f (range r)). + Proof. + induction r; intros. + + rewrite range0_nil. + cbv [combine map BaseSystem.decode sum_firstn firstn fold_right]. + rewrite decode_nil. + auto. + + rewrite combine_range_succ by assumption. + rewrite (firstn_succ _ l) by omega. + rewrite range_succ. + rewrite !map_app, !decode'_splice. + autorewrite with lengths. + rewrite Min.min_l, firstn_firstn, firstn_succ by omega. + rewrite skipn_app_sharp by (rewrite firstn_length, Nat.min_l; omega). + simpl. + rewrite !decode'_cons, decode_nil, IHr by omega. + unfold BaseSystem.decode. + ring. + Qed. + + Definition modulus_digit i := if (eq_nat_dec i 0) then max_bound i - c + 1 else max_bound i. + (* TODO : maybe use this more? *) + + Lemma modulus_digit_nonneg : forall i, 0 <= modulus_digit i. + Proof. + unfold modulus_digit; intros; break_if; auto; subst; omega. + Qed. + Hint Resolve modulus_digit_nonneg. + + Lemma modulus_digit_lt_cap : forall i, + modulus_digit i < 2 ^ log_cap i. + Proof. + unfold modulus_digit; intros; rewrite <- max_bound_log_cap; break_if; omega. + Qed. + Hint Resolve modulus_digit_lt_cap. + + Lemma modulus_digit_land_max_bound_max_ones : forall i, + land_max_bound max_ones i = modulus_digit i. + Proof. + unfold land_max_bound; intros. + eapply land_max_ones_noop; eauto. + Qed. + + Lemma decode_modulus_digit_partial : forall n, (0 < n <= length base)%nat -> + BaseSystem.decode (firstn n base) (map modulus_digit (range (length base))) = + 2 ^ (sum_firstn limb_widths n) - c. + Proof. + induction n; intros; try omega. + rewrite firstn_succ by omega. + rewrite base_app. + rewrite decode'_truncate, firstn_length, Min.min_l in * by omega. + rewrite firstn_firstn by omega. + rewrite skipn_nth_default with (d := 0) by (autorewrite with lengths; omega). + rewrite decode'_cons, decode_base_nil, Z.add_0_r. + erewrite map_nth_default with (y := 0) (x := 0%nat) by + (autorewrite with lengths; omega). + rewrite nth_default_range by (autorewrite with lengths; omega). + rewrite nth_default_base by omega. + unfold modulus_digit at 2; break_if. + + subst. + clear IHn. + cbv [firstn BaseSystem.decode' combine fold_right]. + destruct (nth_error_length_exists_value 0 limb_widths); try (rewrite <-base_length; omega). + erewrite sum_firstn_succ; eauto. + replace (max_bound 0) with (2 ^ log_cap 0 - 1) by (rewrite <-max_bound_log_cap; omega). + rewrite log_cap_eq. + erewrite nth_error_value_eq_nth_default; eauto. + rewrite Z.pow_add_r by (auto using sum_firstn_limb_widths_nonneg; apply limb_widths_nonneg; + auto using (nth_error_value_In 0)). + cbv [sum_firstn firstn fold_right]. + ring. + + rewrite IHn by (auto; omega). + replace (max_bound n) with (2 ^ log_cap n - 1) by (rewrite <-max_bound_log_cap; omega). + rewrite log_cap_eq. + destruct (nth_error_length_exists_value n limb_widths); try (rewrite <- base_length; omega). + erewrite sum_firstn_succ; eauto. + erewrite nth_error_value_eq_nth_default; eauto. + rewrite Z.pow_add_r by (auto using sum_firstn_limb_widths_nonneg; apply limb_widths_nonneg; + auto using (nth_error_value_In n)). + ring. + Qed. + + Lemma decode_map_modulus_digit : + BaseSystem.decode base (map modulus_digit (range (length base))) = modulus. + Proof. + erewrite <-(firstn_all _ base) at 1 by reflexivity. + rewrite decode_modulus_digit_partial by omega. + rewrite base_length. + fold k; unfold c. + ring. + Qed. + + Lemma decode_subtract_modulus_elementwise : forall us, (length us = length base) -> + BaseSystem.decode base + (map (fun x0 : nat * Z => snd x0 - land_max_bound max_ones (fst x0)) + (combine (range (length us)) us)) = BaseSystem.decode base us - modulus. + Proof. + intros. + replace base with (firstn (length us) base) at 1 by (auto using firstn_all). + rewrite decode_subtract_elementwise by omega. + rewrite !firstn_all by auto. + f_equal. + erewrite map_ext; [ | eapply modulus_digit_land_max_bound_max_ones ]. + replace (length us) with (length base) by assumption. + exact decode_map_modulus_digit. + Qed. + + (* TODO : move *) + Lemma decode_mod : forall us vs x, (length us = length base) -> (length vs = length base) -> + decode us = x -> + BaseSystem.decode base us mod modulus = BaseSystem.decode base vs mod modulus -> + decode vs = x. + Proof. + unfold decode; intros until 2; intros decode_us_x BSdecode_eq. + rewrite ZToField_mod in decode_us_x |- *. + rewrite <-BSdecode_eq. + assumption. + Qed. + + Lemma freeze_preserves_rep : forall us x, (length us = length base) -> + rep us x -> rep (freeze us) x. + Proof. + unfold rep; intros. + intuition; rewrite ?freeze_length; auto. + unfold freeze. + break_if. + + apply decode_mod with (us := carry_full (carry_full (carry_full us))). + - rewrite carry_full_3_length; auto. + - autorewrite with lengths. + rewrite Nat.min_id. + rewrite carry_full_3_length; auto. + - repeat apply carry_full_preserves_rep; repeat rewrite carry_full_length; auto. + unfold rep; intuition. + - rewrite decode_subtract_modulus_elementwise by (rewrite carry_full_3_length; auto). + destruct (Z_eq_dec modulus 0); [ subst; rewrite !Zmod_0_r; reflexivity | ]. + rewrite <-Z.add_opp_r. + replace (-modulus) with (-1 * modulus) by ring. + symmetry; auto using Z.mod_add. + + rewrite not_full_no_change by (rewrite carry_full_3_length; auto). + repeat (apply carry_full_preserves_rep; repeat rewrite carry_full_length; auto). + unfold rep; auto. + Qed. + + Lemma isFull_true_iff : forall us, (length us = length base) -> (isFull us = true <-> + max_bound 0 - c < nth_default 0 us 0 + /\ (forall i, (0 < i <= length us - 1)%nat -> nth_default 0 us i = max_bound i)). + Proof. + unfold isFull; intros; auto using isFull'_true_iff. + Qed. + + Definition minimal_rep us := BaseSystem.decode base us = (BaseSystem.decode base us) mod modulus. + + Fixpoint compare' us vs i := + match i with + | O => Eq + | S i' => if Z_eq_dec (nth_default 0 us i') (nth_default 0 vs i') + then compare' us vs i' + else Z.compare (nth_default 0 us i') (nth_default 0 vs i') + end. + + (* Lexicographically compare two vectors of equal length, starting from the END of the list + (in our context, this is the most significant end). *) + Definition compare us vs := compare' us vs (length us). + + Lemma decode_firstn_succ : forall n us, (length us = length base) -> + (n < length base)%nat -> + BaseSystem.decode' (firstn (S n) base) (firstn (S n) us) = + BaseSystem.decode' (firstn n base) (firstn n us) + + nth_default 0 base n * nth_default 0 us n. + Proof. + intros. + rewrite !firstn_succ by omega. + rewrite base_app, firstn_app. + autorewrite with lengths; rewrite !Min.min_l by omega. + rewrite Nat.sub_diag, firstn_firstn, firstn0, app_nil_r by omega. + rewrite skipn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega). + rewrite decode'_cons, decode_nil, Z.add_0_r. + reflexivity. + Qed. + + Lemma decode_lt_next_digit : forall us n, (length us = length base) -> + (n < length base)%nat -> (n < length us)%nat -> + carry_done us -> + BaseSystem.decode' (firstn n base) (firstn n us) < + (nth_default 0 base n). + Proof. + induction n; intros ? ? ? bounded. + + cbv [firstn]. + rewrite decode_base_nil. + apply Z.gt_lt; auto using nth_default_base_positive. + + rewrite decode_firstn_succ by (auto || omega). + rewrite nth_default_base_succ by omega. + eapply Z.lt_le_trans. + - apply Z.add_lt_mono_r. + apply IHn; auto; omega. + - rewrite <-(Z.mul_1_r (nth_default 0 base n)) at 1. + rewrite <-Z.mul_add_distr_l, Z.mul_comm. + apply Z.mul_le_mono_pos_r. + * apply Z.gt_lt. apply nth_default_base_positive; omega. + * rewrite Z.add_1_l. + apply Z.le_succ_l. + rewrite carry_done_bounds in bounded. + apply bounded. + Qed. + + Lemma highest_digit_determines : forall us vs n x, (x < 0) -> + (length us = length base) -> + (n < length us)%nat -> carry_done us -> + (n < length vs)%nat -> carry_done vs -> + BaseSystem.decode' (firstn n base) (firstn n us) + + nth_default 0 base n * x - + BaseSystem.decode' (firstn n base) (firstn n vs) < 0. + Proof. + intros. + eapply Z.le_lt_trans. + apply Z.le_sub_nonneg. + admit. (* TODO : decode' is nonnegative *) + eapply Z.le_lt_trans. + eapply Z.add_le_mono with (q := nth_default 0 base n * -1); [ apply Z.le_refl | ]. + apply Z.mul_le_mono_nonneg_l; try omega. + admit. (* TODO : 0 <= nth_default 0 base n *) + ring_simplify. + apply Z.lt_sub_0. + apply decode_lt_next_digit; auto. + omega. + Qed. + + Lemma Z_compare_decode_step_eq : forall n us vs, + (length us = length base) -> + (length us = length vs) -> + (S n <= length base)%nat -> + (nth_default 0 us n = nth_default 0 vs n) -> + (BaseSystem.decode (firstn (S n) base) us ?= + BaseSystem.decode (firstn (S n) base) vs) = + (BaseSystem.decode (firstn n base) us ?= + BaseSystem.decode (firstn n base) vs). + Proof. + intros until 3; intro nth_default_eq. + destruct (lt_dec n (length us)); try omega. + rewrite firstn_succ, !base_app by omega. + autorewrite with lengths; rewrite Min.min_l by omega. + do 2 (rewrite skipn_nth_default with (d := 0) by omega; + rewrite decode'_cons, decode_base_nil, Z.add_0_r). + rewrite Z.compare_sub, nth_default_eq, Z.add_add_simpl_r_r. + rewrite BaseSystem.decode'_truncate with (us := us). + rewrite BaseSystem.decode'_truncate with (us := vs). + rewrite firstn_length, Min.min_l, <-Z.compare_sub by omega. + reflexivity. + Qed. + + Lemma Z_compare_decode_step_lt : forall n us vs, + (length us = length base) -> + (length us = length vs) -> + (S n <= length base)%nat -> + carry_done us -> carry_done vs -> + (nth_default 0 us n < nth_default 0 vs n) -> + (BaseSystem.decode (firstn (S n) base) us ?= + BaseSystem.decode (firstn (S n) base) vs) = Lt. + Proof. + intros until 5; intro nth_default_lt. + destruct (lt_dec n (length us)). + + rewrite firstn_succ by omega. + rewrite !base_app. + autorewrite with lengths; rewrite Min.min_l by omega. + do 2 (rewrite skipn_nth_default with (d := 0) by omega; + rewrite decode'_cons, decode_base_nil, Z.add_0_r). + rewrite Z.compare_sub. + apply Z.compare_lt_iff. + ring_simplify. + rewrite <-Z.add_sub_assoc. + rewrite <-Z.mul_sub_distr_l. + apply highest_digit_determines; auto; omega. + + rewrite !nth_default_out_of_bounds in nth_default_lt; omega. + Qed. + + Lemma Z_compare_decode_step_neq : forall n us vs, + (length us = length base) -> (length us = length vs) -> + (S n <= length base)%nat -> + carry_done us -> carry_done vs -> + (nth_default 0 us n <> nth_default 0 vs n) -> + (BaseSystem.decode (firstn (S n) base) us ?= + BaseSystem.decode (firstn (S n) base) vs) = + (nth_default 0 us n ?= nth_default 0 vs n). + Proof. + intros. + destruct (Z_dec (nth_default 0 us n) (nth_default 0 vs n)) as [[?|Hgt]|?]; try congruence. + + etransitivity; try apply Z_compare_decode_step_lt; auto. + + match goal with |- (?a ?= ?b) = (?c ?= ?d) => + rewrite (Z.compare_antisym b a); rewrite (Z.compare_antisym d c) end. + apply CompOpp_inj; rewrite !CompOpp_involutive. + apply gt_lt_symmetry in Hgt. + etransitivity; try apply Z_compare_decode_step_lt; auto; omega. + Qed. + + Lemma decode_compare' : forall n us vs, + (length us = length base) -> + (length us = length vs) -> + (n <= length base)%nat -> + carry_done us -> carry_done vs -> + (BaseSystem.decode (firstn n base) us ?= BaseSystem.decode (firstn n base) vs) + = compare' us vs n. + Proof. + induction n; intros. + + cbv [firstn compare']; rewrite !decode_base_nil; auto. + + unfold compare'; fold compare'. + break_if. + - rewrite Z_compare_decode_step_eq by (auto || omega). + apply IHn; auto; omega. + - rewrite Z_compare_decode_step_neq; (auto || omega). + Qed. + + Lemma decode_compare : forall us vs, + (length us = length base) -> carry_done us -> + (length vs = length base) -> carry_done vs -> + Z.compare (BaseSystem.decode base us) (BaseSystem.decode base vs) = compare us vs. + Proof. + unfold compare; intros. + erewrite <-(firstn_all _ base). + + apply decode_compare'; auto; omega. + + assumption. + Qed. + + Transparent isFull'. + Print compare'. + Lemma compare'_succ : forall us j vs, compare' us vs (S j) = + if Z.eq_dec (nth_default 0 us j) (nth_default 0 vs j) + then compare' us vs j + else nth_default 0 us j ?= nth_default 0 vs j. + Proof. + reflexivity. + Qed. + + + Lemma compare'_firstn_r : forall us j vs, (j <= length vs)%nat -> + compare' us vs j = compare' us (firstn j vs) j. + Proof. + induction j; intros; auto. + rewrite !compare'_succ. + rewrite firstn_succ by omega. + rewrite nth_default_app. + autorewrite with lengths; rewrite Min.min_l by omega. + destruct (lt_dec j j); try omega. + rewrite Nat.sub_diag. + rewrite nth_default_cons. + break_if; try reflexivity. + rewrite IHj with (vs := firstn j vs ++ nth_default 0 vs j :: nil) by + (autorewrite with lengths; rewrite Min.min_l; omega). + rewrite firstn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega). + apply IHj; omega. + Qed. + + Lemma isFull'_true_step : forall us j, isFull' us true (S j) = true -> + isFull' us true j = true. + Proof. + simpl; intros ? ? succ_true. + destruct (max_bound (S j) =? nth_default 0 us (S j)); auto. + rewrite isFull'_false in succ_true. + congruence. + Qed. + + Lemma compare'_not_Lt : forall us vs j, j <> 0%nat -> + (forall i, (0 < i < j)%nat -> 0 <= nth_default 0 us i <= nth_default 0 vs i) -> + compare' us vs j <> Lt -> + nth_default 0 vs 0 <= nth_default 0 us 0 /\ + (forall i : nat, (0 < i < j)%nat -> nth_default 0 us i = nth_default 0 vs i). + Proof. + induction j; try congruence. + rewrite compare'_succ. + intros; destruct (eq_nat_dec j 0). + + break_if; subst; split; intros; try omega. + rewrite Z.compare_ge_iff in *; omega. + + break_if. + - split; intros; [ | destruct (eq_nat_dec i j); subst; auto ]; + apply IHj; auto; intros; try omega; + match goal with H : forall i, _ -> 0 <= ?f i <= ?g i |- 0 <= ?f _ <= ?g _ => + apply H; omega end. + - exfalso. rewrite Z.compare_ge_iff in *. + match goal with H : forall i, ?P -> 0 <= ?f i <= ?g i |- _ => + specialize (H j) end; omega. + Qed. + + Lemma nth_default_map_range : forall f n r, (n < r)%nat -> + nth_default 0 (map f (range r)) n = f n. + Proof. + intros. + rewrite map_nth_default with (x := 0%nat) by (autorewrite with lengths; omega). + rewrite nth_default_range by omega. + reflexivity. + Qed. + + Lemma isFull'_compare' : forall us j, j <> 0%nat -> (length us = length base) -> carry_done us -> + (isFull' us true (j - 1) = true <-> + compare' us (map modulus_digit (range j)) j <> Lt). + Proof. + unfold compare; induction j; intros; try congruence. + replace (S j - 1)%nat with j by omega. + (* rewrite isFull'_true_iff by assumption; *) + split; intros. + + simpl. + break_if. + - rewrite compare'_firstn_r by (autorewrite with lengths; omega). + rewrite range_succ, map_app, firstn_app. + autorewrite with lengths. + rewrite Nat.sub_diag, app_nil_r. + rewrite firstn_all by (autorewrite with lengths; reflexivity). + destruct (eq_nat_dec j 0); [ subst; simpl; try congruence | ]. + apply IHj; auto. + apply isFull'_true_step. + replace (S (j - 1)) with j by omega; auto. + - match goal with |- appcontext[?a ?= ?b] => case_eq (a ?= b) end; + intros compare_eq; try congruence. + rewrite Z.compare_lt_iff in compare_eq. + rewrite nth_default_map_range in * by omega. + match goal with H : isFull' _ _ _ = true |- _ => apply isFull'_true_iff in H; auto; destruct H end. + + destruct (eq_nat_dec j 0). + * subst. cbv [modulus_digit] in compare_eq. + break_if; try congruence. omega. + * assert (0 < j <= j)%nat as j_range by omega. + specialize (H3 j j_range). + unfold modulus_digit in n. + break_if; omega. + + apply isFull'_true_iff; try assumption. + match goal with H : compare' _ _ _ <> Lt |- _ => apply compare'_not_Lt in H; [ destruct H as [Hdigit0 Hnonzero] | | ] end. + - rewrite nth_default_map_range in * by omega. + split; [ unfold modulus_digit in *; break_if; omega | ]. + intros i i_range. + assert (0 < i < S j)%nat as i_range' by omega. + specialize (Hnonzero i i_range'). + rewrite nth_default_map_range in * by omega. + unfold modulus_digit in Hnonzero; break_if; omega. + - congruence. + - intros; rewrite nth_default_map_range by omega. + unfold modulus_digit; break_if; try omega. + rewrite <-Z.lt_succ_r with (m := max_bound i). + rewrite max_bound_log_cap; apply carry_done_bounds. + assumption. + Qed. + + Lemma isFull_compare : forall us, (length us = length base) -> carry_done us -> + (isFull us = true <-> + compare us (map modulus_digit (range (length base))) <> Lt). + Proof. + unfold compare, isFull; intros ? lengths_eq. intros. + rewrite lengths_eq. + apply isFull'_compare'; try omega. + assumption. + Qed. + + Lemma isFull_decode : forall us, (length us = length base) -> carry_done us -> + (isFull us = true <-> + (BaseSystem.decode base us ?= BaseSystem.decode base (map modulus_digit (range (length base)))) <> Lt). + Proof. + intros. + rewrite decode_compare; autorewrite with lengths; auto; + [ apply isFull_compare; auto | ]. + rewrite carry_done_bounds; intro i. + destruct (lt_dec i (length base)). + + rewrite nth_default_map_range; auto. + + rewrite nth_default_out_of_bounds by (autorewrite with lengths; omega). + split; auto; omega. + Qed. + + Lemma isFull_false_upper_bound : forall us, (length us = length base) -> carry_done us -> + isFull us = false -> + BaseSystem.decode base us < modulus. + Proof. + intros. + destruct (Z_lt_dec (BaseSystem.decode base us) modulus) as [? | nlt_modulus]; + [assumption | exfalso]. + apply Z.compare_nlt_iff in nlt_modulus. + rewrite <-decode_map_modulus_digit in nlt_modulus at 2. + apply isFull_decode in nlt_modulus; try assumption; congruence. + Qed. + +(* Road map: + * x prove isFull us = false -> us < modulus + * _ prove (carry_full^3 us) < 2 * modulus + *) + + Definition twoKMinusOne := mapi (fun _ => max_bound i + + Lemma bounded_digits_lt_2modulus : forall us, (length us = length base) -> carry_done us -> + BaseSystem.decode base us < 2 ^ k. + Proof. + unfold k. + SearchAbout sum_firstn limb_widths. + Qed. + + Lemma bounded_digits_lt_2modulus : forall us, (length us = length base) -> carry_done us -> + BaseSystem.decode base us < 2 * modulus. + Proof. + + + SearchAbout (carry_full (carry_full (carry_full _))). + + + Lemma freeze_minimal_rep : forall us, minimal_rep (freeze us). + Proof. + unfold minimal_rep, freeze. + intros. + symmetry. apply Z.mod_small. + split. + + admit. + + break_if. + remember (carry_full (carry_full (carry_full us))) as cf3us. + rewrite decode_subtract_modulus_elementwise. + apply isFull_true_ + Qed. + Hint Resolve freeze_minimal_rep. + + Lemma minimal_rep_unique_if_bounded : forall us vs, + minimal_rep us -> minimal_rep vs -> + (forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i) -> + (forall i, 0 <= nth_default 0 vs i < 2 ^ log_cap i) -> + us = vs. + Proof. + + Admitted. + + Lemma freeze_canonical : forall us vs x y, c_carry_constraint -> + pre_carry_bounds us -> (length us = length base) -> rep us x -> + pre_carry_bounds vs -> (length vs = length base) -> rep vs y -> + (x mod modulus = y mod modulus) -> + freeze us = freeze vs. + Proof. + unfold rep; intros. + apply minimal_rep_unique_if_bounded; auto. + intros. apply freeze_in_bounds; auto. + intros. apply freeze_in_bounds; auto. + Qed. End CanonicalizationProofs.
\ No newline at end of file diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v index 1a7b3316e..10bbdf33d 100644 --- a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v +++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v @@ -89,6 +89,13 @@ Section PseudoMersenneBaseParamProofs. - rewrite IHl by auto; ring. Qed. + Lemma limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w. + Proof. + intros. + apply Z.lt_le_incl. + auto using limb_widths_pos. + Qed. + Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n. Proof. unfold sum_firstn; intros. diff --git a/src/ModularArithmetic/PseudoMersenneBaseParams.v b/src/ModularArithmetic/PseudoMersenneBaseParams.v index 3914d6219..e20a7ed09 100644 --- a/src/ModularArithmetic/PseudoMersenneBaseParams.v +++ b/src/ModularArithmetic/PseudoMersenneBaseParams.v @@ -7,7 +7,7 @@ Definition sum_firstn l n := fold_right Z.add 0 (firstn n l). Class PseudoMersenneBaseParams (modulus : Z) := { limb_widths : list Z; - limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w; + limb_widths_pos : forall w, In w limb_widths -> 0 < w; limb_widths_nonnil : limb_widths <> nil; limb_widths_good : forall i j, (i + j < length limb_widths)%nat -> sum_firstn limb_widths (i + j) <= |