diff options
-rw-r--r-- | src/Experiments/NewPipeline/Arithmetic.v | 574 |
1 files changed, 533 insertions, 41 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index f3c6fa709..5af73875b 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -651,6 +651,68 @@ Module Positional. unique pose proof (Z_div_exact_full_2 a b ltac:(auto) ltac:(auto)) end; nsatz. Qed. Hint Rewrite weight_place : push_eval. + Lemma weight_add_mod (weight_mul : forall i, weight (S i) mod weight i = 0) i j + : weight (i + j) mod weight i = 0. + Proof using weight_nz. + rewrite Nat.add_comm. + induction j as [|[|j] IHj]; cbn [Nat.add] in *; + eauto using Z_mod_same_full, Z.mod_mod_trans. + Qed. + Lemma weight_mul_iff (weight_pos : forall i, 0 < weight i) (weight_mul : forall i, weight (S i) mod weight i = 0) i j + : weight i mod weight j = 0 <-> ((j < i)%nat \/ forall k, (i <= k <= j)%nat -> weight k = weight j). + Proof using weight_nz. + split. + { destruct (dec (j < i)%nat); [ left; omega | intro H; right; revert H ]. + assert (j = (j - i) + i)%nat by omega. + generalize dependent (j - i)%nat; intro jmi; intros ? H0. + subst j. + destruct jmi as [|j]; [ intros k ?; assert (k = i) by omega; subst; f_equal; omega | ]. + induction j as [|j IH]; cbn [Nat.add] in *. + { intros k ?; assert (k = i \/ k = S i) by omega; destruct_head'_or; subst; + eauto using Z.mod_mod_0_0_eq_pos. } + { specialize_by omega. + { pose proof (weight_mul (S (j + i))) as H. + specialize_by eauto using Z.mod_mod_trans with omega. + intros k H'; destruct (dec (k = S (S (j + i)))); subst; + try rewrite IH by eauto using Z.mod_mod_trans with omega; + eauto using Z.mod_mod_trans, Z.mod_mod_0_0_eq_pos with omega. + rewrite (IH i) in * by omega. + eauto using Z.mod_mod_trans, Z.mod_mod_0_0_eq_pos with omega. } } } + { destruct (dec (j < i)%nat) as [H|H]; [ intros _ | intros [H'|H']; try omega ]. + { assert (i = j + (i - j))%nat by omega. + generalize dependent (i - j)%nat; intro imj; intros. + subst i. + apply weight_add_mod; auto. } + { erewrite H', Z_mod_same_full by omega; omega. } } + Qed. + Lemma weight_div_from_pos_mul (weight_pos : forall i, 0 < weight i) (weight_mul : forall i, weight (S i) mod weight i = 0) + : forall i, 0 < weight (S i) / weight i. + Proof using weight_nz. + intro i; generalize (weight_mul i) (weight_mul (S i)). + Z.div_mod_to_quot_rem; nia. + Qed. + Lemma place_weight n (weight_pos : forall i, 0 < weight i) (weight_mul : forall i, weight (S i) mod weight i = 0) + (weight_unique : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j -> i = j) + i x + : (place (weight i, x) n) = (Nat.min i n, (weight i / weight (Nat.min i n)) * x). + Proof using weight_0 weight_nz. + cbv [place]. + induction n as [|n IHn]; cbn; [ destruct i; cbn; rewrite ?weight_0; autorewrite with zsimplify_const; reflexivity | ]. + destruct (dec (i < S n)%nat); + break_innermost_match; cbn [fst snd] in *; Z.ltb_to_lt; [ | rewrite IHn | | rewrite IHn ]; + break_innermost_match; + rewrite ?Min.min_l in * by omega; + rewrite ?Min.min_r in * by omega; + eauto with omega. + { rewrite weight_mul_iff in * by auto. + destruct_head'_or; try omega. + assert (S n = i). + { apply weight_unique; try omega. + symmetry; eauto with omega. } + subst; reflexivity. } + { rewrite weight_mul_iff in * by auto. + exfalso; intuition eauto with omega. } + Qed. Definition from_associational n (p:list (Z*Z)) := List.fold_right (fun t ls => @@ -669,6 +731,29 @@ Module Positional. Proof using Type. cbv [from_associational Let_In]. apply fold_right_invariant; intros; distr_length. Qed. Hint Rewrite length_from_associational : distr_length. + Lemma nth_default_from_associational v n p i (n_nz : n <> 0%nat) : + nth_default v (from_associational n p) i + = fold_right Z.add (nth_default v (zeros n) i) + (map (fun t => dlet p : nat * Z := place t (pred n) in + if dec (fst p = i) then snd p else 0) p). + Proof. + subst; cbv [from_associational Let_In]. + induction p as [|p ps IHps]; [ reflexivity | ]; cbn [fold_right map]; rewrite <- IHps; clear IHps. + cbv [add_to_nth]. + match goal with + | [ |- context[place ?p ?i] ] + => pose proof (place_in_range p i) + end. + rewrite update_nth_nth_default_full; break_match; try omega; + rewrite nth_default_out_of_bounds by omega; try omega. + match goal with + | [ H : context[length (fold_right ?f ?v ?ps)] |- _ ] + => replace (length (fold_right f v ps)) with (length v) in H + by (apply fold_right_invariant; intros; distr_length; auto) + end. + distr_length; auto. + Qed. + Definition extend_to_length (n_in n_out : nat) (p:list Z) : list Z := p ++ zeros (n_out - n_in). Lemma eval_extend_to_length n_in n_out p : @@ -760,6 +845,127 @@ Module Positional. apply eval_to_associational. Qed. Hint Rewrite @eval_carry : push_eval. + (** TODO: figure out a way to make this proof shorter and faster *) + Lemma nth_default_carry upper n m index p + (weight_mul : forall i, weight (S i) mod weight i = 0) + (weight_pos : forall i, 0 < weight i) + (weight_unique : forall i j, (i <= upper)%nat -> (j <= upper)%nat -> weight i = weight j -> i = j) + (Hn : (n <= upper)%nat) + (Hm : (0 < m <= upper)%nat) + (Hnm : (n <= m)%nat) + (Hidx : (index <= upper)%nat) : + length p = n -> + forall i, nth_default 0 (carry n m index p) i + = if dec (m <= i)%nat + then 0 + else if dec (i = S index) + then nth_default 0 p i + ((nth_default 0 p index) / (weight (S index) / weight index)) + else if dec (i = index) + then if dec (S index <> n \/ n <> m) + then ((nth_default 0 p i) mod (weight (S index) / weight index)) + else nth_default 0 p i + else nth_default 0 p i. + Proof using weight_0 weight_nz. + assert (weight_unique_iff : forall i j, (i <= upper)%nat -> (j <= upper)%nat -> weight i = weight j <-> i = j) + by (split; subst; auto). + pose proof (weight_div_from_pos_mul weight_pos weight_mul) as weight_div_pos. + assert (weight_div_nz : forall i, weight (S i) / weight i <> 0) by (intro i; specialize (weight_div_pos i); omega). + intro; subst. + intro i. + destruct (dec (m <= i)%nat) as [Hmi|Hmi]; + [ rewrite (@nth_default_out_of_bounds _ i (carry _ _ _ _)) by (distr_length; omega); reflexivity | ]. + cbv [carry to_associational Associational.carry Let_In Associational.carryterm]. + rewrite combine_map_l, flat_map_map; cbn [fst snd]. + rewrite nth_default_from_associational, map_flat_map by omega; cbn [map]. + cbv [zeros]; rewrite nth_default_repeat. + replace (if (dec (i < m)%nat) then 0 else 0) with 0 by (break_match; reflexivity). + set (init := 0) at 1. + lazymatch goal with |- ?LHS = ?RHS => rewrite <- (Z.add_0_l RHS : init + RHS = RHS) end. + clearbody init. + revert Hn i init Hmi Hnm Hidx. + rewrite <- (rev_involutive p); generalize (rev p); clear p; intro p; rewrite rev_length. + induction p as [|p ps IHps]; cbn [length]; intros Hn i init Hmi Hnm Hidx. + { cbn; cbv [zeros]; break_innermost_match; cbn; + rewrite ?nth_default_repeat, ?nth_default_nil; break_innermost_match; autorewrite with zsimplify_const; reflexivity. } + { specialize_by omega. + rewrite seq_snoc, rev_cons, combine_app_samelength by distr_length. + rewrite flat_map_app, fold_right_app, IHps by omega; clear IHps. + cbn [combine fold_right fst snd flat_map map]. + rewrite Nat.add_0_l. + cbv [Let_In]; cbn [fst snd]. + rewrite ?nth_default_app; distr_length. + destruct (dec (i = index)), (dec (i = S index)); try (subst; omega). + { all:subst; break_innermost_match; Z.ltb_to_lt; + match goal with + | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega + end; destruct_head'_or; try (subst; omega). + all:repeat first [ progress cbn [fst snd app map fold_right] + | progress Z.ltb_to_lt + | progress subst + | progress destruct_head'_or + | progress rewrite ?Z.mul_div_eq_full, ?weight_mul, ?Z.sub_0_r by eauto with omega + | progress rewrite ?place_weight by eauto with omega + | rewrite !Nat.sub_diag + | rewrite !Min.min_l by omega + | rewrite !nth_default_cons + | rewrite Z.div_same by eauto with omega + | progress break_innermost_match + | progress autorewrite with zsimplify_const + | lia + | match goal with + | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega + | [ |- context[nth_default ?d ?ls ?i] ] => rewrite (@nth_default_out_of_bounds _ i ls d) by (distr_length; omega) + | [ H : ?x = ?x |- _ ] => clear H + end + | progress handle_min_max_for_omega_case ]. } + { subst; break_innermost_match; Z.ltb_to_lt; + match goal with + | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega + end; destruct_head'_or; try (subst; omega). + all:repeat first [ progress cbn [fst snd app map fold_right] + | progress Z.ltb_to_lt + | progress subst + | progress destruct_head'_or + | progress rewrite ?Z.mul_div_eq_full, ?weight_mul, ?Z.sub_0_r by eauto with omega + | progress rewrite ?place_weight by eauto with omega + | rewrite !Nat.sub_diag + | rewrite !Min.min_l by omega + | rewrite !nth_default_cons + | rewrite Z.div_same by eauto with omega + | progress break_innermost_match + | progress autorewrite with zsimplify_const + | lia + | match goal with + | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega + | [ |- context[nth_default ?d ?ls ?i] ] => rewrite (@nth_default_out_of_bounds _ i ls d) by (distr_length; omega) + | [ H : ?x = ?x |- _ ] => clear H + end + | progress handle_min_max_for_omega_case ]. } + { subst; break_innermost_match; Z.ltb_to_lt; + match goal with + | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega + end; destruct_head'_or; try (subst; omega). + all:repeat first [ progress cbn [fst snd app map fold_right] + | progress Z.ltb_to_lt + | progress subst + | progress destruct_head'_or + | progress rewrite ?Z.mul_div_eq_full, ?weight_mul, ?Z.sub_0_r by eauto with omega + | progress rewrite ?place_weight by eauto with omega + | rewrite !Nat.sub_diag + | rewrite !Min.min_l by omega + | rewrite !nth_default_cons + | rewrite Z.div_same by eauto with omega + | progress break_innermost_match + | progress autorewrite with zsimplify_const + | lia + | match goal with + | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega + | [ |- context[nth_default ?d ?ls ?i] ] => rewrite (@nth_default_out_of_bounds _ i ls d) by (distr_length; omega) + | [ H : ?x = ?x |- _ ] => clear H + end + | progress handle_min_max_for_omega_case ]. } } + Qed. + Definition carry_reduce n (s:Z) (c:list (Z * Z)) (index:nat) (p : list Z) := from_associational @@ -823,6 +1029,230 @@ Module Positional. intros; cbv [chained_carries_no_reduce]; induction (rev idxs) as [|x xs IHxs]; cbn [fold_right]; distr_length. Qed. Hint Rewrite @length_chained_carries_no_reduce : distr_length. + (** TODO: figure out a way to make this proof shorter and faster *) + Lemma nth_default_chained_carries_no_reduce_app n m inp1 inp2 + (weight_mul : forall i, weight (S i) mod weight i = 0) + (weight_pos : forall i, 0 < weight i) + (weight_div : forall i : nat, 0 < weight (S i) / weight i) + (weight_unique : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j -> i = j) : + length inp1 = m -> (length inp1 + length inp2 = n)%nat + -> (List.length inp2 <> 0%nat \/ 0 <= eval m inp1 < weight m) + -> forall i, + nth_default 0 (chained_carries_no_reduce n (inp1 ++ inp2) (seq 0 m)) i + = if dec (i < m)%nat + then ((eval m inp1) mod weight (S i)) / weight i + else if dec (i = m) + then match inp2 with + | nil => 0 + | cons x xs + => x + (eval m inp1) / weight m + end + else nth_default 0 inp2 (i - m). + Proof using weight_0 weight_nz. + intro; subst m. + rewrite <- (rev_involutive inp1); generalize (List.rev inp1); clear inp1; intro inp1; rewrite rev_length. + revert inp2; induction inp1 as [|x xs IHxs]; intros. + { destruct inp2; cbn; autorewrite with zsimplify_const; intros; destruct i; reflexivity. } + destruct (lt_dec i n); + [ + | break_match; cbn [List.length] in *; try lia; + rewrite ?nth_default_out_of_bounds by (repeat autorewrite with distr_length; lia); + reflexivity ]. + cbv [chained_carries_no_reduce] in *. + repeat first [ progress cbn [List.length List.app List.rev fold_right] in * + | reflexivity + | assumption + | progress intros + | rewrite <- List.app_assoc + | rewrite seq_snoc + | rewrite rev_unit + | rewrite Nat.add_0_l + | rewrite eval_snoc_S in * by distr_length + | rewrite app_length + | rewrite rev_length + | erewrite nth_default_carry; try eassumption + | rewrite !IHxs; clear IHxs + | lia + | match goal with + | [ |- length (fold_right _ ?p (rev ?idxs)) = ?n ] + => apply (length_chained_carries_no_reduce n p idxs) + | [ |- context[_ mod weight (S ?n) / weight ?n] ] + => rewrite (Z.div_mod' (weight (S n)) (weight n)), weight_mul, Z.add_0_r, <- Z.mod_pull_div, ?Z.div_mul, ?Z.div_add', ?Z.mul_div_eq', ?weight_mul, ?Z.sub_0_r + by solve [ assumption + | now apply Z.lt_le_incl, weight_div + | now apply Z.lt_gt, weight_pos + | (idtac + symmetry); now apply Z.lt_neq, weight_pos ] + | [ |- context[?x + ?y] ] + => match goal with + | [ |- context[y + x] ] + => progress replace (y + x) with (x + y) by lia + end + end ]. + break_match; try (exfalso; lia). + all: repeat first [ rewrite nth_default_app + | rewrite nth_default_carry + | rewrite Nat.sub_diag + | rewrite minus_S_diag + | rewrite Nat.sub_succ_r + | rewrite nth_default_cons + | rewrite nth_default_cons_S + | progress subst + | now apply weight_0 + | now apply weight_mul + | now apply weight_pos + | reflexivity + | progress intros + | (idtac + symmetry); now apply Z.lt_neq, weight_pos + | rewrite Z.div_add' by ((idtac + symmetry); now apply Z.lt_neq, weight_pos) + | progress destruct_head'_and + | progress destruct_head'_or + | progress cbn [List.length] in * + | match goal with + | [ |- context[?x + ?y] ] + => match goal with + | [ |- context[y + x] ] + => progress replace (y + x) with (x + y) by lia + end + | [ H : List.length ?x = 0%nat |- _ ] => is_var x; destruct x + | [ H : not (or _ _) |- _ ] => apply Decidable.not_or in H + | [ H : ?x = ?x |- _ ] => clear H + | [ H : not (?x < ?x) |- _ ] => clear H + | [ H : not (?x < ?x)%nat |- _ ] => clear H + | [ H : not (S ?x < ?x)%nat |- _ ] => clear H + | [ H : ~(S ?x + _ <= ?x)%nat |- _ ] => clear H + | [ H : (?x < S ?x + _)%nat |- _ ] => clear H + | [ H : ?x <> S ?x |- _ ] => clear H + | [ H : ?x <> (?x + ?y)%nat |- _ ] => assert (0 < y)%nat by lia; clear H + | [ H : (?x < ?x + ?y)%nat |- _ ] => assert (0 < y)%nat by lia; clear H + | [ H : ~(?x + ?y <= ?x)%nat |- _ ] => assert (0 < y)%nat by lia; clear H + | [ H : ~(?x <> ?y) |- _ ] => assert (x = y) by lia; clear H + | [ H : (?x = ?x + ?y)%nat |- _ ] => assert (y = 0%nat) by lia; clear H + | [ H : ~(?x <= ?y)%nat |- _ ] => assert (y < x)%nat by lia; clear H + | [ H : ~(?x < ?y)%nat |- _ ] => assert (y <= x)%nat by lia; clear H + | [ H : (?x <= ?y)%nat, H' : ?x <> ?y |- _ ] => assert (x < y)%nat by lia; clear H H' + | [ H : (?x <= ?y)%nat, H' : ?y <> ?x |- _ ] => assert (x < y)%nat by lia; clear H H' + | [ H : (?x < ?y)%nat |- context[nth_default _ _ (?y - ?x)] ] + => destruct (y - x)%nat eqn:? + | [ |- context[nth_default _ (_ :: _) ?n] ] => is_var n; destruct n + | [ H : ?T, H' : ?T |- _ ] => clear H' + | [ |- (?x + ?y) mod ?z = (?y + ?x) mod ?z ] => apply f_equal2 + | [ |- ?x + _ = ?x + _ ] => apply f_equal + | [ H0 : 0 <= ?e + ?w * ?x, H1 : ?e + ?w * ?x < ?w' + |- ?x + ?e / ?w = (?x + ?e / ?w) mod (?w' / ?w) ] + => rewrite (Z.mod_small (x + e / w) (w' / w)) + | [ H : (?i < ?n)%nat |- context[(_ + weight ?n * _) / weight ?i] ] + => rewrite (Z.div_mod (weight n) (weight (S i))), weight_multiples_full, Z.add_0_r, + (Z.div_mod (weight (S i)) (weight i)), weight_mul, Z.add_0_r, + <- !Z.mul_assoc, Z.div_add', ?Z.div_mul', ?Z.mul_div_eq_full, ?weight_mul, ?Z.sub_0_r + by solve [ assumption + | now apply Z.lt_le_incl, weight_div + | now apply Z.lt_gt, weight_pos + | now apply Nat.lt_le_incl + | (idtac + symmetry); now apply Z.lt_neq, weight_pos ]; + push_Zmod; pull_Zmod + end + | progress autorewrite with distr_length in * + | lia + | progress autorewrite with zsimplify_const + | break_innermost_match_step + | match goal with + | [ |- context[weight (S ?n) / weight ?n] ] + => unique pose proof (@weight_mul n) + end + | Z.div_mod_to_quot_rem; nia ]. + Qed. + + Lemma nth_default_chained_carries_no_reduce n inp + (weight_mul : forall i, weight (S i) mod weight i = 0) + (weight_pos : forall i, 0 < weight i) + (weight_div : forall i : nat, 0 < weight (S i) / weight i) + (weight_unique : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j -> i = j) : + length inp = n -> 0 <= eval n inp < weight n + -> forall i, + nth_default 0 (chained_carries_no_reduce n inp (seq 0 n)) i + = ((eval n inp) mod weight (S i)) / weight i. + Proof using weight_0 weight_nz. + pose proof (weight_divides_full weight ltac:(assumption) ltac:(assumption)) as weight_div_full. + pose proof (weight_multiples_full weight ltac:(assumption) ltac:(assumption)) as weight_mul_full. + assert (weight_le_full : forall j i, (i <= j)%nat -> weight i <= weight j) + by (intros j i pf; specialize (weight_div_full j i pf); specialize (weight_mul_full j i pf); Z.div_mod_to_quot_rem; nia). + intros ? ? i. + pose proof (weight_le_full (S n) n ltac:(lia)). + pose proof (weight_le_full (S i) i ltac:(lia)). + pose proof (weight_le_full i n). + intros; rewrite <- (app_nil_r inp). + rewrite (@nth_default_chained_carries_no_reduce_app n n inp nil), app_nil_r by (cbn [List.length]; auto with lia). + break_innermost_match; try reflexivity; rewrite ?nth_default_nil. + all: rewrite Z.mod_small by lia. + all: rewrite Z.div_small by lia. + all: reflexivity. + Qed. + + Lemma nth_default_chained_carries_no_reduce_pred n inp + (weight_mul : forall i, weight (S i) mod weight i = 0) + (weight_pos : forall i, 0 < weight i) + (weight_div : forall i : nat, 0 < weight (S i) / weight i) + (weight_unique : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j -> i = j) : + length inp = n -> 0 <= eval n inp < weight n + -> forall i, + nth_default 0 (chained_carries_no_reduce n inp (seq 0 (pred n))) i + = ((eval n inp) mod weight (S i)) / weight i. + Proof using weight_0 weight_nz. + pose proof (weight_divides_full weight ltac:(assumption) ltac:(assumption)) as weight_div_full. + pose proof (weight_multiples_full weight ltac:(assumption) ltac:(assumption)) as weight_mul_full. + assert (weight_le_full : forall j i, (i <= j)%nat -> weight i <= weight j) + by (intros j i pf; specialize (weight_div_full j i pf); specialize (weight_mul_full j i pf); Z.div_mod_to_quot_rem; nia). + destruct n as [|n]; [ now apply nth_default_chained_carries_no_reduce | ]. + intros ? ? i. + pose proof (weight_le_full (S n) n ltac:(lia)). + pose proof (weight_le_full (S i) i ltac:(lia)). + pose proof (weight_le_full i n). + pose proof (weight_le_full (S i) (S n)). + pose proof (weight_le_full i (S n)). + cbn [pred]. + revert dependent inp; intro inp. + rewrite <- (rev_involutive inp); generalize (rev inp); clear inp; intro inp. + rewrite rev_length; intros. + destruct inp as [|x inp]; cbn [List.length List.rev] in *; [ lia | ]. + rewrite (@nth_default_chained_carries_no_reduce_app (S n) n (List.rev inp) (x::nil)) by (cbn [List.length]; autorewrite with distr_length; auto with lia). + rewrite eval_snoc_S in * by distr_length. + break_innermost_match; try reflexivity. + all: repeat first [ progress autorewrite with zsimplify_const + | reflexivity + | progress Z.rewrite_mod_small + | rewrite Z.div_add' by ((idtac + symmetry); now apply Z.lt_neq, weight_pos) + | lia + | match goal with + | [ |- context[_ mod weight (S ?n) / weight ?n] ] + => rewrite (Z.div_mod' (weight (S n)) (weight n)), weight_mul, Z.add_0_r, <- Z.mod_pull_div, ?Z.div_mul, ?Z.div_add', ?Z.mul_div_eq', ?weight_mul, ?Z.sub_0_r + by solve [ assumption + | now apply Z.lt_le_incl, weight_div + | now apply Z.lt_gt, weight_pos + | (idtac + symmetry); now apply Z.lt_neq, weight_pos ] + | [ |- context[(_ + weight ?n * _) / weight ?i] ] + => rewrite (Z.div_mod (weight n) (weight (S i))), weight_multiples_full, Z.add_0_r, + (Z.div_mod (weight (S i)) (weight i)), weight_mul, Z.add_0_r, + <- !Z.mul_assoc, Z.div_add', ?Z.div_mul', ?Z.mul_div_eq_full, ?weight_mul, ?Z.sub_0_r + by solve [ assumption + | now apply Z.lt_le_incl, weight_div + | now apply Z.lt_gt, weight_pos + | now apply Nat.lt_le_incl + | (idtac + symmetry); now apply Z.lt_neq, weight_pos ]; + push_Zmod; pull_Zmod + end + | rewrite nth_default_cons + | rewrite nth_default_cons_S + | rewrite nth_default_nil + | rewrite Z.div_small by lia + | lia + | match goal with + | [ H : ~(?x < ?y)%nat |- _ ] => assert (y <= x)%nat by lia; clear H + | [ H : (?x <= ?y)%nat, H' : ?x <> ?y |- _ ] => assert (x < y)%nat by lia; clear H H' + | [ H : (?x <= ?y)%nat, H' : ?y <> ?x |- _ ] => assert (x < y)%nat by lia; clear H H' + | [ H : (?x < ?y)%nat |- context[nth_default _ _ (?y - ?x)] ] + => destruct (y - x)%nat eqn:? + end ]. + Qed. (* Reverse of [eval]; translate from Z to basesystem by putting everything in first digit and then carrying. *) @@ -976,7 +1406,7 @@ Module Positional. End select. End Positional. (* Hint Rewrite disappears after the end of a section *) -Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_encode_no_reduce @length_sub @length_opp @length_select @length_zselect @length_select_min @length_extend_to_length @length_drop_high_to_length : distr_length. +Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_carry @length_chained_carries @length_chained_carries_no_reduce @length_encode @length_encode_no_reduce @length_sub @length_opp @length_select @length_zselect @length_select_min @length_extend_to_length @length_drop_high_to_length : distr_length. Hint Rewrite @eval_zeros @eval_nil @eval_snoc_S @eval_select @eval_zselect @eval_extend_to_length using solve [auto; distr_length]: push_eval. Section Positional_nonuniform. Context (weight weight' : nat -> Z). @@ -1082,6 +1512,19 @@ Section mod_ops. Z.div_mod_to_quot_rem_in_goal; nia. Qed. + Lemma weight_unique_iff : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j <-> i = j. + Proof using limbwidth_good. + clear Hn_nz; clear dependent c. + cbv [weight]; split; intro H'; subst; trivial; []. + apply (f_equal (fun x => limbwidth_den * (- Z.log2 x))) in H'. + rewrite !Z.log2_pow2, !Z.opp_involutive in H' by (Z.div_mod_to_quot_rem; nia). + Z.div_mod_to_quot_rem. + destruct i as [|i], j as [|j]; autorewrite with zsimplify_const in *; [ reflexivity | exfalso; nia.. | ]. + nia. + Qed. + Lemma weight_unique : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j -> i = j. + Proof using limbwidth_good. apply weight_unique_iff. Qed. + Derive carry_mulmod SuchThat (forall (f g : list Z) (Hf : length f = n) @@ -2355,6 +2798,31 @@ Module BaseConversion. rewrite eval_from_associational; auto. Qed. + Lemma length_convert_bases sn dn p + : length (convert_bases sn dn p) = dn. + Proof using Type. + cbv [convert_bases]; now repeat autorewrite with distr_length. + Qed. + Hint Rewrite length_convert_bases : distr_length. + + Lemma convert_bases_partitions sn dn p + (dw_unique : forall i j : nat, (i <= dn)%nat -> (j <= dn)%nat -> dw i = dw j -> i = j) + (p_bounded : 0 <= eval sw sn p < dw dn) + : convert_bases sn dn p = partition dw dn (eval sw sn p). + Proof using dwprops. + apply list_elementwise_eq; intro i. + destruct (lt_dec i dn); [ | now rewrite !nth_error_length_error by distr_length ]. + erewrite !(@nth_error_Some_nth_default _ _ 0) by (break_match; distr_length). + apply f_equal. + cbv [convert_bases partition]. + unshelve erewrite map_nth_default, nth_default_chained_carries_no_reduce_pred; + repeat first [ progress autorewrite with distr_length push_eval + | rewrite eval_from_associational, eval_to_associational + | rewrite nth_default_seq_inbounds + | apply dwprops + | destruct dwprops; now auto with zarith ]. + Qed. + Hint Rewrite @Rows.eval_from_associational @Associational.eval_carry @@ -2487,6 +2955,7 @@ Module BaseConversion. Qed. End mul_converted. End BaseConversion. + Hint Rewrite length_convert_bases : distr_length. (* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *) Section widemul. @@ -2702,6 +3171,9 @@ Section freeze_mod_ops. Definition wprops_bytes := (@wprops 8 1 ltac:(clear; lia)). Local Notation wprops := (@wprops limbwidth_num limbwidth_den limbwidth_good). + Local Notation wunique := (@weight_unique limbwidth_num limbwidth_den limbwidth_good). + Local Notation wunique_bytes := (@weight_unique 8 1 ltac:(clear; lia)). + Local Hint Immediate (wprops). Local Hint Immediate (wprops_bytes). Local Hint Immediate (weight_0 wprops). @@ -2712,28 +3184,31 @@ Section freeze_mod_ops. Local Hint Immediate (weight_positive wprops_bytes). Local Hint Immediate (weight_multiples wprops_bytes). Local Hint Immediate (weight_divides wprops_bytes). + Local Hint Immediate (wunique) (wunique_bytes). + Local Hint Resolve (wunique) (wunique_bytes). Local Hint Resolve Z.positive_is_nonzero Z.lt_gt. Definition bytes_n := Eval cbv [Qceiling Qdiv inject_Z Qfloor Qmult Qopp Qnum Qden Qinv Pos.mul] in Z.to_nat (Qceiling (Z.log2_up (weight n) / 8)). - Definition to_bytes' (v : list Z) + Lemma weight_bytes_weight_matches + : weight n <= bytes_weight bytes_n. + Proof using limbwidth_good. + clear -limbwidth_good. + cbv [weight bytes_n]. + autorewrite with zsimplify_const. + rewrite Z.log2_up_pow2, !Z2Nat.id, !Z.opp_involutive by (Z.div_mod_to_quot_rem; nia). + Z.peel_le. + Z.div_mod_to_quot_rem; nia. + Qed. + + Definition to_bytes (v : list Z) := BaseConversion.convert_bases weight bytes_weight n bytes_n v. Definition from_bytes (v : list Z) := BaseConversion.convert_bases bytes_weight weight bytes_n n v. - (** TODO: We should probably prove that BaseConversion.convert_bases - partitions, so that we don't end up doing a needless [flatten ∘ - from_associational ∘ to_associational] just be be able to prove - that the result partitions. See - https://github.com/JasonGross/fiat-crypto/tree/zzz-wip-better-arith-proofs - for some partial work in this direction. *) - Definition to_bytes (f : list Z) : list Z - := let v := to_bytes' f in - fst (Rows.flatten bytes_weight bytes_n (Rows.from_associational bytes_weight bytes_n (Positional.to_associational bytes_weight bytes_n v))). - Definition freeze_to_bytesmod (f : list Z) : list Z := to_bytes (freeze weight n (Z.ones bitwidth) m_enc f). @@ -2768,48 +3243,29 @@ Section freeze_mod_ops. Z.div_mod_to_quot_rem; nia. Qed. - Lemma eval_to_bytes_mod + Lemma eval_to_bytes : forall (f : list Z) (Hf : length f = n), - eval bytes_weight bytes_n (to_bytes f) = eval weight n f mod (bytes_weight bytes_n). + eval bytes_weight bytes_n (to_bytes f) = eval weight n f. Proof using limbwidth_good Hn_nz. generalize wprops wprops_bytes; clear -Hn_nz limbwidth_good. intros. cbv [to_bytes]. - rewrite Rows.flatten_mod by (assumption || apply Rows.length_from_associational). - rewrite Rows.eval_from_associational by eauto using bytes_nz with omega. - rewrite eval_to_associational. - cbv [to_bytes']. rewrite BaseConversion.eval_convert_bases by (auto using bytes_nz; distr_length; auto using wprops). reflexivity. Qed. - Lemma eval_to_bytes + Lemma to_bytes_partitions : forall (f : list Z) (Hf : length f = n) (Hf_small : 0 <= eval weight n f < weight n), - eval bytes_weight bytes_n (to_bytes f) = eval weight n f. - Proof using Hn_nz limbwidth_good. - generalize bytes_n_big. clear -Hn_nz limbwidth_good. - intros; rewrite eval_to_bytes_mod by assumption. - rewrite Z.mod_small by omega; reflexivity. - Qed. - - Lemma to_bytes_partitions - : forall (f : list Z) - (Hf : length f = n), to_bytes f = Partition.partition bytes_weight bytes_n (Positional.eval weight n f). Proof using Hn_nz limbwidth_good. clear -Hn_nz limbwidth_good. intros; cbv [to_bytes]. - rewrite Rows.flatten_correct by (apply wprops_bytes || apply Rows.length_from_associational). - rewrite Rows.eval_from_associational by eauto using bytes_nz with omega. - rewrite eval_to_associational. - cbv [to_bytes']. - rewrite BaseConversion.eval_convert_bases - by (auto using wprops_bytes, bytes_nz; distr_length; auto using wprops). - reflexivity. + pose proof weight_bytes_weight_matches. + apply BaseConversion.convert_bases_partitions; eauto; lia. Qed. Lemma eval_to_bytesmod @@ -2856,17 +3312,53 @@ Section freeze_mod_ops. intros; now apply eval_freeze_to_bytesmod_and_partitions. Qed. + Lemma eval_from_bytes + : forall (f : list Z) + (Hf : length f = bytes_n), + eval weight n (from_bytes f) = eval bytes_weight bytes_n f. + Proof using limbwidth_good Hn_nz. + generalize wprops wprops_bytes; clear -Hn_nz limbwidth_good. + intros. + cbv [from_bytes]. + rewrite BaseConversion.eval_convert_bases + by (auto using bytes_nz; distr_length; auto using wprops). + reflexivity. + Qed. + + Lemma from_bytes_partitions + : forall (f : list Z) + (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n), + from_bytes f = Partition.partition weight n (Positional.eval bytes_weight bytes_n f). + Proof using limbwidth_good. + clear -limbwidth_good. + intros; cbv [from_bytes]. + pose proof weight_bytes_weight_matches. + apply BaseConversion.convert_bases_partitions; eauto; lia. + Qed. + Lemma eval_from_bytesmod : forall (f : list Z) (Hf : length f = bytes_n), eval weight n (from_bytesmod f) = eval bytes_weight bytes_n f. - Proof using Hn_nz limbwidth_good. - cbv [from_bytesmod from_bytes]; intros. - rewrite BaseConversion.eval_convert_bases by eauto using wprops. - reflexivity. + Proof using Hn_nz limbwidth_good. apply eval_from_bytes. Qed. + + Lemma from_bytesmod_partitions + : forall (f : list Z) + (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n), + from_bytesmod f = Partition.partition weight n (Positional.eval bytes_weight bytes_n f). + Proof using limbwidth_good. apply from_bytes_partitions. Qed. + + Lemma eval_from_bytesmod_and_partitions + : forall (f : list Z) + (Hf : length f = bytes_n) + (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n), + eval weight n (from_bytesmod f) = eval bytes_weight bytes_n f + /\ from_bytesmod f = Partition.partition weight n (Positional.eval bytes_weight bytes_n f). + Proof using limbwidth_good Hn_nz. + now (split; [ apply eval_from_bytesmod | apply from_bytes_partitions ]). Qed. End freeze_mod_ops. -Hint Rewrite eval_freeze_to_bytesmod : push_eval. +Hint Rewrite eval_freeze_to_bytesmod eval_to_bytes eval_to_bytesmod eval_from_bytes eval_from_bytesmod : push_eval. Section primitives. Definition mulx (bitwidth : Z) := Eval cbv [Z.mul_split_at_bitwidth] in Z.mul_split_at_bitwidth bitwidth. |