diff options
author | jadep <jade.philipoom@gmail.com> | 2018-09-15 11:12:24 -0400 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2018-09-17 21:34:36 -0400 |
commit | cbe68953e46548709b36f5560dc8ecefa017efe5 (patch) | |
tree | b6cbc242ab23b86d4fbc165ddc0bdcfd9878c65a /src | |
parent | 34c169ed01f33bb4f8b561651e9cd403aa9a076d (diff) |
use recursive partition to prove eval_div axiom
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/NewPipeline/Arithmetic.v | 99 |
1 files changed, 85 insertions, 14 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index 2fc78bad1..83364be11 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -1456,7 +1456,39 @@ Module Partition. rewrite weight_0 by auto; autorewrite with zsimplify_fast. reflexivity. Qed. + + Lemma length_recursive_partition n : forall i x, + length (recursive_partition n i x) = n. + Proof. + induction n; cbn [recursive_partition]; [reflexivity | ]. + intros; distr_length; auto. + Qed. + Hint Rewrite length_recursive_partition : distr_length. + + (* TODO : move *) + Hint Rewrite Nat.sub_0_l : natsimplify. + + (* TODO : remove if this ends up being unused *) + Lemma recursive_partition_skipn : forall m n i j x, + (m < n)%nat -> + (j = i + m)%nat -> + skipn m (recursive_partition n i (x / weight i)) = recursive_partition (n-m) j (x / weight j). + Proof. + induction m; destruct n; intros; subst; + repeat match goal with + | _ => progress cbn [recursive_partition] + | _ => rewrite weight_0 by assumption + | _ => rewrite weight_multiples by assumption + | _ => rewrite Z.div_div by auto + | _ => rewrite Z.mul_div_eq by auto + | _ => progress autorewrite with zsimplify_fast natsimplify push_skipn + | _ => reflexivity + end. + erewrite IHm by (auto; omega). + repeat (f_equal; try omega). + Qed. End Partition. + Hint Rewrite length_partition length_recursive_partition : distr_length. End Partition. Module Columns. @@ -3018,9 +3050,24 @@ Module UniformWeight. Proof using Type. now cbv [uweight weight]; autorewrite with zsimplify_fast. Qed. Lemma uweight_eq_alt lgr (Hr : 0 <= lgr) n : uweight lgr n = (2^lgr)^Z.of_nat n. Proof using Type. now rewrite uweight_eq_alt', Z.pow_mul_r by lia. Qed. + Lemma uweight_eval_shift lgr (Hr : 0 <= lgr) xs : + forall n, + length xs = n -> + Positional.eval (fun i => uweight lgr (S i)) n xs = + (uweight lgr 1) * Positional.eval (uweight lgr) n xs. + Proof. + induction xs using rev_ind; destruct n; distr_length; + intros; [cbn; ring | ]. + rewrite !Positional.eval_snoc with (n:=n) by distr_length. + rewrite IHxs, !uweight_eq_alt by omega. + autorewrite with push_Zof_nat push_Zpow. + rewrite !Z.pow_succ_r by auto using Nat2Z.is_nonneg. + ring. + Qed. End UniformWeight. Module WordByWordMontgomery. + Import Partition. Section with_args. Context (lgr : Z) (m : Z). @@ -3116,7 +3163,7 @@ Module WordByWordMontgomery. Let R := (r^Z.of_nat R_numlimbs). Transparent T. Definition small {n} (v : T n) : Prop - := v = Partition.partition weight n (eval v). + := v = partition weight n (eval v). Context (small_N : small N) (N_lt_R : eval N < R) (N_nz : 0 < eval N) @@ -3140,18 +3187,18 @@ Module WordByWordMontgomery. Lemma length_small {n v} : @small n v -> length v = n. Proof using Type. clear; cbv [small]; intro H; rewrite H; autorewrite with distr_length; reflexivity. Qed. - Let partition_Proper := (@Partition.partition_Proper _ wprops). + Let partition_Proper := (@partition_Proper _ wprops). Local Existing Instance partition_Proper. Lemma eval_nonzero n A : @small n A -> nonzero A = 0 <-> @eval n A = 0. Proof using lgr_big. clear -lgr_big partition_Proper. cbv [nonzero eval small]; intro Heq. do 2 rewrite Heq. - rewrite !Partition.eval_partition, Z.mod_mod by auto. + rewrite !eval_partition, Z.mod_mod by auto. generalize (Positional.eval weight n A); clear Heq A. induction n as [|n IHn]. { cbn; rewrite weight_0 by auto; intros; autorewrite with zsimplify_const; omega. } - { intro; rewrite Partition.partition_step. + { intro; rewrite partition_step. rewrite fold_right_snoc, Z.lor_comm, <- fold_right_push, Z.lor_eq_0_iff by auto using Z.lor_assoc. assert (Heq : Z.equiv_modulo (weight n) (z mod weight (S n)) (z mod (weight n))). { cbv [Z.equiv_modulo]. @@ -3214,10 +3261,34 @@ Module WordByWordMontgomery. Qed. Local Lemma small_zero : forall n, small (@zero n). Proof using Type. - etransitivity; [ eapply Positional.zeros_ext_map | rewrite eval_zero ]; cbv [Partition.partition]; cbn; try reflexivity; autorewrite with distr_length; reflexivity. + etransitivity; [ eapply Positional.zeros_ext_map | rewrite eval_zero ]; cbv [partition]; cbn; try reflexivity; autorewrite with distr_length; reflexivity. Qed. Local Hint Immediate small_zero. - Local Axiom eval_div : forall n v, small v -> eval (fst (@divmod n v)) = eval v / r. + + (* TODO : relocate? *) + Lemma weight_1 : weight 1 = r. + Proof. + clear - lgr_big. subst r. + rewrite UniformWeight.uweight_eq_alt by omega. + cbn; ring. + Qed. + + Lemma eval_div : forall n v, small v -> eval (fst (@divmod n v)) = eval v / r. + Proof. + pose proof r_big as r_big. + clear - r_big lgr_big; autounfold with loc; intros. + repeat match goal with + | _ => progress cbn [Rows.divmod fst recursive_partition tl] + | H : _ = partition _ _ _ |- _ => rewrite H; clear H + | _ => rewrite recursive_partition_equiv by auto using UniformWeight.uwprops + | _ => progress rewrite weight_0, weight_1 by auto; + autorewrite with zsimplify_fast + | _ => rewrite Positional.eval_cons by distr_length + | _ => rewrite UniformWeight.uweight_eval_shift by distr_length + end. + autorewrite with zsimplify. + reflexivity. + Qed. Local Axiom eval_mod : forall n v, small v -> snd (@divmod n v) = eval v mod r. Local Axiom small_div : forall n v, small v -> small (fst (@divmod n v)). Local Lemma eval_scmul: forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> eval (@scmul n a v) = a * eval v. @@ -3264,7 +3335,7 @@ Module WordByWordMontgomery. clear -lgr_big Ha Hb. autounfold with loc in *; destruct (zerop n); subst. { destruct a as [| ? [|] ], b; cbn; try omega. - cbv [Partition.partition seq eval map] in Ha. + cbv [partition seq eval map] in Ha. cbn in Ha. rewrite (weight_0 wprops) in *. rewrite Z.add_with_get_carry_full_mod. @@ -3530,7 +3601,7 @@ Module WordByWordMontgomery. Lemma fst_redc_body : (eval (fst (redc_body A_S))) = eval (fst A_S) / r. - Proof using small_S small_A S_bound. + Proof using small_S small_A S_bound lgr_big. destruct A_S; simpl; repeat autounfold with word_by_word_montgomery; simpl. autorewrite with push_mont_eval. reflexivity. @@ -3862,7 +3933,7 @@ Module WordByWordMontgomery. Let r := 2^bitwidth. Local Notation weight := (UniformWeight.uweight bitwidth). Local Notation eval := (@eval bitwidth n). - Let m_enc := Partition.partition weight n m. + Let m_enc := partition weight n m. Local Coercion Z.of_nat : nat >-> Z. Context (r' : Z) (m' : Z) @@ -3916,7 +3987,7 @@ Module WordByWordMontgomery. | _ => lia | _ => exact small_m_enc | [ H : small ?x |- context[eval ?x] ] - => rewrite H; cbv [eval]; rewrite Partition.eval_partition by auto + => rewrite H; cbv [eval]; rewrite eval_partition by auto | [ |- context[weight _] ] => rewrite UniformWeight.uweight_eq_alt by auto with omega | _=> progress Z.rewrite_mod_small | _ => progress Z.zero_bounds @@ -3947,7 +4018,7 @@ Module WordByWordMontgomery. t_fin. Qed. - Definition onemod : list Z := Partition.partition weight n 1. + Definition onemod : list Z := partition weight n 1. Definition onemod_correct : eval onemod = 1 /\ valid onemod. Proof using n_nz m_big bitwidth_big. @@ -3955,7 +4026,7 @@ Module WordByWordMontgomery. cbv [valid small onemod eval]; autorewrite with push_eval; t_fin. Qed. - Definition R2mod : list Z := Partition.partition weight n ((r^n * r^n) mod m). + Definition R2mod : list Z := partition weight n ((r^n * r^n) mod m). Definition R2mod_correct : eval R2mod mod m = (r^n*r^n) mod m /\ valid R2mod. Proof using n_nz m_small m_big m'_correct bitwidth_big. @@ -4010,7 +4081,7 @@ Module WordByWordMontgomery. Qed. Definition encodemod (v : Z) : list Z - := mulmod (Partition.partition weight n v) R2mod. + := mulmod (partition weight n v) R2mod. Local Ltac t_valid v := cbv [valid]; repeat apply conj; @@ -4104,7 +4175,7 @@ Module WordByWordMontgomery. Lemma to_bytesmod_correct : (forall a (_ : valid a), Positional.eval (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (to_bytesmod a) = eval a mod m) - /\ (forall a (_ : valid a), to_bytesmod a = Partition.partition (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (eval a mod m)). + /\ (forall a (_ : valid a), to_bytesmod a = partition (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (eval a mod m)). Proof using n_nz m_small bitwidth_big. clear -n_nz m_small bitwidth_big. generalize (@length_small bitwidth n); |