aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2018-09-18 12:00:37 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2018-12-21 10:22:41 -0500
commit943d76f97fcbb02ca5d417266ad9dcd7a9561a73 (patch)
tree35113960fab7845b39314544f974220307cbcd33
parent9883e4e7a60ba7d1e1487f7a4501363bd3958fde (diff)
prove [eval_conditional_sub]
-rw-r--r--src/Experiments/NewPipeline/Arithmetic.v175
-rw-r--r--src/Util/ListUtil.v1
2 files changed, 124 insertions, 52 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v
index 7ad7bb1a3..06461a63c 100644
--- a/src/Experiments/NewPipeline/Arithmetic.v
+++ b/src/Experiments/NewPipeline/Arithmetic.v
@@ -689,40 +689,42 @@ Module Positional.
Definition drop_high_to_length (n : nat) (p:list Z) : list Z :=
firstn n p.
- (*
- Lemma eval_drop_high_to_length n m p :
- (forall i, weight (S i) mod weight i = 0) -> length p = m -> (n <= m)%nat ->
+
+ (* Helper for eval_drop_high_to_length *)
+ Lemma eval_drop_high_to_length' q
+ (weight_multiples: forall i, weight (S i) mod weight i = 0)
+ (weight_positive: forall i, 0 < weight i) :
+ forall n p m,
+ length p = n -> length q = m ->
+ eval (n+m) (p ++ q) mod weight n
+ = eval n p mod weight n.
+ Proof using Type.
+ induction q using rev_ind; intros; distr_length.
+ { subst m; rewrite app_nil_r.
+ autorewrite with natsimplify; reflexivity. }
+ { rewrite app_assoc.
+ rewrite eval_snoc with (n:=(n+pred m)%nat) by distr_length.
+ rewrite weight_div_mod with (j:=(n+pred m)%nat) (i:=n) by auto with omega.
+ push_Zmod. rewrite IHq by lia.
+ autorewrite with zsimplify_fast.
+ reflexivity. }
+ Qed.
+ Lemma eval_drop_high_to_length n m p
+ (weight_multiples: forall i, weight (S i) mod weight i = 0)
+ (weight_positive: forall i, 0 < weight i) :
+ length p = m -> (n <= m)%nat ->
eval n (drop_high_to_length n p) mod weight n
= eval m p mod weight n.
Proof using Type.
- cbv [eval drop_high_to_length to_associational]; intros.
- replace m with (n + (m - n))%nat in * by (f_equal; omega).
- generalize dependent (m - n)%nat; clear m; intro m; intros H' H''.
- rewrite seq_add, map_app, <- (firstn_skipn n p), combine_app_samelength, firstn_skipn, Associational.eval_app;
- push; try omega **.
- rewrite <- (Z.add_0_r (Associational.eval _)) at 1.
- apply Z.add_mod_Proper; [ reflexivity | cbv [Z.equiv_modulo] ].
- generalize (skipn_length n p); rewrite H', minus_plus.
- generalize (skipn n p); clear dependent p; clear H''; intros p Hp.
- rewrite Zmod_0_l.
- subst.
- cbv [Associational.eval].
- revert n; induction p as [|p ps IHps]; intro; [ reflexivity | ].
- cbn in *.
- push_Zmod; pull_Zmod; autorewrite with zsimplify_const.
- rewrite <- IHps.
- { cbn; reflexivity
- Search (0 mod _).
- rewrite Z.mod_0_l
- Search (?x + ?y - ?x)%nat.
- Search Z.equiv_modulo Proper.
- pose proof H as H''.
- rewrite <- (firstn_skipn n p) in H''.
- distr_length.
-
- rewrite Nat.min_l in H'' by omega.
- Qed.
- Hint Rewrite eval_drop_high_to_length : push_eval.*)
+ cbv [drop_high_to_length]; intros.
+ rewrite <-(firstn_skipn n p).
+ rewrite firstn_app_inleft by (distr_length; lia).
+ rewrite firstn_firstn. autorewrite with natsimplify.
+ replace m with (n + (m-n))%nat by omega.
+ rewrite eval_drop_high_to_length' by (auto; distr_length; lia).
+ reflexivity.
+ Qed.
+ Hint Rewrite eval_drop_high_to_length : push_eval.
Lemma length_drop_high_to_length n p :
length (drop_high_to_length n p) = Nat.min n (length p).
Proof using Type. clear; cbv [drop_high_to_length]; intros; distr_length. Qed.
@@ -972,20 +974,30 @@ Module Positional.
Lemma length_zselect mask cond p :
length (zselect mask cond p) = length p.
Proof using Type. clear dependent weight. cbv [zselect Let_In]; break_match; intros; distr_length. Qed.
+
+ (* We need an explicit equality proof here, because sometimes it
+ matters that we retain the same bounds when selecting. *)
+ Lemma select_eq cond n : forall p q,
+ length p = n -> length q = n ->
+ select cond p q = if dec (cond = 0) then p else q.
+ Proof using weight.
+ cbv [select]; induction n; intros;
+ destruct p; distr_length;
+ destruct q; distr_length;
+ repeat match goal with
+ | _ => progress autorewrite with push_combine push_map
+ | _ => rewrite IHn by distr_length
+ | _ => rewrite Z.zselect_correct
+ | _ => break_match; reflexivity
+ end.
+ Qed.
Lemma eval_select n cond p q :
length p = n -> length q = n
-> eval n (select cond p q) =
if dec (cond = 0) then eval n p else eval n q.
- Proof using Type.
- cbv [select Let_In]; intro; subst.
- rewrite <- (List.rev_involutive q), <- (List.rev_involutive p).
- generalize (rev p) (rev q); clear p q; intros p q; revert q.
- induction p as [|p ps IHps], q as [|q qs]; cbn [length map combine rev]; distr_length; rewrite ?Nat.add_1_r; try omega.
- { break_match; reflexivity. }
- { intro; rewrite !combine_snoc, !map_app by (distr_length; omega).
- cbn [map].
- rewrite !eval_snoc with (n:=length ps), IHps by (distr_length; omega* ).
- rewrite !Z.zselect_correct; break_match; reflexivity. }
+ Proof using weight.
+ intros; erewrite select_eq by eauto.
+ break_match; reflexivity.
Qed.
Lemma length_select_min cond p q :
length (select cond p q) = Nat.min (length p) (length q).
@@ -999,7 +1011,7 @@ Module Positional.
End Positional.
(* Hint Rewrite disappears after the end of a section *)
Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp @length_select @length_zselect @length_select_min @length_extend_to_length @length_drop_high_to_length : distr_length.
-Hint Rewrite @eval_zeros @eval_nil @eval_snoc_S @eval_select @eval_zselect @eval_extend_to_length (*@eval_drop_high_to_length*) using (solve [auto; distr_length]): push_eval.
+Hint Rewrite @eval_zeros @eval_nil @eval_snoc_S @eval_select @eval_zselect @eval_extend_to_length @eval_drop_high_to_length using (solve [auto; distr_length]): push_eval.
Section Positional_nonuniform.
Context (weight weight' : nat -> Z).
@@ -1402,7 +1414,16 @@ Module Partition.
induction n; cbn [recursive_partition]; [reflexivity | ].
intros; distr_length; auto.
Qed.
- Hint Rewrite length_recursive_partition : distr_length.
+
+ Lemma drop_high_to_length_partition n m x :
+ (n <= m)%nat ->
+ Positional.drop_high_to_length n (partition m x) = partition n x.
+ Proof.
+ cbv [Positional.drop_high_to_length partition]; intros.
+ autorewrite with push_firstn.
+ rewrite Nat.min_l by omega.
+ reflexivity.
+ Qed.
End Partition.
Hint Rewrite length_partition length_recursive_partition : distr_length.
Hint Rewrite eval_partition using (solve [auto; distr_length]) : push_eval.
@@ -2233,6 +2254,25 @@ Module Rows.
snd (sub n p q) = (Positional.eval weight n p - Positional.eval weight n q) / weight n.
Proof using wprops. solver. Qed.
+ Lemma conditional_sub_partitions n p q
+ (Hp : p = partition weight n (Positional.eval weight n p)) :
+ length q = n ->
+ 0 <= Positional.eval weight n q < weight n ->
+ conditional_sub n p q = partition weight n (if Positional.eval weight n q <=? Positional.eval weight n p then Positional.eval weight n p - Positional.eval weight n q else Positional.eval weight n p).
+ Proof.
+ cbv [conditional_sub]; intros.
+ rewrite (surjective_pairing (sub _ _ _)).
+ assert (length p = n) by (rewrite Hp; distr_length).
+ assert (0 <= Positional.eval weight n p < weight n) by
+ (rewrite Hp; autorewrite with push_eval; auto using Z.mod_pos_bound).
+ rewrite sub_partitions, sub_div; distr_length.
+ erewrite Positional.select_eq by (distr_length; eauto).
+ rewrite Z.div_sub_small, Z.ltb_antisym by omega.
+ destruct (Positional.eval weight n q <=? Positional.eval weight n p);
+ cbn [negb]; autorewrite with zsimplify_fast;
+ break_match; congruence.
+ Qed.
+
Lemma mul_partitions base n m p q :
base <> 0 -> m <> 0%nat -> length p = n -> length q = n ->
fst (mul base n m p q) = partition weight m (Positional.eval weight n p * Positional.eval weight n q).
@@ -2881,6 +2921,13 @@ Module UniformWeight.
rewrite !Z.pow_succ_r by auto using Nat2Z.is_nonneg.
ring.
Qed.
+ Lemma uweight_S lgr (Hr : 0 <= lgr) n : uweight lgr (S n) = 2 ^ lgr * uweight lgr n.
+ Proof.
+ rewrite !uweight_eq_alt by auto.
+ autorewrite with push_Zof_nat.
+ rewrite Z.pow_succ_r by auto using Nat2Z.is_nonneg.
+ reflexivity.
+ Qed.
Lemma uweight_1 lgr : uweight lgr 1 = 2^lgr.
Proof using Type.
@@ -3186,11 +3233,9 @@ Module WordByWordMontgomery.
autounfold with loc; intro n; destruct (zerop n).
{ cbn; intros; subst; cbn; rewrite Z.add_with_get_carry_full_mod; cbn; omega. }
intros; repeat t_step.
- repeat first [ reflexivity
- | rewrite UniformWeight.uweight_eq_alt by omega
- | progress autorewrite with push_Zof_nat
- | rewrite Z.pow_succ_r by lia
- | progress Z.rewrite_mod_small ].
+ rewrite UniformWeight.uweight_S by omega.
+ rewrite UniformWeight.uweight_eq_alt by omega.
+ Z.rewrite_mod_small; reflexivity.
Qed.
Local Lemma small_scmul : forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> small (@scmul n a v).
Proof using lgr_big.
@@ -3232,9 +3277,7 @@ Module WordByWordMontgomery.
apply Z.lt_le_trans with (m:=2 * weight n).
{ rewrite <-Z.add_diag.
auto using Z.add_lt_mono with zarith. }
- { rewrite !UniformWeight.uweight_eq_alt by omega.
- autorewrite with push_Zof_nat push_Zpow.
- rewrite Z.pow_succ_r by auto.
+ { rewrite UniformWeight.uweight_S by omega.
(* In versions newer than 8.7, auto with zarith is sufficient
to solve this from here. *)
apply Z.mul_le_mono_nonneg_r.
@@ -3273,7 +3316,35 @@ Module WordByWordMontgomery.
autorewrite with push_eval.
auto using partition_eq_mod with zarith.
Qed.
- Local Axiom eval_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> eval (conditional_sub v N) = eval v + if eval N <=? eval v then -eval N else 0.
+ Local Lemma eval_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> eval (conditional_sub v N) = eval v + if eval N <=? eval v then -eval N else 0.
+ Proof using small_N lgr_big R_numlimbs_nz N_nz N_lt_R.
+ clear - small_N ri lgr_big R_numlimbs_nz N_nz N_lt_R.
+ intros; autounfold with loc; cbv [conditional_sub].
+ repeat match goal with H : small _ |- _ =>
+ rewrite H; clear H end.
+ autorewrite with push_eval.
+ assert (eval N mod weight R_numlimbs < weight (S R_numlimbs))
+ by (rewrite UniformWeight.uweight_S, !UniformWeight.uweight_eq_alt by omega; subst r R; rewrite Z.mod_small by omega; assert (0 < 2 ^ lgr) by auto with zarith; nia).
+ rewrite Rows.conditional_sub_partitions
+ by (repeat (autorewrite with distr_length push_eval; auto using partition_eq_mod with zarith)).
+ rewrite drop_high_to_length_partition by omega.
+ (* TODO : do we need eval_drop_high_to_length? *)
+ autorewrite with push_eval.
+ assert (eval N < weight R_numlimbs) by
+ (subst r R; rewrite UniformWeight.uweight_eq_alt; omega).
+ assert (weight R_numlimbs = R) by (rewrite UniformWeight.uweight_eq_alt by omega; subst R; reflexivity).
+ assert (eval v < weight (S R_numlimbs)).
+ { apply Z.lt_le_trans with (m:=2 * R); [ lia | ].
+ subst r R; rewrite UniformWeight.uweight_S, UniformWeight.uweight_eq_alt by omega.
+ apply Z.mul_le_mono_nonneg_r; [auto with zarith | ].
+ transitivity (2 ^ 1); [ reflexivity | ];
+ apply Z.pow_le_mono_r; omega. }
+ Z.rewrite_mod_small.
+ break_match; autorewrite with zsimplify_fast; Z.ltb_to_lt.
+ { rewrite Z.add_opp_r. fold (eval N).
+ auto using Z.mod_small with lia. }
+ { auto using Z.mod_small with lia. }
+ Qed.
Local Axiom small_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> small (conditional_sub v N).
Local Axiom eval_sub_then_maybe_add : forall a b, small a -> small b -> 0 <= eval a < eval N -> 0 <= eval b < eval N -> eval (sub_then_maybe_add a b) = eval a - eval b + if eval a - eval b <? 0 then eval N else 0.
Local Axiom small_sub_then_maybe_add : forall a b, small (sub_then_maybe_add a b).
diff --git a/src/Util/ListUtil.v b/src/Util/ListUtil.v
index 9f6658c42..cb11eca3b 100644
--- a/src/Util/ListUtil.v
+++ b/src/Util/ListUtil.v
@@ -1297,6 +1297,7 @@ Proof.
revert k a; induction b as [|? IHb], k; simpl; try reflexivity.
intros; rewrite IHb; reflexivity.
Qed.
+Hint Rewrite @firstn_seq : push_firstn.
Lemma skipn_seq k a b
: skipn k (seq a b) = seq (k + a) (b - k).