From 71820cce3ba80acf0a09d7506c49ba2dd6e32d95 Mon Sep 17 00:00:00 2001 From: jadep Date: Thu, 14 Mar 2019 12:07:28 -0400 Subject: split up Arithmetic (imports etc. not yet fixed, does not build) --- src/Arithmetic/BaseConversion.v | 310 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 310 insertions(+) create mode 100644 src/Arithmetic/BaseConversion.v (limited to 'src/Arithmetic/BaseConversion.v') diff --git a/src/Arithmetic/BaseConversion.v b/src/Arithmetic/BaseConversion.v new file mode 100644 index 000000000..a22aa0c0b --- /dev/null +++ b/src/Arithmetic/BaseConversion.v @@ -0,0 +1,310 @@ + +(* TODO: prune these *) +Require Import Crypto.Algebra.Nsatz. +Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz. +Require Import Coq.Sorting.Mergesort Coq.Structures.Orders. +Require Import Coq.Sorting.Permutation. +Require Import Coq.derive.Derive. +Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. (* For MontgomeryReduction *) +Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. (* For MontgomeryReduction *) +Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable. +Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn. +Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil. +Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. +Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Arithmetic.BarrettReduction.Generalized. +Require Import Crypto.Arithmetic.ModularArithmeticTheorems. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.Tactics.RunTacticAsConstr. +Require Import Crypto.Util.Tactics.Head. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.OptionList. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.Sum. +Require Import Crypto.Util.Bool. +Require Import Crypto.Util.Sigma. +Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div Crypto.Util.ZUtil.Hints.Core. +Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. +Require Import Crypto.Util.ZUtil.Tactics.PeelLe. +Require Import Crypto.Util.ZUtil.Tactics.LinearSubstitute. +Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. +Require Import Crypto.Util.ZUtil.Modulo.PullPush. +Require Import Crypto.Util.ZUtil.Opp. +Require Import Crypto.Util.ZUtil.Log2. +Require Import Crypto.Util.ZUtil.Le. +Require Import Crypto.Util.ZUtil.Hints.PullPush. +Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.Tactics.SplitInContext. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.Sorting. +Require Import Crypto.Util.ZUtil.CC Crypto.Util.ZUtil.Rshi. +Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. +Require Import Crypto.Util.ZUtil.Hints.Core. +Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div. +Require Import Crypto.Util.ZUtil.Hints.PullPush. +Require Import Crypto.Util.ZUtil.EquivModulo. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.CPSNotations. +Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.SetEvars. +Import Coq.Lists.List ListNotations. Local Open Scope Z_scope. + +Module BaseConversion. + Import Positional. Import Partition. + Section BaseConversion. + Hint Resolve Z.positive_is_nonzero Z.lt_gt Z.gt_lt. + Context (sw dw : nat -> Z) (* source/destination weight functions *) + {swprops : @weight_properties sw} + {dwprops : @weight_properties dw}. + + Definition convert_bases (sn dn : nat) (p : list Z) : list Z := + let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in + chained_carries_no_reduce dw dn p' (seq 0 (pred dn)). + + Lemma eval_convert_bases sn dn p : + (dn <> 0%nat) -> length p = sn -> + eval dw dn (convert_bases sn dn p) = eval sw sn p. + Proof using dwprops. + cbv [convert_bases]; intros. + rewrite eval_chained_carries_no_reduce by auto. + rewrite eval_from_associational; auto. + Qed. + + Lemma length_convert_bases sn dn p + : length (convert_bases sn dn p) = dn. + Proof using Type. + cbv [convert_bases]; now repeat autorewrite with distr_length. + Qed. + Hint Rewrite length_convert_bases : distr_length. + + Lemma convert_bases_partitions sn dn p + (dw_unique : forall i j : nat, (i <= dn)%nat -> (j <= dn)%nat -> dw i = dw j -> i = j) + (p_bounded : 0 <= eval sw sn p < dw dn) + : convert_bases sn dn p = partition dw dn (eval sw sn p). + Proof using dwprops. + apply list_elementwise_eq; intro i. + destruct (lt_dec i dn); [ | now rewrite !nth_error_length_error by distr_length ]. + erewrite !(@nth_error_Some_nth_default _ _ 0) by (break_match; distr_length). + apply f_equal. + cbv [convert_bases partition]. + unshelve erewrite map_nth_default, nth_default_chained_carries_no_reduce_pred; + repeat first [ progress autorewrite with distr_length push_eval + | rewrite eval_from_associational, eval_to_associational + | rewrite nth_default_seq_inbounds + | apply dwprops + | destruct dwprops; now auto with zarith ]. + Qed. + + Hint Rewrite + @Rows.eval_from_associational + @Associational.eval_carry + @Associational.eval_mul + @Positional.eval_to_associational + Associational.eval_carryterm + @eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval. + + Ltac push_eval := intros; autorewrite with push_eval; auto with zarith. + + (* convert from positional in one weight to the other, then to associational *) + Definition to_associational n m p : list (Z * Z) := + let p' := convert_bases n m p in + Positional.to_associational dw m p'. + + (* TODO : move to Associational? *) + Section reorder. + Definition reordering_carry (w fw : Z) (p : list (Z * Z)) := + fold_right (fun t acc => + let r := Associational.carryterm w fw t in + if fst t =? w then acc ++ r else r ++ acc) nil p. + + Lemma eval_reordering_carry w fw p (_:fw<>0): + Associational.eval (reordering_carry w fw p) = Associational.eval p. + Proof using Type. + cbv [reordering_carry]. induction p; [reflexivity |]. + autorewrite with push_fold_right. break_match; push_eval. + Qed. + End reorder. + Hint Rewrite eval_reordering_carry using solve [auto using Z.positive_is_nonzero] : push_eval. + + (* carry at specified indices in dw, then use Rows.flatten to convert to Positional with sw *) + Definition from_associational idxs n (p : list (Z * Z)) : list Z := + (* important not to use Positional.carry here; we don't want to accumulate yet *) + let p' := fold_right (fun i acc => reordering_carry (dw i) (dw (S i) / dw i) acc) (Associational.bind_snd p) (rev idxs) in + fst (Rows.flatten sw n (Rows.from_associational sw n p')). + + Lemma eval_carries p idxs : + Associational.eval (fold_right (fun i acc => reordering_carry (dw i) (dw (S i) / dw i) acc) p idxs) = + Associational.eval p. + Proof using dwprops. apply fold_right_invariant; push_eval. Qed. + Hint Rewrite eval_carries: push_eval. + + Lemma eval_to_associational n m p : + m <> 0%nat -> length p = n -> + Associational.eval (to_associational n m p) = Positional.eval sw n p. + Proof using dwprops. cbv [to_associational]; push_eval. Qed. + Hint Rewrite eval_to_associational using solve [push_eval; distr_length] : push_eval. + + Lemma eval_from_associational idxs n p : + n <> 0%nat -> 0 <= Associational.eval p < sw n -> + Positional.eval sw n (from_associational idxs n p) = Associational.eval p. + Proof using dwprops swprops. + cbv [from_associational]; intros. + rewrite Rows.flatten_mod by eauto using Rows.length_from_associational. + rewrite Associational.bind_snd_correct. + push_eval. + Qed. + Hint Rewrite eval_from_associational using solve [push_eval; distr_length] : push_eval. + + Lemma from_associational_partitions n idxs p (_:n<>0%nat): + from_associational idxs n p = partition sw n (Associational.eval p). + Proof using dwprops swprops. + intros. cbv [from_associational]. + rewrite Rows.flatten_correct with (n:=n) by eauto using Rows.length_from_associational. + rewrite Associational.bind_snd_correct. + push_eval. + Qed. + + Derive from_associational_inlined + SuchThat (forall idxs n p, + from_associational_inlined idxs n p = from_associational idxs n p) + As from_associational_inlined_correct. + Proof. + intros. + cbv beta iota delta [from_associational reordering_carry Associational.carryterm]. + cbv beta iota delta [Let_In]. (* inlines all shifts/lands from carryterm *) + cbv beta iota delta [from_associational Rows.from_associational Columns.from_associational]. + cbv beta iota delta [Let_In]. (* inlines the shifts from place *) + subst from_associational_inlined; reflexivity. + Qed. + + Derive to_associational_inlined + SuchThat (forall n m p, + to_associational_inlined n m p = to_associational n m p) + As to_associational_inlined_correct. + Proof. + intros. + cbv beta iota delta [ to_associational convert_bases + Positional.to_associational + Positional.from_associational + chained_carries_no_reduce + carry + Associational.carry + Associational.carryterm + ]. + cbv beta iota delta [Let_In]. + subst to_associational_inlined; reflexivity. + Qed. + + (* carry chain that aligns terms in the intermediate weight with the final weight *) + Definition aligned_carries (log_dw_sw nout : nat) + := (map (fun i => ((log_dw_sw * (i + 1)) - 1))%nat (seq 0 nout)). + + Section mul_converted. + Definition mul_converted + n1 n2 (* lengths in original format *) + m1 m2 (* lengths in converted format *) + (n3 : nat) (* final length *) + (idxs : list nat) (* carries to do -- this helps preemptively line up weights *) + (p1 p2 : list Z) := + let p1_a := to_associational n1 m1 p1 in + let p2_a := to_associational n2 m2 p2 in + let p3_a := Associational.mul p1_a p2_a in + from_associational idxs n3 p3_a. + + Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + length p1 = n1 -> length p2 = n2 -> + 0 <= (Positional.eval sw n1 p1 * Positional.eval sw n2 p2) < sw n3 -> + Positional.eval sw n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval sw n1 p1) * (Positional.eval sw n2 p2). + Proof using dwprops swprops. cbv [mul_converted]; push_eval. Qed. + Hint Rewrite eval_mul_converted : push_eval. + + Lemma mul_converted_partitions n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + length p1 = n1 -> length p2 = n2 -> + mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = partition sw n3 (Positional.eval sw n1 p1 * Positional.eval sw n2 p2). + Proof using dwprops swprops. + intros; cbv [mul_converted]. + rewrite from_associational_partitions by auto. push_eval. + Qed. + End mul_converted. + End BaseConversion. + Hint Rewrite length_convert_bases : distr_length. + + (* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *) + Section widemul. + Context (log2base : Z) (log2base_pos : 0 < log2base). + Context (m n : nat) (m_nz : m <> 0%nat) (n_nz : n <> 0%nat) (n_le_log2base : Z.of_nat n <= log2base). + Let dw : nat -> Z := weight (log2base / Z.of_nat n) 1. + Let sw : nat -> Z := weight log2base 1. + Let mn := (m * n)%nat. + Let nout := (m * 2)%nat. + + Local Lemma mn_nonzero : mn <> 0%nat. Proof. subst mn. apply Nat.neq_mul_0. auto. Qed. + Local Hint Resolve mn_nonzero. + Local Lemma nout_nonzero : nout <> 0%nat. Proof. subst nout. apply Nat.neq_mul_0. auto. Qed. + Local Hint Resolve nout_nonzero. + Local Lemma base_bounds : 0 < 1 <= log2base. Proof using log2base_pos. clear -log2base_pos; auto with zarith. Qed. + Local Lemma dbase_bounds : 0 < 1 <= log2base / Z.of_nat n. Proof using n_nz n_le_log2base. clear -n_nz n_le_log2base; auto with zarith. Qed. + Let dwprops : @weight_properties dw := wprops (log2base / Z.of_nat n) 1 dbase_bounds. + Let swprops : @weight_properties sw := wprops log2base 1 base_bounds. + Local Notation deval := (Positional.eval dw). + Local Notation seval := (Positional.eval sw). + + Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg. + + Definition widemul a b := mul_converted sw dw m m mn mn nout (aligned_carries n nout) a b. + + Lemma widemul_correct a b : + length a = m -> + length b = m -> + widemul a b = Partition.partition sw nout (seval m a * seval m b). + Proof. apply mul_converted_partitions; auto with zarith. Qed. + + Derive widemul_inlined + SuchThat (forall a b, + length a = m -> + length b = m -> + widemul_inlined a b = Partition.partition sw nout (seval m a * seval m b)) + As widemul_inlined_correct. + Proof. + intros. + rewrite <-widemul_correct by auto. + cbv beta iota delta [widemul mul_converted]. + rewrite <-to_associational_inlined_correct with (p:=a). + rewrite <-to_associational_inlined_correct with (p:=b). + rewrite <-from_associational_inlined_correct. + subst widemul_inlined; reflexivity. + Qed. + + Derive widemul_inlined_reverse + SuchThat (forall a b, + length a = m -> + length b = m -> + widemul_inlined_reverse a b = Partition.partition sw nout (seval m a * seval m b)) + As widemul_inlined_reverse_correct. + Proof. + intros. + rewrite <-widemul_inlined_correct by assumption. + cbv [widemul_inlined]. + match goal with |- _ = from_associational_inlined sw dw ?idxs ?n ?p => + transitivity (from_associational_inlined sw dw idxs n (rev p)); + [ | transitivity (from_associational sw dw idxs n p); [ | reflexivity ] ](* reverse to make addc chains line up *) + end. + { subst widemul_inlined_reverse; reflexivity. } + { rewrite from_associational_inlined_correct by auto. + cbv [from_associational]. + rewrite !Rows.flatten_correct by eauto using Rows.length_from_associational. + rewrite !Rows.eval_from_associational by auto. + f_equal. + rewrite !eval_carries, !Associational.bind_snd_correct, !Associational.eval_rev by auto. + reflexivity. } + Qed. + End widemul. +End BaseConversion. \ No newline at end of file -- cgit v1.2.3