aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2019-01-02 11:23:10 -0500
committerGravatar Jason Gross <jasongross9@gmail.com>2019-01-03 00:22:13 -0500
commit1f8b428d03c7d448680245f5752004a32ce77c20 (patch)
treed954d42e1aa89454cd3a695f0c3eaea8ffe027e0
parent2238baafd6ad626bbec17b85cbeb79f848475d6e (diff)
Prove that convert_bases partitions
Unfortunately, the proof is rather slow :-( After | File Name | Before || Change | % Change -------------------------------------------------------------------------------------------------------------------- 22m46.26s | Total | 22m03.36s || +0m42.90s | +3.24% -------------------------------------------------------------------------------------------------------------------- 1m57.06s | Experiments/NewPipeline/Arithmetic.vo | 1m15.08s || +0m41.98s | +55.91% 0m42.39s | p521_32.c | 0m39.92s || +0m02.46s | +6.18% 0m35.40s | p521_64.c | 0m33.02s || +0m02.37s | +7.20% 3m15.80s | p384_32.c | 3m17.10s || -0m01.29s | -0.65% 6m19.46s | Experiments/NewPipeline/SlowPrimeSynthesisExamples.vo | 6m18.91s || +0m00.54s | +0.14% 4m34.80s | Experiments/NewPipeline/Toplevel1.vo | 4m35.56s || -0m00.75s | -0.27% 1m36.64s | Experiments/NewPipeline/Toplevel2.vo | 1m36.64s || +0m00.00s | +0.00% 0m43.22s | Experiments/NewPipeline/ExtractionHaskell/word_by_word_montgomery | 0m44.18s || -0m00.96s | -2.17% 0m29.24s | Experiments/NewPipeline/ExtractionHaskell/unsaturated_solinas | 0m29.83s || -0m00.58s | -1.97% 0m22.83s | Experiments/NewPipeline/ExtractionHaskell/saturated_solinas | 0m22.87s || -0m00.04s | -0.17% 0m16.11s | Experiments/NewPipeline/ExtractionOCaml/word_by_word_montgomery | 0m16.73s || -0m00.62s | -3.70% 0m14.24s | p256_32.c | 0m13.75s || +0m00.49s | +3.56% 0m14.22s | secp256k1_32.c | 0m14.19s || +0m00.03s | +0.21% 0m11.01s | p384_64.c | 0m10.50s || +0m00.50s | +4.85% 0m09.71s | Experiments/NewPipeline/ExtractionOCaml/word_by_word_montgomery.ml | 0m09.86s || -0m00.14s | -1.52% 0m09.69s | Experiments/NewPipeline/ExtractionOCaml/unsaturated_solinas | 0m10.35s || -0m00.66s | -6.37% 0m07.41s | Experiments/NewPipeline/ExtractionOCaml/saturated_solinas | 0m07.74s || -0m00.33s | -4.26% 0m06.65s | Experiments/NewPipeline/ExtractionOCaml/unsaturated_solinas.ml | 0m06.52s || +0m00.13s | +1.99% 0m06.50s | p224_32.c | 0m06.47s || +0m00.03s | +0.46% 0m06.24s | Experiments/NewPipeline/ExtractionHaskell/word_by_word_montgomery.hs | 0m06.38s || -0m00.13s | -2.19% 0m05.10s | Experiments/NewPipeline/ExtractionOCaml/saturated_solinas.ml | 0m04.97s || +0m00.12s | +2.61% 0m04.82s | Experiments/NewPipeline/ExtractionHaskell/unsaturated_solinas.hs | 0m05.04s || -0m00.21s | -4.36% 0m03.95s | Experiments/NewPipeline/ExtractionHaskell/saturated_solinas.hs | 0m04.02s || -0m00.06s | -1.74% 0m02.31s | curve25519_32.c | 0m02.24s || +0m00.06s | +3.12% 0m02.03s | p224_64.c | 0m01.92s || +0m00.10s | +5.72% 0m01.91s | secp256k1_64.c | 0m02.00s || -0m00.09s | -4.50% 0m01.90s | p256_64.c | 0m01.86s || +0m00.03s | +2.15% 0m01.54s | curve25519_64.c | 0m01.61s || -0m00.07s | -4.34% 0m01.51s | Experiments/NewPipeline/CLI.vo | 0m01.49s || +0m00.02s | +1.34% 0m01.29s | Experiments/NewPipeline/StandaloneHaskellMain.vo | 0m01.29s || +0m00.00s | +0.00% 0m01.28s | Experiments/NewPipeline/StandaloneOCamlMain.vo | 0m01.32s || -0m00.04s | -3.03%
-rw-r--r--src/Experiments/NewPipeline/Arithmetic.v574
1 files changed, 533 insertions, 41 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v
index f3c6fa709..5af73875b 100644
--- a/src/Experiments/NewPipeline/Arithmetic.v
+++ b/src/Experiments/NewPipeline/Arithmetic.v
@@ -651,6 +651,68 @@ Module Positional.
unique pose proof (Z_div_exact_full_2 a b ltac:(auto) ltac:(auto))
end; nsatz. Qed.
Hint Rewrite weight_place : push_eval.
+ Lemma weight_add_mod (weight_mul : forall i, weight (S i) mod weight i = 0) i j
+ : weight (i + j) mod weight i = 0.
+ Proof using weight_nz.
+ rewrite Nat.add_comm.
+ induction j as [|[|j] IHj]; cbn [Nat.add] in *;
+ eauto using Z_mod_same_full, Z.mod_mod_trans.
+ Qed.
+ Lemma weight_mul_iff (weight_pos : forall i, 0 < weight i) (weight_mul : forall i, weight (S i) mod weight i = 0) i j
+ : weight i mod weight j = 0 <-> ((j < i)%nat \/ forall k, (i <= k <= j)%nat -> weight k = weight j).
+ Proof using weight_nz.
+ split.
+ { destruct (dec (j < i)%nat); [ left; omega | intro H; right; revert H ].
+ assert (j = (j - i) + i)%nat by omega.
+ generalize dependent (j - i)%nat; intro jmi; intros ? H0.
+ subst j.
+ destruct jmi as [|j]; [ intros k ?; assert (k = i) by omega; subst; f_equal; omega | ].
+ induction j as [|j IH]; cbn [Nat.add] in *.
+ { intros k ?; assert (k = i \/ k = S i) by omega; destruct_head'_or; subst;
+ eauto using Z.mod_mod_0_0_eq_pos. }
+ { specialize_by omega.
+ { pose proof (weight_mul (S (j + i))) as H.
+ specialize_by eauto using Z.mod_mod_trans with omega.
+ intros k H'; destruct (dec (k = S (S (j + i)))); subst;
+ try rewrite IH by eauto using Z.mod_mod_trans with omega;
+ eauto using Z.mod_mod_trans, Z.mod_mod_0_0_eq_pos with omega.
+ rewrite (IH i) in * by omega.
+ eauto using Z.mod_mod_trans, Z.mod_mod_0_0_eq_pos with omega. } } }
+ { destruct (dec (j < i)%nat) as [H|H]; [ intros _ | intros [H'|H']; try omega ].
+ { assert (i = j + (i - j))%nat by omega.
+ generalize dependent (i - j)%nat; intro imj; intros.
+ subst i.
+ apply weight_add_mod; auto. }
+ { erewrite H', Z_mod_same_full by omega; omega. } }
+ Qed.
+ Lemma weight_div_from_pos_mul (weight_pos : forall i, 0 < weight i) (weight_mul : forall i, weight (S i) mod weight i = 0)
+ : forall i, 0 < weight (S i) / weight i.
+ Proof using weight_nz.
+ intro i; generalize (weight_mul i) (weight_mul (S i)).
+ Z.div_mod_to_quot_rem; nia.
+ Qed.
+ Lemma place_weight n (weight_pos : forall i, 0 < weight i) (weight_mul : forall i, weight (S i) mod weight i = 0)
+ (weight_unique : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j -> i = j)
+ i x
+ : (place (weight i, x) n) = (Nat.min i n, (weight i / weight (Nat.min i n)) * x).
+ Proof using weight_0 weight_nz.
+ cbv [place].
+ induction n as [|n IHn]; cbn; [ destruct i; cbn; rewrite ?weight_0; autorewrite with zsimplify_const; reflexivity | ].
+ destruct (dec (i < S n)%nat);
+ break_innermost_match; cbn [fst snd] in *; Z.ltb_to_lt; [ | rewrite IHn | | rewrite IHn ];
+ break_innermost_match;
+ rewrite ?Min.min_l in * by omega;
+ rewrite ?Min.min_r in * by omega;
+ eauto with omega.
+ { rewrite weight_mul_iff in * by auto.
+ destruct_head'_or; try omega.
+ assert (S n = i).
+ { apply weight_unique; try omega.
+ symmetry; eauto with omega. }
+ subst; reflexivity. }
+ { rewrite weight_mul_iff in * by auto.
+ exfalso; intuition eauto with omega. }
+ Qed.
Definition from_associational n (p:list (Z*Z)) :=
List.fold_right (fun t ls =>
@@ -669,6 +731,29 @@ Module Positional.
Proof using Type. cbv [from_associational Let_In]. apply fold_right_invariant; intros; distr_length. Qed.
Hint Rewrite length_from_associational : distr_length.
+ Lemma nth_default_from_associational v n p i (n_nz : n <> 0%nat) :
+ nth_default v (from_associational n p) i
+ = fold_right Z.add (nth_default v (zeros n) i)
+ (map (fun t => dlet p : nat * Z := place t (pred n) in
+ if dec (fst p = i) then snd p else 0) p).
+ Proof.
+ subst; cbv [from_associational Let_In].
+ induction p as [|p ps IHps]; [ reflexivity | ]; cbn [fold_right map]; rewrite <- IHps; clear IHps.
+ cbv [add_to_nth].
+ match goal with
+ | [ |- context[place ?p ?i] ]
+ => pose proof (place_in_range p i)
+ end.
+ rewrite update_nth_nth_default_full; break_match; try omega;
+ rewrite nth_default_out_of_bounds by omega; try omega.
+ match goal with
+ | [ H : context[length (fold_right ?f ?v ?ps)] |- _ ]
+ => replace (length (fold_right f v ps)) with (length v) in H
+ by (apply fold_right_invariant; intros; distr_length; auto)
+ end.
+ distr_length; auto.
+ Qed.
+
Definition extend_to_length (n_in n_out : nat) (p:list Z) : list Z :=
p ++ zeros (n_out - n_in).
Lemma eval_extend_to_length n_in n_out p :
@@ -760,6 +845,127 @@ Module Positional.
apply eval_to_associational.
Qed. Hint Rewrite @eval_carry : push_eval.
+ (** TODO: figure out a way to make this proof shorter and faster *)
+ Lemma nth_default_carry upper n m index p
+ (weight_mul : forall i, weight (S i) mod weight i = 0)
+ (weight_pos : forall i, 0 < weight i)
+ (weight_unique : forall i j, (i <= upper)%nat -> (j <= upper)%nat -> weight i = weight j -> i = j)
+ (Hn : (n <= upper)%nat)
+ (Hm : (0 < m <= upper)%nat)
+ (Hnm : (n <= m)%nat)
+ (Hidx : (index <= upper)%nat) :
+ length p = n ->
+ forall i, nth_default 0 (carry n m index p) i
+ = if dec (m <= i)%nat
+ then 0
+ else if dec (i = S index)
+ then nth_default 0 p i + ((nth_default 0 p index) / (weight (S index) / weight index))
+ else if dec (i = index)
+ then if dec (S index <> n \/ n <> m)
+ then ((nth_default 0 p i) mod (weight (S index) / weight index))
+ else nth_default 0 p i
+ else nth_default 0 p i.
+ Proof using weight_0 weight_nz.
+ assert (weight_unique_iff : forall i j, (i <= upper)%nat -> (j <= upper)%nat -> weight i = weight j <-> i = j)
+ by (split; subst; auto).
+ pose proof (weight_div_from_pos_mul weight_pos weight_mul) as weight_div_pos.
+ assert (weight_div_nz : forall i, weight (S i) / weight i <> 0) by (intro i; specialize (weight_div_pos i); omega).
+ intro; subst.
+ intro i.
+ destruct (dec (m <= i)%nat) as [Hmi|Hmi];
+ [ rewrite (@nth_default_out_of_bounds _ i (carry _ _ _ _)) by (distr_length; omega); reflexivity | ].
+ cbv [carry to_associational Associational.carry Let_In Associational.carryterm].
+ rewrite combine_map_l, flat_map_map; cbn [fst snd].
+ rewrite nth_default_from_associational, map_flat_map by omega; cbn [map].
+ cbv [zeros]; rewrite nth_default_repeat.
+ replace (if (dec (i < m)%nat) then 0 else 0) with 0 by (break_match; reflexivity).
+ set (init := 0) at 1.
+ lazymatch goal with |- ?LHS = ?RHS => rewrite <- (Z.add_0_l RHS : init + RHS = RHS) end.
+ clearbody init.
+ revert Hn i init Hmi Hnm Hidx.
+ rewrite <- (rev_involutive p); generalize (rev p); clear p; intro p; rewrite rev_length.
+ induction p as [|p ps IHps]; cbn [length]; intros Hn i init Hmi Hnm Hidx.
+ { cbn; cbv [zeros]; break_innermost_match; cbn;
+ rewrite ?nth_default_repeat, ?nth_default_nil; break_innermost_match; autorewrite with zsimplify_const; reflexivity. }
+ { specialize_by omega.
+ rewrite seq_snoc, rev_cons, combine_app_samelength by distr_length.
+ rewrite flat_map_app, fold_right_app, IHps by omega; clear IHps.
+ cbn [combine fold_right fst snd flat_map map].
+ rewrite Nat.add_0_l.
+ cbv [Let_In]; cbn [fst snd].
+ rewrite ?nth_default_app; distr_length.
+ destruct (dec (i = index)), (dec (i = S index)); try (subst; omega).
+ { all:subst; break_innermost_match; Z.ltb_to_lt;
+ match goal with
+ | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega
+ end; destruct_head'_or; try (subst; omega).
+ all:repeat first [ progress cbn [fst snd app map fold_right]
+ | progress Z.ltb_to_lt
+ | progress subst
+ | progress destruct_head'_or
+ | progress rewrite ?Z.mul_div_eq_full, ?weight_mul, ?Z.sub_0_r by eauto with omega
+ | progress rewrite ?place_weight by eauto with omega
+ | rewrite !Nat.sub_diag
+ | rewrite !Min.min_l by omega
+ | rewrite !nth_default_cons
+ | rewrite Z.div_same by eauto with omega
+ | progress break_innermost_match
+ | progress autorewrite with zsimplify_const
+ | lia
+ | match goal with
+ | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega
+ | [ |- context[nth_default ?d ?ls ?i] ] => rewrite (@nth_default_out_of_bounds _ i ls d) by (distr_length; omega)
+ | [ H : ?x = ?x |- _ ] => clear H
+ end
+ | progress handle_min_max_for_omega_case ]. }
+ { subst; break_innermost_match; Z.ltb_to_lt;
+ match goal with
+ | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega
+ end; destruct_head'_or; try (subst; omega).
+ all:repeat first [ progress cbn [fst snd app map fold_right]
+ | progress Z.ltb_to_lt
+ | progress subst
+ | progress destruct_head'_or
+ | progress rewrite ?Z.mul_div_eq_full, ?weight_mul, ?Z.sub_0_r by eauto with omega
+ | progress rewrite ?place_weight by eauto with omega
+ | rewrite !Nat.sub_diag
+ | rewrite !Min.min_l by omega
+ | rewrite !nth_default_cons
+ | rewrite Z.div_same by eauto with omega
+ | progress break_innermost_match
+ | progress autorewrite with zsimplify_const
+ | lia
+ | match goal with
+ | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega
+ | [ |- context[nth_default ?d ?ls ?i] ] => rewrite (@nth_default_out_of_bounds _ i ls d) by (distr_length; omega)
+ | [ H : ?x = ?x |- _ ] => clear H
+ end
+ | progress handle_min_max_for_omega_case ]. }
+ { subst; break_innermost_match; Z.ltb_to_lt;
+ match goal with
+ | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega
+ end; destruct_head'_or; try (subst; omega).
+ all:repeat first [ progress cbn [fst snd app map fold_right]
+ | progress Z.ltb_to_lt
+ | progress subst
+ | progress destruct_head'_or
+ | progress rewrite ?Z.mul_div_eq_full, ?weight_mul, ?Z.sub_0_r by eauto with omega
+ | progress rewrite ?place_weight by eauto with omega
+ | rewrite !Nat.sub_diag
+ | rewrite !Min.min_l by omega
+ | rewrite !nth_default_cons
+ | rewrite Z.div_same by eauto with omega
+ | progress break_innermost_match
+ | progress autorewrite with zsimplify_const
+ | lia
+ | match goal with
+ | [ H : context[weight ?x = weight ?y] |- _ ] => rewrite (weight_unique_iff x y) in H by omega
+ | [ |- context[nth_default ?d ?ls ?i] ] => rewrite (@nth_default_out_of_bounds _ i ls d) by (distr_length; omega)
+ | [ H : ?x = ?x |- _ ] => clear H
+ end
+ | progress handle_min_max_for_omega_case ]. } }
+ Qed.
+
Definition carry_reduce n (s:Z) (c:list (Z * Z))
(index:nat) (p : list Z) :=
from_associational
@@ -823,6 +1029,230 @@ Module Positional.
intros; cbv [chained_carries_no_reduce]; induction (rev idxs) as [|x xs IHxs];
cbn [fold_right]; distr_length.
Qed. Hint Rewrite @length_chained_carries_no_reduce : distr_length.
+ (** TODO: figure out a way to make this proof shorter and faster *)
+ Lemma nth_default_chained_carries_no_reduce_app n m inp1 inp2
+ (weight_mul : forall i, weight (S i) mod weight i = 0)
+ (weight_pos : forall i, 0 < weight i)
+ (weight_div : forall i : nat, 0 < weight (S i) / weight i)
+ (weight_unique : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j -> i = j) :
+ length inp1 = m -> (length inp1 + length inp2 = n)%nat
+ -> (List.length inp2 <> 0%nat \/ 0 <= eval m inp1 < weight m)
+ -> forall i,
+ nth_default 0 (chained_carries_no_reduce n (inp1 ++ inp2) (seq 0 m)) i
+ = if dec (i < m)%nat
+ then ((eval m inp1) mod weight (S i)) / weight i
+ else if dec (i = m)
+ then match inp2 with
+ | nil => 0
+ | cons x xs
+ => x + (eval m inp1) / weight m
+ end
+ else nth_default 0 inp2 (i - m).
+ Proof using weight_0 weight_nz.
+ intro; subst m.
+ rewrite <- (rev_involutive inp1); generalize (List.rev inp1); clear inp1; intro inp1; rewrite rev_length.
+ revert inp2; induction inp1 as [|x xs IHxs]; intros.
+ { destruct inp2; cbn; autorewrite with zsimplify_const; intros; destruct i; reflexivity. }
+ destruct (lt_dec i n);
+ [
+ | break_match; cbn [List.length] in *; try lia;
+ rewrite ?nth_default_out_of_bounds by (repeat autorewrite with distr_length; lia);
+ reflexivity ].
+ cbv [chained_carries_no_reduce] in *.
+ repeat first [ progress cbn [List.length List.app List.rev fold_right] in *
+ | reflexivity
+ | assumption
+ | progress intros
+ | rewrite <- List.app_assoc
+ | rewrite seq_snoc
+ | rewrite rev_unit
+ | rewrite Nat.add_0_l
+ | rewrite eval_snoc_S in * by distr_length
+ | rewrite app_length
+ | rewrite rev_length
+ | erewrite nth_default_carry; try eassumption
+ | rewrite !IHxs; clear IHxs
+ | lia
+ | match goal with
+ | [ |- length (fold_right _ ?p (rev ?idxs)) = ?n ]
+ => apply (length_chained_carries_no_reduce n p idxs)
+ | [ |- context[_ mod weight (S ?n) / weight ?n] ]
+ => rewrite (Z.div_mod' (weight (S n)) (weight n)), weight_mul, Z.add_0_r, <- Z.mod_pull_div, ?Z.div_mul, ?Z.div_add', ?Z.mul_div_eq', ?weight_mul, ?Z.sub_0_r
+ by solve [ assumption
+ | now apply Z.lt_le_incl, weight_div
+ | now apply Z.lt_gt, weight_pos
+ | (idtac + symmetry); now apply Z.lt_neq, weight_pos ]
+ | [ |- context[?x + ?y] ]
+ => match goal with
+ | [ |- context[y + x] ]
+ => progress replace (y + x) with (x + y) by lia
+ end
+ end ].
+ break_match; try (exfalso; lia).
+ all: repeat first [ rewrite nth_default_app
+ | rewrite nth_default_carry
+ | rewrite Nat.sub_diag
+ | rewrite minus_S_diag
+ | rewrite Nat.sub_succ_r
+ | rewrite nth_default_cons
+ | rewrite nth_default_cons_S
+ | progress subst
+ | now apply weight_0
+ | now apply weight_mul
+ | now apply weight_pos
+ | reflexivity
+ | progress intros
+ | (idtac + symmetry); now apply Z.lt_neq, weight_pos
+ | rewrite Z.div_add' by ((idtac + symmetry); now apply Z.lt_neq, weight_pos)
+ | progress destruct_head'_and
+ | progress destruct_head'_or
+ | progress cbn [List.length] in *
+ | match goal with
+ | [ |- context[?x + ?y] ]
+ => match goal with
+ | [ |- context[y + x] ]
+ => progress replace (y + x) with (x + y) by lia
+ end
+ | [ H : List.length ?x = 0%nat |- _ ] => is_var x; destruct x
+ | [ H : not (or _ _) |- _ ] => apply Decidable.not_or in H
+ | [ H : ?x = ?x |- _ ] => clear H
+ | [ H : not (?x < ?x) |- _ ] => clear H
+ | [ H : not (?x < ?x)%nat |- _ ] => clear H
+ | [ H : not (S ?x < ?x)%nat |- _ ] => clear H
+ | [ H : ~(S ?x + _ <= ?x)%nat |- _ ] => clear H
+ | [ H : (?x < S ?x + _)%nat |- _ ] => clear H
+ | [ H : ?x <> S ?x |- _ ] => clear H
+ | [ H : ?x <> (?x + ?y)%nat |- _ ] => assert (0 < y)%nat by lia; clear H
+ | [ H : (?x < ?x + ?y)%nat |- _ ] => assert (0 < y)%nat by lia; clear H
+ | [ H : ~(?x + ?y <= ?x)%nat |- _ ] => assert (0 < y)%nat by lia; clear H
+ | [ H : ~(?x <> ?y) |- _ ] => assert (x = y) by lia; clear H
+ | [ H : (?x = ?x + ?y)%nat |- _ ] => assert (y = 0%nat) by lia; clear H
+ | [ H : ~(?x <= ?y)%nat |- _ ] => assert (y < x)%nat by lia; clear H
+ | [ H : ~(?x < ?y)%nat |- _ ] => assert (y <= x)%nat by lia; clear H
+ | [ H : (?x <= ?y)%nat, H' : ?x <> ?y |- _ ] => assert (x < y)%nat by lia; clear H H'
+ | [ H : (?x <= ?y)%nat, H' : ?y <> ?x |- _ ] => assert (x < y)%nat by lia; clear H H'
+ | [ H : (?x < ?y)%nat |- context[nth_default _ _ (?y - ?x)] ]
+ => destruct (y - x)%nat eqn:?
+ | [ |- context[nth_default _ (_ :: _) ?n] ] => is_var n; destruct n
+ | [ H : ?T, H' : ?T |- _ ] => clear H'
+ | [ |- (?x + ?y) mod ?z = (?y + ?x) mod ?z ] => apply f_equal2
+ | [ |- ?x + _ = ?x + _ ] => apply f_equal
+ | [ H0 : 0 <= ?e + ?w * ?x, H1 : ?e + ?w * ?x < ?w'
+ |- ?x + ?e / ?w = (?x + ?e / ?w) mod (?w' / ?w) ]
+ => rewrite (Z.mod_small (x + e / w) (w' / w))
+ | [ H : (?i < ?n)%nat |- context[(_ + weight ?n * _) / weight ?i] ]
+ => rewrite (Z.div_mod (weight n) (weight (S i))), weight_multiples_full, Z.add_0_r,
+ (Z.div_mod (weight (S i)) (weight i)), weight_mul, Z.add_0_r,
+ <- !Z.mul_assoc, Z.div_add', ?Z.div_mul', ?Z.mul_div_eq_full, ?weight_mul, ?Z.sub_0_r
+ by solve [ assumption
+ | now apply Z.lt_le_incl, weight_div
+ | now apply Z.lt_gt, weight_pos
+ | now apply Nat.lt_le_incl
+ | (idtac + symmetry); now apply Z.lt_neq, weight_pos ];
+ push_Zmod; pull_Zmod
+ end
+ | progress autorewrite with distr_length in *
+ | lia
+ | progress autorewrite with zsimplify_const
+ | break_innermost_match_step
+ | match goal with
+ | [ |- context[weight (S ?n) / weight ?n] ]
+ => unique pose proof (@weight_mul n)
+ end
+ | Z.div_mod_to_quot_rem; nia ].
+ Qed.
+
+ Lemma nth_default_chained_carries_no_reduce n inp
+ (weight_mul : forall i, weight (S i) mod weight i = 0)
+ (weight_pos : forall i, 0 < weight i)
+ (weight_div : forall i : nat, 0 < weight (S i) / weight i)
+ (weight_unique : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j -> i = j) :
+ length inp = n -> 0 <= eval n inp < weight n
+ -> forall i,
+ nth_default 0 (chained_carries_no_reduce n inp (seq 0 n)) i
+ = ((eval n inp) mod weight (S i)) / weight i.
+ Proof using weight_0 weight_nz.
+ pose proof (weight_divides_full weight ltac:(assumption) ltac:(assumption)) as weight_div_full.
+ pose proof (weight_multiples_full weight ltac:(assumption) ltac:(assumption)) as weight_mul_full.
+ assert (weight_le_full : forall j i, (i <= j)%nat -> weight i <= weight j)
+ by (intros j i pf; specialize (weight_div_full j i pf); specialize (weight_mul_full j i pf); Z.div_mod_to_quot_rem; nia).
+ intros ? ? i.
+ pose proof (weight_le_full (S n) n ltac:(lia)).
+ pose proof (weight_le_full (S i) i ltac:(lia)).
+ pose proof (weight_le_full i n).
+ intros; rewrite <- (app_nil_r inp).
+ rewrite (@nth_default_chained_carries_no_reduce_app n n inp nil), app_nil_r by (cbn [List.length]; auto with lia).
+ break_innermost_match; try reflexivity; rewrite ?nth_default_nil.
+ all: rewrite Z.mod_small by lia.
+ all: rewrite Z.div_small by lia.
+ all: reflexivity.
+ Qed.
+
+ Lemma nth_default_chained_carries_no_reduce_pred n inp
+ (weight_mul : forall i, weight (S i) mod weight i = 0)
+ (weight_pos : forall i, 0 < weight i)
+ (weight_div : forall i : nat, 0 < weight (S i) / weight i)
+ (weight_unique : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j -> i = j) :
+ length inp = n -> 0 <= eval n inp < weight n
+ -> forall i,
+ nth_default 0 (chained_carries_no_reduce n inp (seq 0 (pred n))) i
+ = ((eval n inp) mod weight (S i)) / weight i.
+ Proof using weight_0 weight_nz.
+ pose proof (weight_divides_full weight ltac:(assumption) ltac:(assumption)) as weight_div_full.
+ pose proof (weight_multiples_full weight ltac:(assumption) ltac:(assumption)) as weight_mul_full.
+ assert (weight_le_full : forall j i, (i <= j)%nat -> weight i <= weight j)
+ by (intros j i pf; specialize (weight_div_full j i pf); specialize (weight_mul_full j i pf); Z.div_mod_to_quot_rem; nia).
+ destruct n as [|n]; [ now apply nth_default_chained_carries_no_reduce | ].
+ intros ? ? i.
+ pose proof (weight_le_full (S n) n ltac:(lia)).
+ pose proof (weight_le_full (S i) i ltac:(lia)).
+ pose proof (weight_le_full i n).
+ pose proof (weight_le_full (S i) (S n)).
+ pose proof (weight_le_full i (S n)).
+ cbn [pred].
+ revert dependent inp; intro inp.
+ rewrite <- (rev_involutive inp); generalize (rev inp); clear inp; intro inp.
+ rewrite rev_length; intros.
+ destruct inp as [|x inp]; cbn [List.length List.rev] in *; [ lia | ].
+ rewrite (@nth_default_chained_carries_no_reduce_app (S n) n (List.rev inp) (x::nil)) by (cbn [List.length]; autorewrite with distr_length; auto with lia).
+ rewrite eval_snoc_S in * by distr_length.
+ break_innermost_match; try reflexivity.
+ all: repeat first [ progress autorewrite with zsimplify_const
+ | reflexivity
+ | progress Z.rewrite_mod_small
+ | rewrite Z.div_add' by ((idtac + symmetry); now apply Z.lt_neq, weight_pos)
+ | lia
+ | match goal with
+ | [ |- context[_ mod weight (S ?n) / weight ?n] ]
+ => rewrite (Z.div_mod' (weight (S n)) (weight n)), weight_mul, Z.add_0_r, <- Z.mod_pull_div, ?Z.div_mul, ?Z.div_add', ?Z.mul_div_eq', ?weight_mul, ?Z.sub_0_r
+ by solve [ assumption
+ | now apply Z.lt_le_incl, weight_div
+ | now apply Z.lt_gt, weight_pos
+ | (idtac + symmetry); now apply Z.lt_neq, weight_pos ]
+ | [ |- context[(_ + weight ?n * _) / weight ?i] ]
+ => rewrite (Z.div_mod (weight n) (weight (S i))), weight_multiples_full, Z.add_0_r,
+ (Z.div_mod (weight (S i)) (weight i)), weight_mul, Z.add_0_r,
+ <- !Z.mul_assoc, Z.div_add', ?Z.div_mul', ?Z.mul_div_eq_full, ?weight_mul, ?Z.sub_0_r
+ by solve [ assumption
+ | now apply Z.lt_le_incl, weight_div
+ | now apply Z.lt_gt, weight_pos
+ | now apply Nat.lt_le_incl
+ | (idtac + symmetry); now apply Z.lt_neq, weight_pos ];
+ push_Zmod; pull_Zmod
+ end
+ | rewrite nth_default_cons
+ | rewrite nth_default_cons_S
+ | rewrite nth_default_nil
+ | rewrite Z.div_small by lia
+ | lia
+ | match goal with
+ | [ H : ~(?x < ?y)%nat |- _ ] => assert (y <= x)%nat by lia; clear H
+ | [ H : (?x <= ?y)%nat, H' : ?x <> ?y |- _ ] => assert (x < y)%nat by lia; clear H H'
+ | [ H : (?x <= ?y)%nat, H' : ?y <> ?x |- _ ] => assert (x < y)%nat by lia; clear H H'
+ | [ H : (?x < ?y)%nat |- context[nth_default _ _ (?y - ?x)] ]
+ => destruct (y - x)%nat eqn:?
+ end ].
+ Qed.
(* Reverse of [eval]; translate from Z to basesystem by putting
everything in first digit and then carrying. *)
@@ -976,7 +1406,7 @@ Module Positional.
End select.
End Positional.
(* Hint Rewrite disappears after the end of a section *)
-Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_encode_no_reduce @length_sub @length_opp @length_select @length_zselect @length_select_min @length_extend_to_length @length_drop_high_to_length : distr_length.
+Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_carry @length_chained_carries @length_chained_carries_no_reduce @length_encode @length_encode_no_reduce @length_sub @length_opp @length_select @length_zselect @length_select_min @length_extend_to_length @length_drop_high_to_length : distr_length.
Hint Rewrite @eval_zeros @eval_nil @eval_snoc_S @eval_select @eval_zselect @eval_extend_to_length using solve [auto; distr_length]: push_eval.
Section Positional_nonuniform.
Context (weight weight' : nat -> Z).
@@ -1082,6 +1512,19 @@ Section mod_ops.
Z.div_mod_to_quot_rem_in_goal; nia.
Qed.
+ Lemma weight_unique_iff : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j <-> i = j.
+ Proof using limbwidth_good.
+ clear Hn_nz; clear dependent c.
+ cbv [weight]; split; intro H'; subst; trivial; [].
+ apply (f_equal (fun x => limbwidth_den * (- Z.log2 x))) in H'.
+ rewrite !Z.log2_pow2, !Z.opp_involutive in H' by (Z.div_mod_to_quot_rem; nia).
+ Z.div_mod_to_quot_rem.
+ destruct i as [|i], j as [|j]; autorewrite with zsimplify_const in *; [ reflexivity | exfalso; nia.. | ].
+ nia.
+ Qed.
+ Lemma weight_unique : forall i j, (i <= n)%nat -> (j <= n)%nat -> weight i = weight j -> i = j.
+ Proof using limbwidth_good. apply weight_unique_iff. Qed.
+
Derive carry_mulmod
SuchThat (forall (f g : list Z)
(Hf : length f = n)
@@ -2355,6 +2798,31 @@ Module BaseConversion.
rewrite eval_from_associational; auto.
Qed.
+ Lemma length_convert_bases sn dn p
+ : length (convert_bases sn dn p) = dn.
+ Proof using Type.
+ cbv [convert_bases]; now repeat autorewrite with distr_length.
+ Qed.
+ Hint Rewrite length_convert_bases : distr_length.
+
+ Lemma convert_bases_partitions sn dn p
+ (dw_unique : forall i j : nat, (i <= dn)%nat -> (j <= dn)%nat -> dw i = dw j -> i = j)
+ (p_bounded : 0 <= eval sw sn p < dw dn)
+ : convert_bases sn dn p = partition dw dn (eval sw sn p).
+ Proof using dwprops.
+ apply list_elementwise_eq; intro i.
+ destruct (lt_dec i dn); [ | now rewrite !nth_error_length_error by distr_length ].
+ erewrite !(@nth_error_Some_nth_default _ _ 0) by (break_match; distr_length).
+ apply f_equal.
+ cbv [convert_bases partition].
+ unshelve erewrite map_nth_default, nth_default_chained_carries_no_reduce_pred;
+ repeat first [ progress autorewrite with distr_length push_eval
+ | rewrite eval_from_associational, eval_to_associational
+ | rewrite nth_default_seq_inbounds
+ | apply dwprops
+ | destruct dwprops; now auto with zarith ].
+ Qed.
+
Hint Rewrite
@Rows.eval_from_associational
@Associational.eval_carry
@@ -2487,6 +2955,7 @@ Module BaseConversion.
Qed.
End mul_converted.
End BaseConversion.
+ Hint Rewrite length_convert_bases : distr_length.
(* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *)
Section widemul.
@@ -2702,6 +3171,9 @@ Section freeze_mod_ops.
Definition wprops_bytes := (@wprops 8 1 ltac:(clear; lia)).
Local Notation wprops := (@wprops limbwidth_num limbwidth_den limbwidth_good).
+ Local Notation wunique := (@weight_unique limbwidth_num limbwidth_den limbwidth_good).
+ Local Notation wunique_bytes := (@weight_unique 8 1 ltac:(clear; lia)).
+
Local Hint Immediate (wprops).
Local Hint Immediate (wprops_bytes).
Local Hint Immediate (weight_0 wprops).
@@ -2712,28 +3184,31 @@ Section freeze_mod_ops.
Local Hint Immediate (weight_positive wprops_bytes).
Local Hint Immediate (weight_multiples wprops_bytes).
Local Hint Immediate (weight_divides wprops_bytes).
+ Local Hint Immediate (wunique) (wunique_bytes).
+ Local Hint Resolve (wunique) (wunique_bytes).
Local Hint Resolve Z.positive_is_nonzero Z.lt_gt.
Definition bytes_n
:= Eval cbv [Qceiling Qdiv inject_Z Qfloor Qmult Qopp Qnum Qden Qinv Pos.mul]
in Z.to_nat (Qceiling (Z.log2_up (weight n) / 8)).
- Definition to_bytes' (v : list Z)
+ Lemma weight_bytes_weight_matches
+ : weight n <= bytes_weight bytes_n.
+ Proof using limbwidth_good.
+ clear -limbwidth_good.
+ cbv [weight bytes_n].
+ autorewrite with zsimplify_const.
+ rewrite Z.log2_up_pow2, !Z2Nat.id, !Z.opp_involutive by (Z.div_mod_to_quot_rem; nia).
+ Z.peel_le.
+ Z.div_mod_to_quot_rem; nia.
+ Qed.
+
+ Definition to_bytes (v : list Z)
:= BaseConversion.convert_bases weight bytes_weight n bytes_n v.
Definition from_bytes (v : list Z)
:= BaseConversion.convert_bases bytes_weight weight bytes_n n v.
- (** TODO: We should probably prove that BaseConversion.convert_bases
- partitions, so that we don't end up doing a needless [flatten ∘
- from_associational ∘ to_associational] just be be able to prove
- that the result partitions. See
- https://github.com/JasonGross/fiat-crypto/tree/zzz-wip-better-arith-proofs
- for some partial work in this direction. *)
- Definition to_bytes (f : list Z) : list Z
- := let v := to_bytes' f in
- fst (Rows.flatten bytes_weight bytes_n (Rows.from_associational bytes_weight bytes_n (Positional.to_associational bytes_weight bytes_n v))).
-
Definition freeze_to_bytesmod (f : list Z) : list Z
:= to_bytes (freeze weight n (Z.ones bitwidth) m_enc f).
@@ -2768,48 +3243,29 @@ Section freeze_mod_ops.
Z.div_mod_to_quot_rem; nia.
Qed.
- Lemma eval_to_bytes_mod
+ Lemma eval_to_bytes
: forall (f : list Z)
(Hf : length f = n),
- eval bytes_weight bytes_n (to_bytes f) = eval weight n f mod (bytes_weight bytes_n).
+ eval bytes_weight bytes_n (to_bytes f) = eval weight n f.
Proof using limbwidth_good Hn_nz.
generalize wprops wprops_bytes; clear -Hn_nz limbwidth_good.
intros.
cbv [to_bytes].
- rewrite Rows.flatten_mod by (assumption || apply Rows.length_from_associational).
- rewrite Rows.eval_from_associational by eauto using bytes_nz with omega.
- rewrite eval_to_associational.
- cbv [to_bytes'].
rewrite BaseConversion.eval_convert_bases
by (auto using bytes_nz; distr_length; auto using wprops).
reflexivity.
Qed.
- Lemma eval_to_bytes
+ Lemma to_bytes_partitions
: forall (f : list Z)
(Hf : length f = n)
(Hf_small : 0 <= eval weight n f < weight n),
- eval bytes_weight bytes_n (to_bytes f) = eval weight n f.
- Proof using Hn_nz limbwidth_good.
- generalize bytes_n_big. clear -Hn_nz limbwidth_good.
- intros; rewrite eval_to_bytes_mod by assumption.
- rewrite Z.mod_small by omega; reflexivity.
- Qed.
-
- Lemma to_bytes_partitions
- : forall (f : list Z)
- (Hf : length f = n),
to_bytes f = Partition.partition bytes_weight bytes_n (Positional.eval weight n f).
Proof using Hn_nz limbwidth_good.
clear -Hn_nz limbwidth_good.
intros; cbv [to_bytes].
- rewrite Rows.flatten_correct by (apply wprops_bytes || apply Rows.length_from_associational).
- rewrite Rows.eval_from_associational by eauto using bytes_nz with omega.
- rewrite eval_to_associational.
- cbv [to_bytes'].
- rewrite BaseConversion.eval_convert_bases
- by (auto using wprops_bytes, bytes_nz; distr_length; auto using wprops).
- reflexivity.
+ pose proof weight_bytes_weight_matches.
+ apply BaseConversion.convert_bases_partitions; eauto; lia.
Qed.
Lemma eval_to_bytesmod
@@ -2856,17 +3312,53 @@ Section freeze_mod_ops.
intros; now apply eval_freeze_to_bytesmod_and_partitions.
Qed.
+ Lemma eval_from_bytes
+ : forall (f : list Z)
+ (Hf : length f = bytes_n),
+ eval weight n (from_bytes f) = eval bytes_weight bytes_n f.
+ Proof using limbwidth_good Hn_nz.
+ generalize wprops wprops_bytes; clear -Hn_nz limbwidth_good.
+ intros.
+ cbv [from_bytes].
+ rewrite BaseConversion.eval_convert_bases
+ by (auto using bytes_nz; distr_length; auto using wprops).
+ reflexivity.
+ Qed.
+
+ Lemma from_bytes_partitions
+ : forall (f : list Z)
+ (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n),
+ from_bytes f = Partition.partition weight n (Positional.eval bytes_weight bytes_n f).
+ Proof using limbwidth_good.
+ clear -limbwidth_good.
+ intros; cbv [from_bytes].
+ pose proof weight_bytes_weight_matches.
+ apply BaseConversion.convert_bases_partitions; eauto; lia.
+ Qed.
+
Lemma eval_from_bytesmod
: forall (f : list Z)
(Hf : length f = bytes_n),
eval weight n (from_bytesmod f) = eval bytes_weight bytes_n f.
- Proof using Hn_nz limbwidth_good.
- cbv [from_bytesmod from_bytes]; intros.
- rewrite BaseConversion.eval_convert_bases by eauto using wprops.
- reflexivity.
+ Proof using Hn_nz limbwidth_good. apply eval_from_bytes. Qed.
+
+ Lemma from_bytesmod_partitions
+ : forall (f : list Z)
+ (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n),
+ from_bytesmod f = Partition.partition weight n (Positional.eval bytes_weight bytes_n f).
+ Proof using limbwidth_good. apply from_bytes_partitions. Qed.
+
+ Lemma eval_from_bytesmod_and_partitions
+ : forall (f : list Z)
+ (Hf : length f = bytes_n)
+ (Hf_small : 0 <= eval bytes_weight bytes_n f < weight n),
+ eval weight n (from_bytesmod f) = eval bytes_weight bytes_n f
+ /\ from_bytesmod f = Partition.partition weight n (Positional.eval bytes_weight bytes_n f).
+ Proof using limbwidth_good Hn_nz.
+ now (split; [ apply eval_from_bytesmod | apply from_bytes_partitions ]).
Qed.
End freeze_mod_ops.
-Hint Rewrite eval_freeze_to_bytesmod : push_eval.
+Hint Rewrite eval_freeze_to_bytesmod eval_to_bytes eval_to_bytesmod eval_from_bytes eval_from_bytesmod : push_eval.
Section primitives.
Definition mulx (bitwidth : Z) := Eval cbv [Z.mul_split_at_bitwidth] in Z.mul_split_at_bitwidth bitwidth.