aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/NewPipeline/Arithmetic.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Experiments/NewPipeline/Arithmetic.v')
-rw-r--r--src/Experiments/NewPipeline/Arithmetic.v1962
1 files changed, 1962 insertions, 0 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v
new file mode 100644
index 000000000..d7fdf0306
--- /dev/null
+++ b/src/Experiments/NewPipeline/Arithmetic.v
@@ -0,0 +1,1962 @@
+(* Following http://adam.chlipala.net/theses/andreser.pdf chapter 3 *)
+Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz.
+Require Import Coq.derive.Derive.
+Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable.
+Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn.
+Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil.
+Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil.
+Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop.
+Require Import Crypto.Arithmetic.BarrettReduction.Generalized.
+Require Import Crypto.Arithmetic.MontgomeryReduction.Definition.
+Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs.
+Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo.
+Require Import Crypto.Util.ZRange.
+Require Import Crypto.Util.ZRange.Operations.
+Require Import Crypto.Util.Tactics.RunTacticAsConstr.
+Require Import Crypto.Util.Tactics.Head.
+Require Import Crypto.Util.Option.
+Require Import Crypto.Util.OptionList.
+Require Import Crypto.Util.Sum.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div Crypto.Util.ZUtil.Hints.Core.
+Require Import Crypto.Util.ZUtil.Hints.PullPush.
+Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit.
+Require Import Crypto.Util.ZUtil.Tactics.LtbToLt.
+Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo.
+Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem.
+Require Import Crypto.Util.Tactics.SpecializeBy.
+Require Import Crypto.Util.Tactics.SplitInContext.
+Require Import Crypto.Util.Tactics.SubstEvars.
+Require Import Crypto.Util.Notations.
+Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.CC Crypto.Util.ZUtil.Rshi.
+Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo.
+Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit.
+Require Import Crypto.Util.ZUtil Crypto.Util.ZUtil.Hints.Core.
+Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div.
+Require Import Crypto.Util.ZUtil.Hints.PullPush.
+Require Import Crypto.Util.ZUtil.EquivModulo.
+Require Import Crypto.Util.Tactics.DebugPrint.
+Require Import Crypto.Util.CPSNotations.
+Require Import Crypto.Util.Equality.
+Import ListNotations. Local Open Scope Z_scope.
+
+Module Associational.
+ Definition eval (p:list (Z*Z)) : Z :=
+ fold_right (fun x y => x + y) 0%Z (map (fun t => fst t * snd t) p).
+
+ Lemma eval_nil : eval nil = 0.
+ Proof. trivial. Qed.
+ Lemma eval_cons p q : eval (p::q) = fst p * snd p + eval q.
+ Proof. trivial. Qed.
+ Lemma eval_app p q: eval (p++q) = eval p + eval q.
+ Proof. induction p; rewrite <-?List.app_comm_cons;
+ rewrite ?eval_nil, ?eval_cons; nsatz. Qed.
+
+ Hint Rewrite eval_nil eval_cons eval_app : push_eval.
+ Local Ltac push := autorewrite with
+ push_eval push_map push_partition push_flat_map
+ push_fold_right push_nth_default cancel_pair.
+
+ Lemma eval_map_mul (a x:Z) (p:list (Z*Z))
+ : eval (List.map (fun t => (a*fst t, x*snd t)) p) = a*x*eval p.
+ Proof. induction p; push; nsatz. Qed.
+ Hint Rewrite eval_map_mul : push_eval.
+
+ Definition mul (p q:list (Z*Z)) : list (Z*Z) :=
+ flat_map (fun t =>
+ map (fun t' =>
+ (fst t * fst t', snd t * snd t'))
+ q) p.
+ Lemma eval_mul p q : eval (mul p q) = eval p * eval q.
+ Proof. induction p; cbv [mul]; push; nsatz. Qed.
+ Hint Rewrite eval_mul : push_eval.
+
+ Definition negate_snd (p:list (Z*Z)) : list (Z*Z) :=
+ map (fun cx => (fst cx, -snd cx)) p.
+ Lemma eval_negate_snd p : eval (negate_snd p) = - eval p.
+ Proof. induction p; cbv [negate_snd]; push; nsatz. Qed.
+ Hint Rewrite eval_negate_snd : push_eval.
+
+ Example base10_2digit_mul (a0:Z) (a1:Z) (b0:Z) (b1:Z) :
+ {ab| eval ab = eval [(10,a1);(1,a0)] * eval [(10,b1);(1,b0)]}.
+ eexists ?[ab].
+ (* Goal: eval ?ab = eval [(10,a1);(1,a0)] * eval [(10,b1);(1,b0)] *)
+ rewrite <-eval_mul.
+ (* Goal: eval ?ab = eval (mul [(10,a1);(1,a0)] [(10,b1);(1,b0)]) *)
+ cbv -[Z.mul eval]; cbn -[eval].
+ (* Goal: eval ?ab = eval [(100,(a1*b1));(10,a1*b0);(10,a0*b1);(1,a0*b0)]%RT *)
+ trivial. Defined.
+
+ Definition split (s:Z) (p:list (Z*Z)) : list (Z*Z) * list (Z*Z)
+ := let hi_lo := partition (fun t => fst t mod s =? 0) p in
+ (snd hi_lo, map (fun t => (fst t / s, snd t)) (fst hi_lo)).
+ Lemma eval_split s p (s_nz:s<>0) :
+ eval (fst (split s p)) + s * eval (snd (split s p)) = eval p.
+ Proof. cbv [Let_In split]; induction p;
+ repeat match goal with
+ | |- context[?a/?b] =>
+ unique pose proof (Z_div_exact_full_2 a b ltac:(trivial) ltac:(trivial))
+ | _ => progress push
+ | _ => progress break_match
+ | _ => progress nsatz end. Qed.
+
+ Lemma reduction_rule a b s c (modulus_nz:s-c<>0) :
+ (a + s * b) mod (s - c) = (a + c * b) mod (s - c).
+ Proof. replace (a + s * b) with ((a + c*b) + b*(s-c)) by nsatz.
+ rewrite Z.add_mod,Z_mod_mult,Z.add_0_r,Z.mod_mod;trivial. Qed.
+
+ Definition reduce (s:Z) (c:list _) (p:list _) : list (Z*Z) :=
+ let lo_hi := split s p in fst lo_hi ++ mul c (snd lo_hi).
+
+ Lemma eval_reduce s c p (s_nz:s<>0) (modulus_nz:s-eval c<>0) :
+ eval (reduce s c p) mod (s - eval c) = eval p mod (s - eval c).
+ Proof. cbv [reduce]; push.
+ rewrite <-reduction_rule, eval_split; trivial. Qed.
+ Hint Rewrite eval_reduce : push_eval.
+
+ Definition bind_snd (p : list (Z*Z)) :=
+ map (fun t => dlet_nd t2 := snd t in (fst t, t2)) p.
+
+ Lemma bind_snd_correct p : bind_snd p = p.
+ Proof.
+ cbv [bind_snd]; induction p as [| [? ?] ];
+ push; [|rewrite IHp]; reflexivity.
+ Qed.
+
+ Lemma eval_rev p : eval (rev p) = eval p.
+ Proof. induction p; cbn [rev]; push; lia. Qed.
+
+ Section Carries.
+ Definition carryterm (w fw:Z) (t:Z * Z) :=
+ if (Z.eqb (fst t) w)
+ then dlet_nd t2 := snd t in
+ dlet_nd d2 := t2 / fw in
+ dlet_nd m2 := t2 mod fw in
+ [(w * fw, d2);(w,m2)]
+ else [t].
+
+ Lemma eval_carryterm w fw (t:Z * Z) (fw_nonzero:fw<>0):
+ eval (carryterm w fw t) = eval [t].
+ Proof using Type*.
+ cbv [carryterm Let_In]; break_match; push; [|trivial].
+ pose proof (Z.div_mod (snd t) fw fw_nonzero).
+ rewrite Z.eqb_eq in *.
+ nsatz.
+ Qed. Hint Rewrite eval_carryterm using auto : push_eval.
+
+ Definition carry (w fw:Z) (p:list (Z * Z)):=
+ flat_map (carryterm w fw) p.
+
+ Lemma eval_carry w fw p (fw_nonzero:fw<>0):
+ eval (carry w fw p) = eval p.
+ Proof using Type*. cbv [carry]; induction p; push; nsatz. Qed.
+ Hint Rewrite eval_carry using auto : push_eval.
+ End Carries.
+End Associational.
+
+Module Positional. Section Positional.
+ Context (weight : nat -> Z)
+ (weight_0 : weight 0%nat = 1)
+ (weight_nz : forall i, weight i <> 0).
+
+ Definition to_associational (n:nat) (xs:list Z) : list (Z*Z)
+ := combine (map weight (List.seq 0 n)) xs.
+ Definition eval n x := Associational.eval (@to_associational n x).
+ Lemma eval_to_associational n x :
+ Associational.eval (@to_associational n x) = eval n x.
+ Proof. trivial. Qed.
+ Hint Rewrite @eval_to_associational : push_eval.
+ Lemma eval_nil n : eval n [] = 0.
+ Proof. cbv [eval to_associational]. rewrite combine_nil_r. reflexivity. Qed.
+ Hint Rewrite eval_nil : push_eval.
+ Lemma eval0 p : eval 0 p = 0.
+ Proof. cbv [eval to_associational]. reflexivity. Qed.
+ Hint Rewrite eval0 : push_eval.
+
+ Lemma eval_snoc n m x y : n = length x -> m = S n -> eval m (x ++ [y]) = eval n x + weight n * y.
+ Proof.
+ cbv [eval to_associational]; intros; subst n m.
+ rewrite seq_snoc, map_app.
+ rewrite combine_app_samelength by distr_length.
+ autorewrite with push_eval. simpl.
+ autorewrite with push_eval cancel_pair; ring.
+ Qed.
+
+ (* SKIP over this: zeros, add_to_nth *)
+ Local Ltac push := autorewrite with push_eval push_map distr_length
+ push_flat_map push_fold_right push_nth_default cancel_pair natsimplify.
+ Definition zeros n : list Z := repeat 0 n.
+ Lemma length_zeros n : length (zeros n) = n. Proof. cbv [zeros]; distr_length. Qed.
+ Hint Rewrite length_zeros : distr_length.
+ Lemma eval_zeros n : eval n (zeros n) = 0.
+ Proof.
+ cbv [eval Associational.eval to_associational zeros].
+ rewrite <- (seq_length n 0) at 2.
+ generalize dependent (List.seq 0 n); intro xs.
+ induction xs; simpl; nsatz. Qed.
+ Definition add_to_nth i x (ls : list Z) : list Z
+ := ListUtil.update_nth i (fun y => x + y) ls.
+ Lemma length_add_to_nth i x ls : length (add_to_nth i x ls) = length ls.
+ Proof. cbv [add_to_nth]; distr_length. Qed.
+ Hint Rewrite length_add_to_nth : distr_length.
+ Lemma eval_add_to_nth (n:nat) (i:nat) (x:Z) (xs:list Z) (H:(i<length xs)%nat)
+ (Hn : length xs = n) (* N.B. We really only need [i < Nat.min n (length xs)] *) :
+ eval n (add_to_nth i x xs) = weight i * x + eval n xs.
+ Proof.
+ subst n.
+ cbv [eval to_associational add_to_nth].
+ rewrite ListUtil.combine_update_nth_r at 1.
+ rewrite <-(update_nth_id i (List.combine _ _)) at 2.
+ rewrite <-!(ListUtil.splice_nth_equiv_update_nth_update _ _
+ (weight 0, 0)) by (push; lia); cbv [ListUtil.splice_nth id].
+ repeat match goal with
+ | _ => progress push
+ | _ => progress break_match
+ | _ => progress (apply Zminus_eq; ring_simplify)
+ | _ => rewrite <-ListUtil.map_nth_default_always
+ end; lia. Qed.
+ Hint Rewrite @eval_add_to_nth eval_zeros : push_eval.
+
+ Definition place (t:Z*Z) (i:nat) : nat * Z :=
+ nat_rect
+ (fun _ => (nat * Z)%type)
+ (O, fst t * snd t)
+ (fun i' place_i'
+ => let i := S i' in
+ if (fst t mod weight i =? 0)
+ then (i, let c := fst t / weight i in c * snd t)
+ else place_i')
+ i.
+
+ Lemma place_in_range (t:Z*Z) (n:nat) : (fst (place t n) < S n)%nat.
+ Proof. induction n; cbv [place nat_rect] in *; break_match; autorewrite with cancel_pair; try omega. Qed.
+ Lemma weight_place t i : weight (fst (place t i)) * snd (place t i) = fst t * snd t.
+ Proof. induction i; cbv [place nat_rect] in *; break_match; push;
+ repeat match goal with |- context[?a/?b] =>
+ unique pose proof (Z_div_exact_full_2 a b ltac:(auto) ltac:(auto))
+ end; nsatz. Qed.
+ Hint Rewrite weight_place : push_eval.
+
+ Definition from_associational n (p:list (Z*Z)) :=
+ List.fold_right (fun t ls =>
+ dlet_nd p := place t (pred n) in
+ add_to_nth (fst p) (snd p) ls ) (zeros n) p.
+ Lemma eval_from_associational n p (n_nz:n<>O \/ p = nil) :
+ eval n (from_associational n p) = Associational.eval p.
+ Proof. destruct n_nz; [ induction p | subst p ];
+ cbv [from_associational Let_In] in *; push; try
+ pose proof place_in_range a (pred n); try omega; try nsatz;
+ apply fold_right_invariant; cbv [zeros add_to_nth];
+ intros; rewrite ?map_length, ?List.repeat_length, ?seq_length, ?length_update_nth;
+ try omega. Qed.
+ Hint Rewrite @eval_from_associational : push_eval.
+ Lemma length_from_associational n p : length (from_associational n p) = n.
+ Proof. cbv [from_associational Let_In]. apply fold_right_invariant; intros; distr_length. Qed.
+ Hint Rewrite length_from_associational : distr_length.
+
+ Section mulmod.
+ Context (s:Z) (s_nz:s <> 0)
+ (c:list (Z*Z))
+ (m_nz:s - Associational.eval c <> 0).
+ Definition mulmod (n:nat) (a b:list Z) : list Z
+ := let a_a := to_associational n a in
+ let b_a := to_associational n b in
+ let ab_a := Associational.mul a_a b_a in
+ let abm_a := Associational.reduce s c ab_a in
+ from_associational n abm_a.
+ Lemma eval_mulmod n (f g:list Z)
+ (Hf : length f = n) (Hg : length g = n) :
+ eval n (mulmod n f g) mod (s - Associational.eval c)
+ = (eval n f * eval n g) mod (s - Associational.eval c).
+ Proof. cbv [mulmod]; push; trivial.
+ destruct f, g; simpl in *; [ right; subst n | left; try omega.. ].
+ clear; cbv -[Associational.reduce].
+ induction c as [|?? IHc]; simpl; trivial. Qed.
+ End mulmod.
+ Hint Rewrite @eval_mulmod : push_eval.
+
+ Definition add (n:nat) (a b:list Z) : list Z
+ := let a_a := to_associational n a in
+ let b_a := to_associational n b in
+ from_associational n (a_a ++ b_a).
+ Lemma eval_add n (f g:list Z)
+ (Hf : length f = n) (Hg : length g = n) :
+ eval n (add n f g) = (eval n f + eval n g).
+ Proof. cbv [add]; push; trivial. destruct n; auto. Qed.
+ Hint Rewrite @eval_add : push_eval.
+ Lemma length_add n f g
+ (Hf : length f = n) (Hg : length g = n) :
+ length (add n f g) = n.
+ Proof. clear -Hf Hf; cbv [add]; distr_length. Qed.
+ Hint Rewrite @length_add : distr_length.
+
+ Section Carries.
+ Definition carry n m (index:nat) (p:list Z) : list Z :=
+ from_associational
+ m (@Associational.carry (weight index)
+ (weight (S index) / weight index)
+ (to_associational n p)).
+
+ Lemma length_carry n m index p : length (carry n m index p) = m.
+ Proof. cbv [carry]; distr_length. Qed.
+ Lemma eval_carry n m i p: (n <> 0%nat) -> (m <> 0%nat) ->
+ weight (S i) / weight i <> 0 ->
+ eval m (carry n m i p) = eval n p.
+ Proof.
+ cbv [carry]; intros; push; [|tauto].
+ rewrite @Associational.eval_carry by eauto.
+ apply eval_to_associational.
+ Qed. Hint Rewrite @eval_carry : push_eval.
+
+ Definition carry_reduce n (s:Z) (c:list (Z * Z))
+ (index:nat) (p : list Z) :=
+ from_associational
+ n (Associational.reduce
+ s c (to_associational (S n) (@carry n (S n) index p))).
+
+ Lemma eval_carry_reduce n s c index p :
+ (s <> 0) -> (s - Associational.eval c <> 0) -> (n <> 0%nat) ->
+ (weight (S index) / weight index <> 0) ->
+ eval n (carry_reduce n s c index p) mod (s - Associational.eval c)
+ = eval n p mod (s - Associational.eval c).
+ Proof. cbv [carry_reduce]; intros; push; auto. Qed.
+ Hint Rewrite @eval_carry_reduce : push_eval.
+ Lemma length_carry_reduce n s c index p
+ : length p = n -> length (carry_reduce n s c index p) = n.
+ Proof. cbv [carry_reduce]; distr_length. Qed.
+ Hint Rewrite @length_carry_reduce : distr_length.
+
+ (* N.B. It is important to reverse [idxs] here, because fold_right is
+ written such that the first terms in the list are actually used
+ last in the computation. For example, running:
+
+ `Eval cbv - [Z.add] in (fun a b c d => fold_right Z.add d [a;b;c]).`
+
+ will produce [fun a b c d => (a + (b + (c + d)))].*)
+ Definition chained_carries n s c p (idxs : list nat) :=
+ fold_right (fun a b => carry_reduce n s c a b) p (rev idxs).
+
+ Lemma eval_chained_carries n s c p idxs :
+ (s <> 0) -> (s - Associational.eval c <> 0) -> (n <> 0%nat) ->
+ (forall i, In i idxs -> weight (S i) / weight i <> 0) ->
+ eval n (chained_carries n s c p idxs) mod (s - Associational.eval c)
+ = eval n p mod (s - Associational.eval c).
+ Proof using Type*.
+ cbv [chained_carries]; intros; push.
+ apply fold_right_invariant; [|intro; rewrite <-in_rev];
+ destruct n; intros; push; auto.
+ Qed. Hint Rewrite @eval_chained_carries : push_eval.
+ Lemma length_chained_carries n s c p idxs
+ : length p = n -> length (@chained_carries n s c p idxs) = n.
+ Proof.
+ intros; cbv [chained_carries]; induction (rev idxs) as [|x xs IHxs];
+ cbn [fold_right]; distr_length.
+ Qed. Hint Rewrite @length_chained_carries : distr_length.
+
+ (* carries without modular reduction; useful for converting between bases *)
+ Definition chained_carries_no_reduce n p (idxs : list nat) :=
+ fold_right (fun a b => carry n n a b) p (rev idxs).
+ Lemma eval_chained_carries_no_reduce n p idxs:
+ (forall i, In i idxs -> weight (S i) / weight i <> 0) ->
+ eval n (chained_carries_no_reduce n p idxs) = eval n p.
+ Proof.
+ cbv [chained_carries_no_reduce]; intros.
+ destruct n; [push;reflexivity|].
+ apply fold_right_invariant; [|intro; rewrite <-in_rev];
+ intros; push; auto.
+ Qed. Hint Rewrite @eval_chained_carries_no_reduce : push_eval.
+
+ (* Reverse of [eval]; translate from Z to basesystem by putting
+ everything in first digit and then carrying. *)
+ Definition encode n s c (x : Z) : list Z :=
+ chained_carries n s c (from_associational n [(1,x)]) (seq 0 n).
+ Lemma eval_encode n s c x :
+ (s <> 0) -> (s - Associational.eval c <> 0) -> (n <> 0%nat) ->
+ (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) ->
+ eval n (encode n s c x) mod (s - Associational.eval c)
+ = x mod (s - Associational.eval c).
+ Proof using Type*. cbv [encode]; intros; push; auto; f_equal; omega. Qed.
+ Lemma length_encode n s c x
+ : length (encode n s c x) = n.
+ Proof. cbv [encode]; repeat distr_length. Qed.
+
+ End Carries.
+ Hint Rewrite @eval_encode : push_eval.
+ Hint Rewrite @length_encode : distr_length.
+
+ Section sub.
+ Context (n:nat)
+ (s:Z) (s_nz:s <> 0)
+ (c:list (Z * Z))
+ (m_nz:s - Associational.eval c <> 0)
+ (coef:Z).
+
+ Definition negate_snd (a:list Z) : list Z
+ := let A := to_associational n a in
+ let negA := Associational.negate_snd A in
+ from_associational n negA.
+
+ Definition scmul (x:Z) (a:list Z) : list Z
+ := let A := to_associational n a in
+ let R := Associational.mul A [(1, x)] in
+ from_associational n R.
+
+ Definition balance : list Z
+ := scmul coef (encode n s c (s - Associational.eval c)).
+
+ Definition sub (a b:list Z) : list Z
+ := let ca := add n balance a in
+ let _b := negate_snd b in
+ add n ca _b.
+ Lemma eval_sub a b
+ : (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) ->
+ (List.length a = n) -> (List.length b = n) ->
+ eval n (sub a b) mod (s - Associational.eval c)
+ = (eval n a - eval n b) mod (s - Associational.eval c).
+ Proof.
+ destruct (zerop n); subst; try reflexivity.
+ intros; cbv [sub balance scmul negate_snd]; push; repeat distr_length;
+ eauto with omega.
+ push_Zmod; push; pull_Zmod; push_Zmod; pull_Zmod; distr_length; eauto.
+ Qed.
+ Hint Rewrite eval_sub : push_eval.
+ Lemma length_sub a b
+ : length a = n -> length b = n ->
+ length (sub a b) = n.
+ Proof. intros; cbv [sub balance scmul negate_snd]; repeat distr_length. Qed.
+ Hint Rewrite length_sub : distr_length.
+ Definition opp (a:list Z) : list Z
+ := sub (zeros n) a.
+ Lemma eval_opp
+ (a:list Z)
+ : (length a = n) ->
+ (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) ->
+ eval n (opp a) mod (s - Associational.eval c)
+ = (- eval n a) mod (s - Associational.eval c).
+ Proof. intros; cbv [opp]; push; distr_length; auto. Qed.
+ Lemma length_opp a
+ : length a = n -> length (opp a) = n.
+ Proof. cbv [opp]; intros; repeat distr_length. Qed.
+ End sub.
+ Hint Rewrite @eval_opp @eval_sub : push_eval.
+ Hint Rewrite @length_sub @length_opp : distr_length.
+End Positional.
+(* Hint Rewrite disappears after the end of a section *)
+Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp : distr_length.
+End Positional.
+
+Record weight_properties {weight : nat -> Z} :=
+ {
+ weight_0 : weight 0%nat = 1;
+ weight_positive : forall i, 0 < weight i;
+ weight_multiples : forall i, weight (S i) mod weight i = 0;
+ weight_divides : forall i : nat, 0 < weight (S i) / weight i;
+ }.
+Hint Resolve weight_0 weight_positive weight_multiples weight_divides.
+
+Section mod_ops.
+ Import Positional.
+ Local Coercion Z.of_nat : nat >-> Z.
+ Local Coercion QArith_base.inject_Z : Z >-> Q.
+ (* Design constraints:
+ - inputs must be [Z] (b/c reification does not support Q)
+ - internal structure must not match on the arguments (b/c reification does not support [positive]) *)
+ Context (limbwidth_num limbwidth_den : Z)
+ (limbwidth_good : 0 < limbwidth_den <= limbwidth_num)
+ (s : Z)
+ (c : list (Z*Z))
+ (n : nat)
+ (len_c : nat)
+ (idxs : list nat)
+ (len_idxs : nat)
+ (m_nz:s - Associational.eval c <> 0) (s_nz:s <> 0)
+ (Hn_nz : n <> 0%nat)
+ (Hc : length c = len_c)
+ (Hidxs : length idxs = len_idxs).
+ Definition weight (i : nat)
+ := 2^(-(-(limbwidth_num * i) / limbwidth_den)).
+
+ Local Ltac Q_cbv :=
+ cbv [Qceiling inject_Z Qle Qfloor Qdiv Qnum Qden Qmult Qinv Qopp].
+
+ Local Lemma weight_ZQ_correct i
+ (limbwidth := (limbwidth_num / limbwidth_den)%Q)
+ : weight i = 2^Qceiling(limbwidth*i).
+ Proof.
+ clear -limbwidth_good.
+ cbv [limbwidth weight]; Q_cbv.
+ destruct limbwidth_num, limbwidth_den, i; try reflexivity;
+ repeat rewrite ?Pos.mul_1_l, ?Pos.mul_1_r, ?Z.mul_0_l, ?Zdiv_0_l, ?Zdiv_0_r, ?Z.mul_1_l, ?Z.mul_1_r, <- ?Z.opp_eq_mul_m1, ?Pos2Z.opp_pos;
+ try reflexivity; try lia.
+ Qed.
+
+ Local Ltac t_weight_with lem :=
+ clear -limbwidth_good;
+ intros; rewrite !weight_ZQ_correct;
+ apply lem;
+ try omega; Q_cbv; destruct limbwidth_den; cbn; try lia.
+
+ Definition wprops : @weight_properties weight.
+ Proof.
+ constructor.
+ { cbv [weight Z.of_nat]; autorewrite with zsimplify_fast; reflexivity. }
+ { intros; apply Z.gt_lt. t_weight_with (@pow_ceil_mul_nat_pos 2). }
+ { t_weight_with (@pow_ceil_mul_nat_multiples 2). }
+ { intros; apply Z.gt_lt. t_weight_with (@pow_ceil_mul_nat_divide 2). }
+ Defined.
+ Local Hint Immediate (weight_0 wprops).
+ Local Hint Immediate (weight_positive wprops).
+ Local Hint Immediate (weight_multiples wprops).
+ Local Hint Immediate (weight_divides wprops).
+ Local Hint Resolve Z.positive_is_nonzero Z.lt_gt.
+
+ Local Lemma weight_1_gt_1 : weight 1 > 1.
+ Proof.
+ clear -limbwidth_good.
+ cut (1 < weight 1); [ lia | ].
+ cbv [weight Z.of_nat]; autorewrite with zsimplify_fast.
+ apply Z.pow_gt_1; [ omega | ].
+ Z.div_mod_to_quot_rem; nia.
+ Qed.
+
+ Derive carry_mulmod
+ SuchThat (forall (f g : list Z)
+ (Hf : length f = n)
+ (Hg : length g = n),
+ (eval weight n (carry_mulmod f g)) mod (s - Associational.eval c)
+ = (eval weight n f * eval weight n g) mod (s - Associational.eval c))
+ As eval_carry_mulmod.
+ Proof.
+ intros.
+ rewrite <-eval_mulmod with (s:=s) (c:=c) by auto.
+ etransitivity;
+ [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs)
+ by auto; reflexivity ].
+ eapply f_equal2; [|trivial]. eapply f_equal.
+ subst carry_mulmod; reflexivity.
+ Qed.
+
+ Derive carrymod
+ SuchThat (forall (f : list Z)
+ (Hf : length f = n),
+ (eval weight n (carrymod f)) mod (s - Associational.eval c)
+ = (eval weight n f) mod (s - Associational.eval c))
+ As eval_carrymod.
+ Proof.
+ intros.
+ etransitivity;
+ [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs)
+ by auto; reflexivity ].
+ eapply f_equal2; [|trivial]. eapply f_equal.
+ subst carrymod; reflexivity.
+ Qed.
+
+ Derive addmod
+ SuchThat (forall (f g : list Z)
+ (Hf : length f = n)
+ (Hg : length g = n),
+ (eval weight n (addmod f g)) mod (s - Associational.eval c)
+ = (eval weight n f + eval weight n g) mod (s - Associational.eval c))
+ As eval_addmod.
+ Proof.
+ intros.
+ rewrite <-eval_add by auto.
+ eapply f_equal2; [|trivial]. eapply f_equal.
+ subst addmod; reflexivity.
+ Qed.
+
+ Derive submod
+ SuchThat (forall (coef:Z)
+ (f g : list Z)
+ (Hf : length f = n)
+ (Hg : length g = n),
+ (eval weight n (submod coef f g)) mod (s - Associational.eval c)
+ = (eval weight n f - eval weight n g) mod (s - Associational.eval c))
+ As eval_submod.
+ Proof.
+ intros.
+ rewrite <-eval_sub with (coef:=coef) by auto.
+ eapply f_equal2; [|trivial]. eapply f_equal.
+ subst submod; reflexivity.
+ Qed.
+
+ Derive oppmod
+ SuchThat (forall (coef:Z)
+ (f: list Z)
+ (Hf : length f = n),
+ (eval weight n (oppmod coef f)) mod (s - Associational.eval c)
+ = (- eval weight n f) mod (s - Associational.eval c))
+ As eval_oppmod.
+ Proof.
+ intros.
+ rewrite <-eval_opp with (coef:=coef) by auto.
+ eapply f_equal2; [|trivial]. eapply f_equal.
+ subst oppmod; reflexivity.
+ Qed.
+
+ Derive encodemod
+ SuchThat (forall (f:Z),
+ (eval weight n (encodemod f)) mod (s - Associational.eval c)
+ = f mod (s - Associational.eval c))
+ As eval_encodemod.
+ Proof.
+ intros.
+ etransitivity.
+ 2:rewrite <-@eval_encode with (weight:=weight) (n:=n) by auto; reflexivity.
+ eapply f_equal2; [|trivial]. eapply f_equal.
+ subst encodemod; reflexivity.
+ Qed.
+End mod_ops.
+
+Module Saturated.
+ Hint Resolve weight_positive weight_0 weight_multiples weight_divides.
+ Hint Resolve Z.positive_is_nonzero Z.lt_gt Nat2Z.is_nonneg.
+
+ Section Weight.
+ Context weight {wprops : @weight_properties weight}.
+
+ Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0.
+ Proof.
+ induction j; intros;
+ repeat match goal with
+ | _ => rewrite Nat.add_succ_r
+ | _ => rewrite IHj
+ | |- context [weight (S ?x) mod weight _] =>
+ rewrite (Z.div_mod (weight (S x)) (weight x)), weight_multiples by auto
+ | _ => progress autorewrite with push_Zmod natsimplify zsimplify_fast
+ | _ => reflexivity
+ end.
+ Qed.
+
+ Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0.
+ Proof.
+ intros; replace j with (i + (j - i))%nat by omega.
+ apply weight_multiples_full'.
+ Qed.
+
+ Lemma weight_divides_full j i : (i <= j)%nat -> 0 < weight j / weight i.
+ Proof. auto using Z.gt_lt, Z.div_positive_gt_0, weight_multiples_full. Qed.
+
+ Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i).
+ Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed.
+ End Weight.
+
+ Module Associational.
+ Section Associational.
+
+ Definition sat_multerm s (t t' : (Z * Z)) : list (Z * Z) :=
+ dlet_nd xy := Z.mul_split s (snd t) (snd t') in
+ [(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)].
+
+ Definition sat_mul s (p q : list (Z * Z)) : list (Z * Z) :=
+ flat_map (fun t => flat_map (fun t' => sat_multerm s t t') q) p.
+
+ Lemma eval_map_sat_multerm s a q (s_nonzero:s<>0):
+ Associational.eval (flat_map (sat_multerm s a) q) = fst a * snd a * Associational.eval q.
+ Proof.
+ cbv [sat_multerm Let_In]; induction q;
+ repeat match goal with
+ | _ => progress autorewrite with cancel_pair push_eval to_div_mod in *
+ | _ => progress simpl flat_map
+ | _ => rewrite IHq
+ | _ => rewrite Z.mod_eq by assumption
+ | _ => ring_simplify; omega
+ end.
+ Qed.
+ Hint Rewrite eval_map_sat_multerm using (omega || assumption) : push_eval.
+
+ Lemma eval_sat_mul s p q (s_nonzero:s<>0):
+ Associational.eval (sat_mul s p q) = Associational.eval p * Associational.eval q.
+ Proof.
+ cbv [sat_mul]; induction p; [reflexivity|].
+ repeat match goal with
+ | _ => progress (autorewrite with push_flat_map push_eval in * )
+ | _ => rewrite IHp
+ | _ => ring_simplify; omega
+ end.
+ Qed.
+ Hint Rewrite eval_sat_mul : push_eval.
+
+ Definition sat_multerm_const s (t t' : (Z * Z)) : list (Z * Z) :=
+ if snd t =? 1
+ then [(fst t * fst t', snd t')]
+ else if snd t =? -1
+ then [(fst t * fst t', - snd t')]
+ else if snd t =? 0
+ then nil
+ else dlet_nd xy := Z.mul_split s (snd t) (snd t') in
+ [(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)].
+
+ Definition sat_mul_const s (p q : list (Z * Z)) : list (Z * Z) :=
+ flat_map (fun t => flat_map (fun t' => sat_multerm_const s t t') q) p.
+
+ Lemma eval_map_sat_multerm_const s a q (s_nonzero:s<>0):
+ Associational.eval (flat_map (sat_multerm_const s a) q) = fst a * snd a * Associational.eval q.
+ Proof.
+ cbv [sat_multerm_const Let_In]; induction q;
+ repeat match goal with
+ | _ => progress autorewrite with cancel_pair push_eval to_div_mod in *
+ | _ => progress simpl flat_map
+ | H : _ = 1 |- _ => rewrite H
+ | H : _ = -1 |- _ => rewrite H
+ | H : _ = 0 |- _ => rewrite H
+ | _ => progress break_match; Z.ltb_to_lt
+ | _ => rewrite IHq
+ | _ => rewrite Z.mod_eq by assumption
+ | _ => ring_simplify; omega
+ end.
+ Qed.
+ Hint Rewrite eval_map_sat_multerm_const using (omega || assumption) : push_eval.
+
+ Lemma eval_sat_mul_const s p q (s_nonzero:s<>0):
+ Associational.eval (sat_mul_const s p q) = Associational.eval p * Associational.eval q.
+ Proof.
+ cbv [sat_mul_const]; induction p; [reflexivity|].
+ repeat match goal with
+ | _ => progress (autorewrite with push_flat_map push_eval in * )
+ | _ => rewrite IHp
+ | _ => ring_simplify; omega
+ end.
+ Qed.
+ Hint Rewrite eval_sat_mul_const : push_eval.
+ End Associational.
+ End Associational.
+
+ Section DivMod.
+ Lemma mod_step a b c d: 0 < a -> 0 < b ->
+ c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b).
+ Proof.
+ intros; rewrite Z.rem_mul_r by omega. push_Zmod.
+ autorewrite with zsimplify pull_Zmod. repeat (f_equal; try ring).
+ Qed.
+
+ Lemma div_step a b c d : 0 < a -> 0 < b ->
+ (c / a + d) / b = (a * d + c) / (a * b).
+ Proof. intros; Z.div_mod_to_quot_rem; nia. Qed.
+
+ Lemma add_mod_div_multiple a b n m:
+ n > 0 ->
+ 0 <= m / n ->
+ m mod n = 0 ->
+ (a / n + b) mod (m / n) = (a + n * b) mod m / n.
+ Proof.
+ intros. rewrite <-!Z.div_add' by auto using Z.positive_is_nonzero.
+ rewrite Z.mod_pull_div, Z.mul_div_eq' by auto using Z.gt_lt.
+ repeat (f_equal; try omega).
+ Qed.
+
+ Lemma add_mod_l_multiple a b n m:
+ 0 < n / m -> m <> 0 -> n mod m = 0 ->
+ (a mod n + b) mod m = (a + b) mod m.
+ Proof.
+ intros.
+ rewrite (proj2 (Z.div_exact n m ltac:(auto))) by auto.
+ rewrite Z.rem_mul_r by auto.
+ push_Zmod. autorewrite with zsimplify.
+ pull_Zmod. reflexivity.
+ Qed.
+
+ Definition is_div_mod {T} (evalf : T -> Z) dm y n :=
+ evalf (fst dm) = y mod n /\ snd dm = y / n.
+
+ Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x :
+ n1 > 0 ->
+ 0 < n2 / n1 ->
+ n2 mod n1 = 0 ->
+ evalf2 (fst dm2) = evalf1 (fst dm1) + n1 * ((snd dm1 + x) mod (n2 / n1)) ->
+ snd dm2 = (snd dm1 + x) / (n2 / n1) ->
+ y2 = y1 + n1 * x ->
+ @is_div_mod T evalf1 dm1 y1 n1 ->
+ @is_div_mod T evalf2 dm2 y2 n2.
+ Proof.
+ intros; subst y2; cbv [is_div_mod] in *.
+ repeat match goal with
+ | H: _ /\ _ |- _ => destruct H
+ | H: ?LHS = _ |- _ => match LHS with context [dm2] => rewrite H end
+ | H: ?LHS = _ |- _ => match LHS with context [dm1] => rewrite H end
+ | _ => rewrite mod_step by omega
+ | _ => rewrite div_step by omega
+ | _ => rewrite Z.mul_div_eq_full by omega
+ end.
+ split; f_equal; omega.
+ Qed.
+
+ Lemma is_div_mod_result_equal {T} evalf dm y1 y2 n :
+ y1 = y2 ->
+ @is_div_mod T evalf dm y1 n ->
+ @is_div_mod T evalf dm y2 n.
+ Proof. congruence. Qed.
+ End DivMod.
+End Saturated.
+
+Module Columns.
+ Import Saturated.
+ Section Columns.
+ Context weight {wprops : @weight_properties weight}.
+
+ Definition eval n (x : list (list Z)) : Z := Positional.eval weight n (map sum x).
+
+ Lemma eval_nil n : eval n [] = 0.
+ Proof. cbv [eval]; simpl. apply Positional.eval_nil. Qed.
+ Hint Rewrite eval_nil : push_eval.
+ Lemma eval_snoc n x y : n = length x -> eval (S n) (x ++ [y]) = eval n x + weight n * sum y.
+ Proof.
+ cbv [eval]; intros; subst. rewrite map_app. simpl map.
+ apply Positional.eval_snoc; distr_length.
+ Qed. Hint Rewrite eval_snoc using (solve [distr_length]) : push_eval.
+
+ Hint Rewrite <- Z.div_add' using omega : pull_Zdiv.
+
+ Ltac cases :=
+ match goal with
+ | |- _ /\ _ => split
+ | H: _ /\ _ |- _ => destruct H
+ | H: _ \/ _ |- _ => destruct H
+ | _ => progress break_match; try discriminate
+ end.
+
+ Section Flatten.
+ Section flatten_column.
+ Context (fw : Z). (* maximum size of the result *)
+
+ (* Outputs (sum, carry) *)
+ Definition flatten_column (digit: list Z) : (Z * Z) :=
+ list_rect (fun _ => (Z * Z)%type) (0,0)
+ (fun xx tl flatten_column_tl =>
+ list_rect
+ (fun _ => (Z * Z)%type) (xx mod fw, xx / fw)
+ (fun yy tl' _ =>
+ list_rect
+ (fun _ => (Z * Z)%type) (dlet_nd x := xx in dlet_nd y := yy in Z.add_get_carry_full fw x y)
+ (fun _ _ _ =>
+ dlet_nd x := xx in
+ dlet_nd rec := flatten_column_tl in (* recursively get the sum and carry *)
+ dlet_nd sum_carry := Z.add_get_carry_full fw x (fst rec) in (* add the new value to the sum *)
+ dlet_nd carry' := snd sum_carry + snd rec in (* add the two carries together *)
+ (fst sum_carry, carry'))
+ tl')
+ tl)
+ digit.
+ End flatten_column.
+
+ Definition flatten_step (digit:list Z) (acc_carry:list Z * Z) : list Z * Z :=
+ dlet sum_carry := flatten_column (weight (S (length (fst acc_carry))) / weight (length (fst acc_carry))) (snd acc_carry::digit) in
+ (fst acc_carry ++ fst sum_carry :: nil, snd sum_carry).
+
+ Definition flatten (xs : list (list Z)) : list Z * Z :=
+ fold_right (fun a b => flatten_step a b) (nil,0) (rev xs).
+
+ Ltac push_fast :=
+ repeat match goal with
+ | _ => progress cbv [Let_In]
+ | |- context [list_rect _ _ _ ?ls] => rewrite single_list_rect_to_match; destruct ls
+ | _ => progress (unfold flatten_step in *; fold flatten_step in * )
+ | _ => rewrite Nat.add_1_r
+ | _ => rewrite Z.mul_div_eq_full by (auto; omega)
+ | _ => rewrite weight_multiples
+ | _ => reflexivity
+ | _ => solve [repeat (f_equal; try ring)]
+ | _ => congruence
+ | _ => progress cases
+ end.
+ Ltac push :=
+ repeat match goal with
+ | _ => progress push_fast
+ | _ => progress autorewrite with cancel_pair to_div_mod
+ | _ => progress autorewrite with push_sum push_fold_right push_nth_default in *
+ | _ => progress autorewrite with pull_Zmod pull_Zdiv zsimplify_fast
+ | _ => progress autorewrite with list distr_length push_eval
+ end.
+
+ Lemma flatten_column_mod fw (xs : list Z) :
+ fst (flatten_column fw xs) = sum xs mod fw.
+ Proof.
+ induction xs; simpl flatten_column; cbv [Let_In];
+ repeat match goal with
+ | _ => rewrite IHxs
+ | _ => progress push
+ end.
+ Qed. Hint Rewrite flatten_column_mod : to_div_mod.
+
+ Lemma flatten_column_div fw (xs : list Z) (fw_nz : fw <> 0) :
+ snd (flatten_column fw xs) = sum xs / fw.
+ Proof.
+ induction xs; simpl flatten_column; cbv [Let_In];
+ repeat match goal with
+ | _ => rewrite IHxs
+ | _ => rewrite Z.mul_div_eq_full by omega
+ | _ => progress push
+ end.
+ Qed. Hint Rewrite flatten_column_div using auto with zarith : to_div_mod.
+
+ Hint Rewrite Positional.eval_nil : push_eval.
+ Hint Resolve Z.gt_lt.
+
+ Lemma length_flatten_step digit state :
+ length (fst (flatten_step digit state)) = S (length (fst state)).
+ Proof. cbv [flatten_step]; push. Qed.
+ Hint Rewrite length_flatten_step : distr_length.
+ Lemma length_flatten inp : length (fst (flatten inp)) = length inp.
+ Proof. cbv [flatten]. induction inp using rev_ind; push. Qed.
+ Hint Rewrite length_flatten : distr_length.
+
+ Lemma flatten_div_mod n inp :
+ length inp = n ->
+ (Positional.eval weight n (fst (flatten inp))
+ = (eval n inp) mod (weight n))
+ /\ (snd (flatten inp) = eval n inp / weight n).
+ Proof.
+ (* to make the invariant take the right form, we make everything depend on output length, not input length *)
+ intro. subst n. rewrite <-(length_flatten inp). cbv [flatten].
+ induction inp using rev_ind; intros; [push|].
+ repeat match goal with
+ | _ => rewrite Nat.add_1_r
+ | _ => progress (fold (flatten inp) in * )
+ | _ => erewrite Positional.eval_snoc by (distr_length; reflexivity)
+ | H: _ = _ mod (weight _) |- _ => rewrite H
+ | H: _ = _ / (weight _) |- _ => rewrite H
+ | _ => progress rewrite ?mod_step, ?div_step by auto
+ | _ => progress autorewrite with cancel_pair to_div_mod push_sum list push_fold_right push_eval
+ | _ => progress (distr_length; push_fast)
+ end.
+ Qed.
+
+ Lemma flatten_mod {n} inp :
+ length inp = n ->
+ (Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)).
+ Proof. apply flatten_div_mod. Qed.
+ Hint Rewrite @flatten_mod : push_eval.
+
+ Lemma flatten_div {n} inp :
+ length inp = n -> snd (flatten inp) = eval n inp / weight n.
+ Proof. apply flatten_div_mod. Qed.
+ Hint Rewrite @flatten_div : push_eval.
+
+ Lemma flatten_snoc x inp : flatten (inp ++ [x]) = flatten_step x (flatten inp).
+ Proof. cbv [flatten]. rewrite rev_unit. reflexivity. Qed.
+
+ Lemma flatten_partitions inp:
+ forall n i, length inp = n -> (i < n)%nat ->
+ nth_default 0 (fst (flatten inp)) i = ((eval n inp) mod (weight (S i))) / weight i.
+ Proof.
+ induction inp using rev_ind; intros; destruct n; distr_length.
+ rewrite flatten_snoc.
+ push; distr_length;
+ [rewrite IHinp with (n:=n) by omega; rewrite weight_div_mod with (j:=n) (i:=S i) by (eauto; omega); push_Zmod; push |].
+ repeat match goal with
+ | _ => progress replace (length inp) with n by omega
+ | _ => progress replace i with n by omega
+ | _ => progress push
+ | _ => erewrite flatten_div by eauto
+ | _ => rewrite <-Z.div_add' by auto
+ | _ => rewrite Z.mul_div_eq' by auto
+ | _ => rewrite Z.mod_pull_div by auto using Z.lt_le_incl
+ | _ => progress autorewrite with push_nth_default natsimplify
+ end.
+ Qed.
+ End Flatten.
+
+ Section FromAssociational.
+ (* nils *)
+ Definition nils n : list (list Z) := repeat nil n.
+ Lemma length_nils n : length (nils n) = n. Proof. cbv [nils]. distr_length. Qed.
+ Hint Rewrite length_nils : distr_length.
+ Lemma eval_nils n : eval n (nils n) = 0.
+ Proof.
+ erewrite <-Positional.eval_zeros by eauto.
+ cbv [eval nils]; rewrite List.map_repeat; reflexivity.
+ Qed. Hint Rewrite eval_nils : push_eval.
+
+ (* cons_to_nth *)
+ Definition cons_to_nth i x (xs : list (list Z)) : list (list Z) :=
+ ListUtil.update_nth i (fun y => cons x y) xs.
+ Lemma length_cons_to_nth i x xs : length (cons_to_nth i x xs) = length xs.
+ Proof. cbv [cons_to_nth]. distr_length. Qed.
+ Hint Rewrite length_cons_to_nth : distr_length.
+ Lemma cons_to_nth_add_to_nth xs : forall i x,
+ map sum (cons_to_nth i x xs) = Positional.add_to_nth i x (map sum xs).
+ Proof.
+ cbv [cons_to_nth]; induction xs as [|? ? IHxs];
+ intros i x; destruct i; simpl; rewrite ?IHxs; reflexivity.
+ Qed.
+ Lemma eval_cons_to_nth n i x xs : (i < length xs)%nat -> length xs = n ->
+ eval n (cons_to_nth i x xs) = weight i * x + eval n xs.
+ Proof using Type.
+ cbv [eval]; intros. rewrite cons_to_nth_add_to_nth.
+ apply Positional.eval_add_to_nth; distr_length.
+ Qed. Hint Rewrite eval_cons_to_nth using (solve [distr_length]) : push_eval.
+
+ Hint Rewrite Positional.eval_zeros : push_eval.
+ Hint Rewrite Positional.length_from_associational : distr_length.
+ Hint Rewrite Positional.eval_add_to_nth using (solve [distr_length]): push_eval.
+
+ (* from_associational *)
+ Definition from_associational n (p:list (Z*Z)) : list (list Z) :=
+ List.fold_right (fun t ls =>
+ dlet_nd p := Positional.place weight t (pred n) in
+ cons_to_nth (fst p) (snd p) ls ) (nils n) p.
+ Lemma length_from_associational n p : length (from_associational n p) = n.
+ Proof. cbv [from_associational Let_In]. apply fold_right_invariant; intros; distr_length. Qed.
+ Hint Rewrite length_from_associational: distr_length.
+ Lemma eval_from_associational n p (n_nonzero:n<>0%nat\/p=nil):
+ eval n (from_associational n p) = Associational.eval p.
+ Proof.
+ erewrite <-Positional.eval_from_associational by eauto.
+ induction p; [ autorewrite with push_eval; solve [auto] |].
+ cbv [from_associational Positional.from_associational]; autorewrite with push_fold_right.
+ fold (from_associational n p); fold (Positional.from_associational weight n p).
+ cbv [Let_In].
+ match goal with |- context [Positional.place _ ?x ?n] =>
+ pose proof (Positional.place_in_range weight x n) end.
+ repeat match goal with
+ | _ => rewrite Nat.succ_pred in * by auto
+ | _ => rewrite IHp by auto
+ | _ => progress autorewrite with push_eval
+ | _ => progress cases
+ | _ => congruence
+ end.
+ Qed.
+
+ Lemma from_associational_step n t p :
+ from_associational n (t :: p) =
+ cons_to_nth (fst (Positional.place weight t (Nat.pred n)))
+ (snd (Positional.place weight t (Nat.pred n)))
+ (from_associational n p).
+ Proof. reflexivity. Qed.
+ End FromAssociational.
+ End Columns.
+End Columns.
+
+Module Rows.
+ Import Saturated.
+ Section Rows.
+ Context weight {wprops : @weight_properties weight}.
+
+ Local Notation rows := (list (list Z)) (only parsing).
+ Local Notation cols := (list (list Z)) (only parsing).
+
+ Hint Rewrite Positional.eval_nil Positional.eval0 @Positional.eval_snoc
+ Positional.eval_to_associational
+ Columns.eval_nil Columns.eval_snoc using (auto; solve [distr_length]) : push_eval.
+ Hint Resolve in_eq in_cons.
+
+ Definition eval n (inp : rows) :=
+ sum (map (Positional.eval weight n) inp).
+ Lemma eval_nil n : eval n nil = 0.
+ Proof. cbv [eval]. rewrite map_nil, sum_nil; reflexivity. Qed.
+ Hint Rewrite eval_nil : push_eval.
+ Lemma eval0 x : eval 0 x = 0.
+ Proof. cbv [eval]. induction x; autorewrite with push_map push_sum push_eval; omega. Qed.
+ Hint Rewrite eval0 : push_eval.
+ Lemma eval_cons n r inp : eval n (r :: inp) = Positional.eval weight n r + eval n inp.
+ Proof. cbv [eval]; autorewrite with push_map push_sum; reflexivity. Qed.
+ Hint Rewrite eval_cons : push_eval.
+ Lemma eval_app n x y : eval n (x ++ y) = eval n x + eval n y.
+ Proof. cbv [eval]; autorewrite with push_map push_sum; reflexivity. Qed.
+ Hint Rewrite eval_app : push_eval.
+
+ Ltac In_cases :=
+ repeat match goal with
+ | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H
+ | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H
+ | H: In _ nil |- _ => contradiction H
+ | H: forall x, In x (?y :: ?ls) -> ?P |- _ =>
+ unique pose proof (H y ltac:(apply in_eq));
+ unique assert (forall x, In x ls -> P) by auto
+ | H: forall x, In x (?ls ++ ?y :: nil) -> ?P |- _ =>
+ unique pose proof (H y ltac:(auto using in_or_app, in_eq));
+ unique assert (forall x, In x ls -> P) by eauto using in_or_app
+ end.
+
+ Section FromAssociational.
+ (* extract row *)
+ Definition extract_row (inp : cols) : cols * list Z := (map (fun c => tl c) inp, map (fun c => hd 0 c) inp).
+
+ Lemma eval_extract_row (inp : cols): forall n,
+ length inp = n ->
+ Positional.eval weight n (snd (extract_row inp)) = Columns.eval weight n inp - Columns.eval weight n (fst (extract_row inp)) .
+ Proof.
+ cbv [extract_row].
+ induction inp using rev_ind; [ | destruct n ];
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress distr_length
+ | _ => rewrite Positional.eval_snoc with (n:=n) by distr_length
+ | _ => progress autorewrite with cancel_pair push_eval push_map in *
+ | _ => ring
+ end.
+ rewrite IHinp by distr_length.
+ destruct x; cbn [hd tl]; rewrite ?sum_nil, ?sum_cons; ring.
+ Qed. Hint Rewrite eval_extract_row using (solve [distr_length]) : push_eval.
+
+ Lemma length_fst_extract_row n (inp : cols) :
+ length inp = n -> length (fst (extract_row inp)) = n.
+ Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed.
+ Hint Rewrite length_fst_extract_row : distr_length.
+
+ Lemma length_snd_extract_row n (inp : cols) :
+ length inp = n -> length (snd (extract_row inp)) = n.
+ Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed.
+ Hint Rewrite length_snd_extract_row : distr_length.
+
+ (* max column size *)
+ Definition max_column_size (x:cols) := fold_right (fun a b => Nat.max a b) 0%nat (map (fun c => length c) x).
+
+ (* TODO: move to where list is defined *)
+ Hint Rewrite @app_nil_l : list.
+ Hint Rewrite <-@app_comm_cons: list.
+
+ Lemma max_column_size_nil : max_column_size nil = 0%nat.
+ Proof. reflexivity. Qed. Hint Rewrite max_column_size_nil : push_max_column_size.
+ Lemma max_column_size_cons col (inp : cols) :
+ max_column_size (col :: inp) = Nat.max (length col) (max_column_size inp).
+ Proof. reflexivity. Qed. Hint Rewrite max_column_size_cons : push_max_column_size.
+ Lemma max_column_size_app (x y : cols) :
+ max_column_size (x ++ y) = Nat.max (max_column_size x) (max_column_size y).
+ Proof. induction x; autorewrite with list push_max_column_size; lia. Qed.
+ Hint Rewrite max_column_size_app : push_max_column_size.
+ Lemma max_column_size0 (inp : cols) :
+ forall n,
+ length inp = n -> (* this is not needed to make the lemma true, but prevents reliance on the implementation of Columns.eval*)
+ max_column_size inp = 0%nat -> Columns.eval weight n inp = 0.
+ Proof.
+ induction inp as [|x inp] using rev_ind; destruct n; try destruct x; intros;
+ autorewrite with push_max_column_size push_eval push_sum distr_length in *; try lia.
+ rewrite IHinp; distr_length; lia.
+ Qed.
+
+ (* from_columns *)
+ Definition from_columns' n start_state : cols * rows :=
+ fold_right (fun _ (state : cols * rows) =>
+ let cols'_row := extract_row (fst state) in
+ (fst cols'_row, snd state ++ [snd cols'_row])
+ ) start_state (repeat 0 n).
+
+ Definition from_columns (inp : cols) : rows := snd (from_columns' (max_column_size inp) (inp, [])).
+
+ Lemma eval_from_columns'_with_length m st n:
+ (length (fst st) = n) ->
+ length (fst (from_columns' m st)) = n /\
+ ((forall r, In r (snd st) -> length r = n) ->
+ forall r, In r (snd (from_columns' m st)) -> length r = n) /\
+ eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st)
+ - Columns.eval weight n (fst (from_columns' m st)).
+ Proof.
+ cbv [from_columns']; intros.
+ apply fold_right_invariant; intros;
+ repeat match goal with
+ | _ => progress (intros; subst)
+ | _ => progress autorewrite with cancel_pair push_eval
+ | _ => progress In_cases
+ | _ => split; try omega
+ | H: _ /\ _ |- _ => destruct H
+ | _ => solve [auto using length_fst_extract_row, length_snd_extract_row]
+ end.
+ Qed.
+ Lemma length_fst_from_columns' m st :
+ length (fst (from_columns' m st)) = length (fst st).
+ Proof. apply eval_from_columns'_with_length; reflexivity. Qed.
+ Hint Rewrite length_fst_from_columns' : distr_length.
+ Lemma length_snd_from_columns' m st :
+ (forall r, In r (snd st) -> length r = length (fst st)) ->
+ forall r, In r (snd (from_columns' m st)) -> length r = length (fst st).
+ Proof. apply eval_from_columns'_with_length. reflexivity. Qed.
+ Hint Rewrite length_snd_from_columns' : distr_length.
+ Lemma eval_from_columns' m st n :
+ (length (fst st) = n) ->
+ eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st)
+ - Columns.eval weight n (fst (from_columns' m st)).
+ Proof. apply eval_from_columns'_with_length. Qed.
+ Hint Rewrite eval_from_columns' using (auto; solve [distr_length]) : push_eval.
+
+ Lemma max_column_size_extract_row inp :
+ max_column_size (fst (extract_row inp)) = (max_column_size inp - 1)%nat.
+ Proof.
+ cbv [extract_row]. autorewrite with cancel_pair.
+ induction inp; [ reflexivity | ].
+ autorewrite with push_max_column_size push_map distr_length.
+ rewrite IHinp. auto using Nat.sub_max_distr_r.
+ Qed.
+ Hint Rewrite max_column_size_extract_row : push_max_column_size.
+
+ Lemma max_column_size_from_columns' m st :
+ max_column_size (fst (from_columns' m st)) = (max_column_size (fst st) - m)%nat.
+ Proof.
+ cbv [from_columns']; induction m; intros; cbn - [max_column_size extract_row];
+ autorewrite with push_max_column_size; lia.
+ Qed.
+ Hint Rewrite max_column_size_from_columns' : push_max_column_size.
+
+ Lemma eval_from_columns (inp : cols) :
+ forall n, length inp = n -> eval n (from_columns inp) = Columns.eval weight n inp.
+ Proof.
+ intros; cbv [from_columns];
+ repeat match goal with
+ | _ => progress autorewrite with cancel_pair push_eval push_max_column_size
+ | _ => rewrite max_column_size0 with (inp := fst (from_columns' _ _)) by
+ (autorewrite with push_max_column_size; distr_length)
+ | _ => omega
+ end.
+ Qed.
+ Hint Rewrite eval_from_columns using (auto; solve [distr_length]) : push_eval.
+
+ Lemma length_from_columns inp:
+ forall r, In r (from_columns inp) -> length r = length inp.
+ Proof.
+ cbv [from_columns]; intros.
+ change inp with (fst (inp, @nil (list Z))).
+ eapply length_snd_from_columns'; eauto.
+ autorewrite with cancel_pair; intros; In_cases.
+ Qed.
+ Hint Rewrite length_from_columns : distr_length.
+
+ (* from associational *)
+ Definition from_associational n (p : list (Z * Z)) := from_columns (Columns.from_associational weight n p).
+
+ Lemma eval_from_associational n p: (n <> 0%nat \/ p = nil) ->
+ eval n (from_associational n p) = Associational.eval p.
+ Proof.
+ intros. cbv [from_associational].
+ rewrite eval_from_columns by auto using Columns.length_from_associational.
+ auto using Columns.eval_from_associational.
+ Qed.
+
+ Lemma length_from_associational n p :
+ forall r, In r (from_associational n p) -> length r = n.
+ Proof.
+ cbv [from_associational]; intros.
+ match goal with H: _ |- _ => apply length_from_columns in H end.
+ rewrite Columns.length_from_associational in *; auto.
+ Qed.
+
+ (* TODO : move *)
+ Lemma max_0_iff a b : Nat.max a b = 0%nat <-> (a = 0%nat /\ b = 0%nat).
+ Proof.
+ destruct a, b; try tauto.
+ rewrite <-Nat.succ_max_distr.
+ split; [ | destruct 1]; congruence.
+ Qed.
+ Lemma max_column_size_zero_iff x :
+ max_column_size x = 0%nat <-> (forall c, In c x -> c = nil).
+ Proof.
+ cbv [max_column_size]; induction x; intros; [ cbn; tauto | ].
+ autorewrite with push_fold_right push_map.
+ rewrite max_0_iff, IHx.
+ split; intros; [ | rewrite length_zero_iff_nil; solve [auto] ].
+ match goal with H : _ /\ _ |- _ => destruct H end.
+ In_cases; subst; auto using length0_nil.
+ Qed.
+
+ Lemma max_column_size_Columns_from_associational n p :
+ n <> 0%nat -> p <> nil ->
+ max_column_size (Columns.from_associational weight n p) <> 0%nat.
+ Proof.
+ intros.
+ rewrite max_column_size_zero_iff.
+ intro. destruct p; [congruence | ].
+ rewrite Columns.from_associational_step in *.
+ cbv [Columns.cons_to_nth] in *.
+ match goal with H : forall c, In c (update_nth ?n ?f ?ls) -> _ |- _ =>
+ assert (n < length (update_nth n f ls))%nat;
+ [ | specialize (H (nth n (update_nth n f ls) nil) ltac:(auto using nth_In)) ]
+ end.
+ { distr_length.
+ rewrite Columns.length_from_associational.
+ remember (Nat.pred n) as m. replace n with (S m) by omega.
+ apply Positional.place_in_range. }
+ rewrite <-nth_default_eq in *.
+ autorewrite with push_nth_default in *.
+ rewrite eq_nat_dec_refl in *.
+ congruence.
+ Qed.
+
+ Lemma from_associational_nonnil n p :
+ n <> 0%nat -> p <> nil ->
+ from_associational n p <> nil.
+ Proof.
+ intros; cbv [from_associational from_columns from_columns'].
+ pose proof (max_column_size_Columns_from_associational n p ltac:(auto) ltac:(auto)).
+ case_eq (max_column_size (Columns.from_associational weight n p)); [omega|].
+ intros; cbn.
+ rewrite <-length_zero_iff_nil. distr_length.
+ Qed.
+ End FromAssociational.
+
+ Section Flatten.
+ Local Notation fw := (fun i => weight (S i) / weight i) (only parsing).
+
+ Section SumRows.
+ Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z * nat :=
+ fold_right (fun next (state : list Z * Z * nat) =>
+ let i := snd state in
+ let low_high' :=
+ let low_high := fst state in
+ let low := fst low_high in
+ let high := snd low_high in
+ dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) high (fst next) (snd next) in
+ (low ++ [fst sum_carry], snd sum_carry) in
+ (low_high', S i)) start_state (rev (combine row1 row2)).
+ Definition sum_rows row1 row2 := fst (sum_rows' (nil, 0, 0%nat) row1 row2).
+
+ Ltac push :=
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress cbv [Let_In]
+ | _ => rewrite Nat.add_1_r
+ | _ => erewrite Positional.eval_snoc by eauto
+ | H : length _ = _ |- _ => rewrite H
+ | H: 0%nat = _ |- _ => rewrite <-H
+ | [p := _ |- _] => subst p
+ | _ => progress autorewrite with cancel_pair natsimplify push_sum_rows list push_nth_default
+ | _ => progress autorewrite with cancel_pair in *
+ | _ => progress distr_length
+ | _ => progress break_match
+ | _ => ring
+ | _ => solve [ repeat (f_equal; try ring) ]
+ | _ => tauto
+ | _ => solve [eauto]
+ end.
+
+ Lemma sum_rows'_cons state x1 row1 x2 row2 :
+ sum_rows' state (x1 :: row1) (x2 :: row2) =
+ sum_rows' (fst (fst state) ++ [(snd (fst state) + x1 + x2) mod (fw (snd state))],
+ (snd (fst state) + x1 + x2) / fw (snd state),
+ S (snd state)) row1 row2.
+ Proof.
+ cbv [sum_rows' Let_In]; autorewrite with push_combine.
+ rewrite !fold_left_rev_right. cbn [fold_left].
+ autorewrite with cancel_pair to_div_mod. congruence.
+ Qed.
+
+ Lemma sum_rows'_nil state :
+ sum_rows' state nil nil = state.
+ Proof. reflexivity. Qed.
+
+ Hint Rewrite sum_rows'_cons sum_rows'_nil : push_sum_rows.
+
+ Lemma sum_rows'_div_mod_length row1 :
+ forall nm start_state row2 row1' row2',
+ let m := snd start_state in
+ let n := length row1 in
+ length row2 = n ->
+ length row1' = m ->
+ length row2' = m ->
+ length (fst (fst start_state)) = m ->
+ (nm = n + m)%nat ->
+ let eval := Positional.eval weight in
+ is_div_mod (eval m) (fst start_state) (eval m row1' + eval m row2') (weight m) ->
+ length (fst (fst (sum_rows' start_state row1 row2))) = nm
+ /\ is_div_mod (eval nm) (fst (sum_rows' start_state row1 row2))
+ (eval nm (row1' ++ row1) + eval nm (row2' ++ row2))
+ (weight nm).
+ Proof.
+ induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [ ].
+ rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
+ apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length in *; try omega.
+ eapply is_div_mod_step with (x := x1 + x2); try eassumption; push.
+ Qed.
+
+ Lemma sum_rows_div_mod n row1 row2 :
+ length row1 = n -> length row2 = n ->
+ let eval := Positional.eval weight in
+ is_div_mod (eval n) (sum_rows row1 row2) (eval n row1 + eval n row2) (weight n).
+ Proof.
+ cbv [sum_rows]; intros.
+ apply sum_rows'_div_mod_length with (row1':=nil) (row2':=nil);
+ cbv [is_div_mod]; autorewrite with cancel_pair push_eval zsimplify; distr_length.
+ Qed.
+
+ Lemma sum_rows_mod n row1 row2 :
+ length row1 = n -> length row2 = n ->
+ Positional.eval weight n (fst (sum_rows row1 row2))
+ = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n).
+ Proof. apply sum_rows_div_mod. Qed.
+ Lemma sum_rows_div row1 row2 n:
+ length row1 = n -> length row2 = n ->
+ snd (sum_rows row1 row2)
+ = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n).
+ Proof. apply sum_rows_div_mod. Qed.
+
+ Lemma sum_rows'_partitions row1 :
+ forall nm start_state row2 row1' row2',
+ let m := snd start_state in
+ let n := length row1 in
+ length row2 = n ->
+ length row1' = m ->
+ length row2' = m ->
+ length (fst (fst start_state)) = m ->
+ nm = (n + m)%nat ->
+ let eval := Positional.eval weight in
+ snd (fst start_state) = (eval m row1' + eval m row2') / weight m ->
+ (forall j, (j < m)%nat ->
+ nth_default 0 (fst (fst start_state)) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) ->
+ forall i, (i < nm)%nat ->
+ nth_default 0 (fst (fst (sum_rows' start_state row1 row2))) i
+ = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i).
+ Proof.
+ induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [].
+
+ rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
+ apply IHrow1; clear IHrow1; push;
+ repeat match goal with
+ | H : ?LHS = _ |- _ =>
+ match LHS with context [start_state] => rewrite H end
+ | H : context [nth_default 0 (fst (fst start_state))] |- _ => rewrite H by omega
+ | _ => rewrite <-(Z.add_assoc _ x1 x2)
+ end.
+ { rewrite div_step by auto using Z.gt_lt.
+ rewrite Z.mul_div_eq_full by auto; rewrite weight_multiples by auto. push. }
+ { rewrite weight_div_mod with (j:=snd start_state) (i:=S j) by (auto; omega).
+ push_Zmod. autorewrite with zsimplify_fast. reflexivity. }
+ { push. replace (snd start_state) with j in * by omega.
+ push. rewrite add_mod_div_multiple by auto using Z.lt_le_incl.
+ push. }
+ Qed.
+
+ Lemma sum_rows_partitions row1: forall row2 n i,
+ length row1 = n -> length row2 = n -> (i < n)%nat ->
+ nth_default 0 (fst (sum_rows row1 row2)) i
+ = ((Positional.eval weight n row1 + Positional.eval weight n row2) mod weight (S i)) / (weight i).
+ Proof.
+ cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
+ rewrite <-(app_nil_l row1), <-(app_nil_l row2).
+ apply sum_rows'_partitions; intros;
+ autorewrite with cancel_pair push_eval zsimplify_fast push_nth_default; distr_length.
+ Qed.
+
+ Lemma length_sum_rows row1 row2 n:
+ length row1 = n -> length row2 = n ->
+ length (fst (sum_rows row1 row2)) = n.
+ Proof.
+ cbv [sum_rows]; intros.
+ eapply sum_rows'_div_mod_length; cbv [is_div_mod];
+ autorewrite with cancel_pair; distr_length; auto using nil_length0.
+ Qed. Hint Rewrite length_sum_rows : distr_length.
+ End SumRows.
+ Hint Resolve length_sum_rows.
+ Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval.
+
+ Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z :=
+ fold_right (fun next_row (state : list Z * Z)=>
+ let out_carry := sum_rows next_row (fst state) in
+ (fst out_carry, snd state + snd out_carry)) start_state inp.
+
+ (* In order for the output to have the right length and bounds,
+ we insert rows of zeroes if there are fewer than two rows. *)
+ Definition flatten n (inp : rows) : list Z * Z :=
+ let default := Positional.zeros n in
+ flatten' (hd default inp, 0) (hd default (tl inp) :: tl (tl inp)).
+
+ Lemma flatten'_cons state r inp :
+ flatten' state (r :: inp) = (fst (sum_rows r (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows r (fst (flatten' state inp)))).
+ Proof. cbv [flatten']; autorewrite with list push_fold_right. reflexivity. Qed.
+ Lemma flatten'_snoc state r inp :
+ flatten' state (inp ++ r :: nil) = flatten' (fst (sum_rows r (fst state)), snd state + snd (sum_rows r (fst state))) inp.
+ Proof. cbv [flatten']; autorewrite with list push_fold_right. reflexivity. Qed.
+ Lemma flatten'_nil state : flatten' state [] = state. Proof. reflexivity. Qed.
+ Hint Rewrite flatten'_cons flatten'_snoc flatten'_nil : push_flatten.
+
+ Ltac push :=
+ repeat match goal with
+ | _ => progress intros
+ | H: length ?x = ?n |- context [snd (sum_rows ?x _)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto)
+ | H: length ?x = ?n |- context [snd (sum_rows _ ?x)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto)
+ | H: length _ = _ |- _ => rewrite H
+ | _ => progress autorewrite with cancel_pair push_flatten push_eval distr_length zsimplify_fast
+ | _ => progress In_cases
+ | |- _ /\ _ => split
+ | |- context [?x mod ?y] => unique pose proof (Z.mul_div_eq_full x y ltac:(auto)); lia
+ | _ => apply length_sum_rows
+ | _ => solve [repeat (ring_simplify; f_equal; try ring)]
+ | _ => congruence
+ | _ => solve [eauto]
+ end.
+
+ Lemma flatten'_div_mod_length n inp : forall start_state,
+ length (fst start_state) = n ->
+ (forall row, In row inp -> length row = n) ->
+ length (fst (flatten' start_state inp)) = n
+ /\ (inp <> nil ->
+ is_div_mod (Positional.eval weight n) (flatten' start_state inp)
+ (Positional.eval weight n (fst start_state) + eval n inp + weight n * snd start_state)
+ (weight n)).
+ Proof.
+ induction inp using rev_ind; push; [apply IHinp; push|].
+ destruct (dec (inp = nil)); [subst inp; cbv [is_div_mod]
+ | eapply is_div_mod_result_equal; try apply IHinp]; push.
+ { autorewrite with zsimplify; push. }
+ { rewrite Z.div_add' by auto; push. }
+ Qed.
+
+ Hint Rewrite (@Positional.length_zeros weight) : distr_length.
+ Hint Rewrite (@Positional.eval_zeros weight) using auto : push_eval.
+
+ Lemma flatten_div_mod inp n :
+ (forall row, In row inp -> length row = n) ->
+ is_div_mod (Positional.eval weight n) (flatten n inp) (eval n inp) (weight n).
+ Proof.
+ intros; cbv [flatten].
+ destruct inp; [|destruct inp]; cbn [hd tl].
+ { cbv [is_div_mod]; push.
+ erewrite sum_rows_div by (distr_length; reflexivity).
+ push. }
+ { cbv [is_div_mod]; push. }
+ { eapply is_div_mod_result_equal; try apply flatten'_div_mod_length; push. }
+ Qed.
+
+ Lemma flatten_mod inp n :
+ (forall row, In row inp -> length row = n) ->
+ Positional.eval weight n (fst (flatten n inp)) = (eval n inp) mod (weight n).
+ Proof. apply flatten_div_mod. Qed.
+ Lemma flatten_div inp n :
+ (forall row, In row inp -> length row = n) ->
+ snd (flatten n inp) = (eval n inp) / (weight n).
+ Proof. apply flatten_div_mod. Qed.
+
+ Lemma length_flatten' n start_state inp :
+ length (fst start_state) = n ->
+ (forall row, In row inp -> length row = n) ->
+ length (fst (flatten' start_state inp)) = n.
+ Proof. apply flatten'_div_mod_length. Qed.
+ Hint Rewrite length_flatten' : distr_length.
+
+ Lemma length_flatten n inp :
+ (forall row, In row inp -> length row = n) ->
+ length (fst (flatten n inp)) = n.
+ Proof.
+ intros.
+ apply length_flatten'; push;
+ destruct inp as [|? [|? ?] ]; try congruence; cbn [hd tl] in *; push;
+ subst row; distr_length.
+ Qed. Hint Rewrite length_flatten : distr_length.
+
+ Lemma flatten'_partitions n inp : forall start_state,
+ inp <> nil ->
+ length (fst start_state) = n ->
+ (forall row, In row inp -> length row = n) ->
+ forall i, (i < n)%nat ->
+ nth_default 0 (fst (flatten' start_state inp)) i
+ = ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i).
+ Proof.
+ induction inp using rev_ind; push.
+ destruct (dec (inp = nil)).
+ { subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. push. }
+ { erewrite IHinp; push.
+ rewrite add_mod_l_multiple by auto using weight_divides_full, weight_multiples_full.
+ push. }
+ Qed.
+
+ Lemma flatten_partitions inp n :
+ (forall row, In row inp -> length row = n) ->
+ forall i, (i < n)%nat ->
+ nth_default 0 (fst (flatten n inp)) i = (eval n inp mod weight (S i)) / (weight i).
+ Proof.
+ intros; cbv [flatten].
+ intros; destruct inp as [| ? [| ? ?] ]; try congruence; cbn [hd tl] in *; try solve [push].
+ { cbn. autorewrite with push_nth_default.
+ rewrite sum_rows_partitions with (n:=n) by distr_length.
+ autorewrite with push_eval zsimplify_fast.
+ auto with zarith. }
+ { push. rewrite sum_rows_partitions with (n:=n) by distr_length; push. }
+ { rewrite flatten'_partitions with (n:=n); push. }
+ Qed.
+
+ Definition partition n x :=
+ map (fun i => (x mod weight (S i)) / weight i) (seq 0 n).
+
+ Lemma nth_default_partitions x : forall p n,
+ (forall i, (i < n)%nat -> nth_default 0 p i = (x mod weight (S i)) / weight i) ->
+ length p = n ->
+ p = partition n x.
+ Proof.
+ cbv [partition]; induction p using rev_ind; intros; distr_length; subst n; [reflexivity|].
+ rewrite Nat.add_1_r, seq_snoc.
+ autorewrite with natsimplify push_map.
+ rewrite <-IHp; auto; intros;
+ match goal with H : context [nth_default _ (p ++ [ _ ])] |- _ =>
+ rewrite <-H by omega end.
+ { autorewrite with push_nth_default natsimplify. reflexivity. }
+ { autorewrite with push_nth_default natsimplify.
+ break_match; omega. }
+ Qed.
+
+ Lemma partition_step n x :
+ partition (S n) x = partition n x ++ [(x mod weight (S n)) / weight n].
+ Proof.
+ cbv [partition]. rewrite seq_snoc.
+ autorewrite with natsimplify push_map. reflexivity.
+ Qed.
+
+ Lemma length_partition n x : length (partition n x) = n.
+ Proof. cbv [partition]; distr_length. Qed.
+ Hint Rewrite length_partition : distr_length.
+
+ Lemma eval_partition n x :
+ Positional.eval weight n (partition n x) = x mod (weight n).
+ Proof.
+ induction n; intros.
+ { cbn. rewrite (weight_0); auto with zarith. }
+ { rewrite (Z.div_mod (x mod weight (S n)) (weight n)) by auto.
+ rewrite <-Znumtheory.Zmod_div_mod by (try apply Z.mod_divide; auto).
+ rewrite partition_step, Positional.eval_snoc with (n:=n) by distr_length.
+ omega. }
+ Qed.
+
+ Lemma flatten_partitions' inp n :
+ (forall row, In row inp -> length row = n) ->
+ fst (flatten n inp) = partition n (eval n inp).
+ Proof. auto using nth_default_partitions, flatten_partitions, length_flatten. Qed.
+ End Flatten.
+
+ Section Ops.
+ Definition add n p q := flatten n [p; q].
+
+ (* TODO: Although cleaner, using Positional.negate snd inserts
+ dlets which prevent add-opp=>sub transformation in partial
+ evaluation. Should probably either make partial evaluation
+ handle that or remove the dlet in
+ Positional.from_associational. *)
+ Definition sub n p q := flatten n [p; map (fun x => dlet y := x in Z.opp y) q].
+
+ Hint Rewrite eval_cons eval_nil using solve [auto] : push_eval.
+
+ Definition mul base n m (p q : list Z) :=
+ let p_a := Positional.to_associational weight n p in
+ let q_a := Positional.to_associational weight n q in
+ let pq_a := Associational.sat_mul base p_a q_a in
+ flatten m (from_associational m pq_a).
+
+ (* TODO : move sat_reduce and repeat_sat_reduce to Saturated.Associational *)
+ Definition sat_reduce base s c (p : list (Z * Z)) :=
+ let lo_hi := Associational.split s p in
+ fst lo_hi ++ (Associational.sat_mul_const base c (snd lo_hi)).
+
+ Definition repeat_sat_reduce base s c (p : list (Z * Z)) n :=
+ fold_right (fun _ q => sat_reduce base s c q) p (seq 0 n).
+
+ Definition mulmod base s c n nreductions (p q : list Z) :=
+ let p_a := Positional.to_associational weight n p in
+ let q_a := Positional.to_associational weight n q in
+ let pq_a := Associational.sat_mul base p_a q_a in
+ let r_a := repeat_sat_reduce base s c pq_a nreductions in
+ flatten n (from_associational n r_a).
+
+ Hint Rewrite Associational.eval_sat_mul_const Associational.eval_sat_mul Associational.eval_split using solve [auto] : push_eval.
+ Hint Rewrite eval_from_associational using solve [auto] : push_eval.
+ Hint Rewrite eval_partition using solve [auto] : push_eval.
+ Ltac solver :=
+ intros; cbv [sub add mul mulmod sat_reduce];
+ rewrite ?flatten_partitions' by (intros; In_cases; subst; distr_length; eauto using length_from_associational);
+ rewrite ?flatten_div by (intros; In_cases; subst; distr_length; eauto using length_from_associational);
+ autorewrite with push_eval; ring_simplify_subterms;
+ try reflexivity.
+
+ Lemma add_partitions n p q :
+ n <> 0%nat -> length p = n -> length q = n ->
+ fst (add n p q) = partition n (Positional.eval weight n p + Positional.eval weight n q).
+ Proof. solver. Qed.
+
+ Lemma add_div n p q :
+ n <> 0%nat -> length p = n -> length q = n ->
+ snd (add n p q) = (Positional.eval weight n p + Positional.eval weight n q) / weight n.
+ Proof. solver. Qed.
+
+ Lemma eval_map_opp q :
+ forall n, length q = n ->
+ Positional.eval weight n (map Z.opp q) = - Positional.eval weight n q.
+ Proof.
+ induction q using rev_ind; intros;
+ repeat match goal with
+ | _ => progress autorewrite with push_map push_eval
+ | _ => erewrite !Positional.eval_snoc with (n:=length q) by distr_length
+ | _ => rewrite IHq by auto
+ | _ => ring
+ end.
+ Qed. Hint Rewrite eval_map_opp using solve [auto]: push_eval.
+
+ Lemma sub_partitions n p q :
+ n <> 0%nat -> length p = n -> length q = n ->
+ fst (sub n p q) = partition n (Positional.eval weight n p - Positional.eval weight n q).
+ Proof. solver. Qed.
+
+ Lemma sub_div n p q :
+ n <> 0%nat -> length p = n -> length q = n ->
+ snd (sub n p q) = (Positional.eval weight n p - Positional.eval weight n q) / weight n.
+ Proof. solver. Qed.
+
+ Lemma mul_partitions base n m p q :
+ base <> 0 -> n <> 0%nat -> m <> 0%nat -> length p = n -> length q = n ->
+ fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q).
+ Proof. solver. Qed.
+
+ Lemma eval_sat_reduce base s c p :
+ base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 ->
+ Associational.eval (sat_reduce base s c p) mod (s - Associational.eval c)
+ = Associational.eval p mod (s - Associational.eval c).
+ Proof.
+ intros; cbv [sat_reduce].
+ autorewrite with push_eval.
+ rewrite <-Associational.reduction_rule by omega.
+ autorewrite with push_eval; reflexivity.
+ Qed.
+ Hint Rewrite eval_sat_reduce using auto : push_eval.
+
+ Lemma eval_repeat_sat_reduce base s c p n :
+ base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 ->
+ Associational.eval (repeat_sat_reduce base s c p n) mod (s - Associational.eval c)
+ = Associational.eval p mod (s - Associational.eval c).
+ Proof.
+ intros; cbv [repeat_sat_reduce].
+ apply fold_right_invariant; intros; autorewrite with push_eval; auto.
+ Qed.
+ Hint Rewrite eval_repeat_sat_reduce using auto : push_eval.
+
+ Lemma eval_mulmod base s c n nreductions p q :
+ base <> 0 -> s <> 0 -> s - Associational.eval c <> 0 ->
+ n <> 0%nat -> length p = n -> length q = n ->
+ (Positional.eval weight n (fst (mulmod base s c n nreductions p q))
+ + weight n * (snd (mulmod base s c n nreductions p q))) mod (s - Associational.eval c)
+ = (Positional.eval weight n p * Positional.eval weight n q) mod (s - Associational.eval c).
+ Proof.
+ solver.
+ rewrite <-Z.div_mod'' by auto.
+ autorewrite with push_eval; reflexivity.
+ Qed.
+ End Ops.
+ End Rows.
+End Rows.
+
+Module BaseConversion.
+ Import Positional.
+ Section BaseConversion.
+ Hint Resolve Z.gt_lt.
+ Context (sw dw : nat -> Z) (* source/destination weight functions *)
+ {swprops : @weight_properties sw}
+ {dwprops : @weight_properties dw}.
+
+ Definition convert_bases (sn dn : nat) (p : list Z) : list Z :=
+ let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in
+ chained_carries_no_reduce dw dn p' (seq 0 (pred dn)).
+
+ Lemma eval_convert_bases sn dn p :
+ (dn <> 0%nat) -> length p = sn ->
+ eval dw dn (convert_bases sn dn p) = eval sw sn p.
+ Proof.
+ cbv [convert_bases]; intros.
+ rewrite eval_chained_carries_no_reduce; auto using ZUtil.Z.positive_is_nonzero.
+ rewrite eval_from_associational; auto.
+ Qed.
+
+ Hint Rewrite
+ @Rows.eval_from_associational
+ @Associational.eval_carry
+ @Associational.eval_mul
+ @Positional.eval_to_associational
+ Associational.eval_carryterm
+ @eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval.
+
+ Ltac push_eval := intros; autorewrite with push_eval; auto with zarith.
+
+ (* convert from positional in one weight to the other, then to associational *)
+ Definition to_associational n m p : list (Z * Z) :=
+ let p' := convert_bases n m p in
+ Positional.to_associational dw m p'.
+
+ (* TODO : move to Associational? *)
+ Section reorder.
+ Definition reordering_carry (w fw : Z) (p : list (Z * Z)) :=
+ fold_right (fun t acc =>
+ let r := Associational.carryterm w fw t in
+ if fst t =? w then acc ++ r else r ++ acc) nil p.
+
+ Lemma eval_reordering_carry w fw p (_:fw<>0):
+ Associational.eval (reordering_carry w fw p) = Associational.eval p.
+ Proof.
+ cbv [reordering_carry]. induction p; [reflexivity |].
+ autorewrite with push_fold_right. break_match; push_eval.
+ Qed.
+ End reorder.
+ Hint Rewrite eval_reordering_carry using solve [auto using Z.positive_is_nonzero] : push_eval.
+
+ (* carry at specified indices in dw, then use Rows.flatten to convert to Positional with sw *)
+ Definition from_associational idxs n (p : list (Z * Z)) : list Z :=
+ (* important not to use Positional.carry here; we don't want to accumulate yet *)
+ let p' := fold_right (fun i acc => reordering_carry (dw i) (dw (S i) / dw i) acc) (Associational.bind_snd p) (rev idxs) in
+ fst (Rows.flatten sw n (Rows.from_associational sw n p')).
+
+ Lemma eval_carries p idxs :
+ Associational.eval (fold_right (fun i acc => reordering_carry (dw i) (dw (S i) / dw i) acc) p idxs) =
+ Associational.eval p.
+ Proof. apply fold_right_invariant; push_eval. Qed.
+ Hint Rewrite eval_carries: push_eval.
+
+ Lemma eval_to_associational n m p :
+ m <> 0%nat -> length p = n ->
+ Associational.eval (to_associational n m p) = Positional.eval sw n p.
+ Proof. cbv [to_associational]; push_eval. Qed.
+ Hint Rewrite eval_to_associational using solve [push_eval; distr_length] : push_eval.
+
+ Lemma eval_from_associational idxs n p :
+ n <> 0%nat -> 0 <= Associational.eval p < sw n ->
+ Positional.eval sw n (from_associational idxs n p) = Associational.eval p.
+ Proof.
+ cbv [from_associational]; intros.
+ rewrite Rows.flatten_mod by eauto using Rows.length_from_associational.
+ rewrite Associational.bind_snd_correct.
+ push_eval.
+ Qed.
+ Hint Rewrite eval_from_associational using solve [push_eval; distr_length] : push_eval.
+
+ Lemma from_associational_partitions n idxs p (_:n<>0%nat):
+ forall i, (i < n)%nat ->
+ nth_default 0 (from_associational idxs n p) i = (Associational.eval p) mod (sw (S i)) / sw i.
+ Proof.
+ intros; cbv [from_associational].
+ rewrite Rows.flatten_partitions with (n:=n) by (eauto using Rows.length_from_associational; omega).
+ rewrite Associational.bind_snd_correct.
+ push_eval.
+ Qed.
+
+ Lemma from_associational_eq n idxs p (_:n<>0%nat):
+ from_associational idxs n p = Rows.partition sw n (Associational.eval p).
+ Proof.
+ intros. cbv [from_associational].
+ rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational.
+ rewrite Associational.bind_snd_correct.
+ push_eval.
+ Qed.
+
+ Derive from_associational_inlined
+ SuchThat (forall idxs n p,
+ from_associational_inlined idxs n p = from_associational idxs n p)
+ As from_associational_inlined_correct.
+ Proof.
+ intros.
+ cbv beta iota delta [from_associational reordering_carry Associational.carryterm].
+ cbv beta iota delta [Let_In]. (* inlines all shifts/lands from carryterm *)
+ cbv beta iota delta [from_associational Rows.from_associational Columns.from_associational].
+ cbv beta iota delta [Let_In]. (* inlines the shifts from place *)
+ subst from_associational_inlined; reflexivity.
+ Qed.
+
+ Derive to_associational_inlined
+ SuchThat (forall n m p,
+ to_associational_inlined n m p = to_associational n m p)
+ As to_associational_inlined_correct.
+ Proof.
+ intros.
+ cbv beta iota delta [ to_associational convert_bases
+ Positional.to_associational
+ Positional.from_associational
+ chained_carries_no_reduce
+ carry
+ Associational.carry
+ Associational.carryterm
+ ].
+ cbv beta iota delta [Let_In].
+ subst to_associational_inlined; reflexivity.
+ Qed.
+
+ (* carry chain that aligns terms in the intermediate weight with the final weight *)
+ Definition aligned_carries (log_dw_sw nout : nat)
+ := (map (fun i => ((log_dw_sw * (i + 1)) - 1))%nat (seq 0 nout)).
+
+ Section mul_converted.
+ Definition mul_converted
+ n1 n2 (* lengths in original format *)
+ m1 m2 (* lengths in converted format *)
+ (n3 : nat) (* final length *)
+ (idxs : list nat) (* carries to do -- this helps preemptively line up weights *)
+ (p1 p2 : list Z) :=
+ let p1_a := to_associational n1 m1 p1 in
+ let p2_a := to_associational n2 m2 p2 in
+ let p3_a := Associational.mul p1_a p2_a in
+ from_associational idxs n3 p3_a.
+
+ Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
+ length p1 = n1 -> length p2 = n2 ->
+ 0 <= (Positional.eval sw n1 p1 * Positional.eval sw n2 p2) < sw n3 ->
+ Positional.eval sw n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval sw n1 p1) * (Positional.eval sw n2 p2).
+ Proof. cbv [mul_converted]; push_eval. Qed.
+ Hint Rewrite eval_mul_converted : push_eval.
+
+ Lemma mul_converted_partitions n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
+ length p1 = n1 -> length p2 = n2 ->
+ mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = Rows.partition sw n3 (Positional.eval sw n1 p1 * Positional.eval sw n2 p2).
+ Proof.
+ intros; cbv [mul_converted].
+ rewrite from_associational_eq by auto. push_eval.
+ Qed.
+ End mul_converted.
+ End BaseConversion.
+
+ (* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *)
+ Section widemul.
+ Context (log2base : Z) (log2base_pos : 0 < log2base).
+ Context (n : nat) (n_nz : n <> 0%nat) (n_le_log2base : Z.of_nat n <= log2base)
+ (nout : nat) (nout_2 : nout = 2%nat). (* nout is always 2, but partial evaluation is overeager if it's a constant *)
+ Let dw : nat -> Z := weight (log2base / Z.of_nat n) 1.
+ Let sw : nat -> Z := weight log2base 1.
+
+ Local Lemma base_bounds : 0 < 1 <= log2base. Proof. auto with zarith. Qed.
+ Local Lemma dbase_bounds : 0 < 1 <= log2base / Z.of_nat n. Proof. auto with zarith. Qed.
+ Let dwprops : @weight_properties dw := wprops (log2base / Z.of_nat n) 1 dbase_bounds.
+ Let swprops : @weight_properties sw := wprops log2base 1 base_bounds.
+
+ Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg.
+
+ Definition widemul a b := mul_converted sw dw 1 1 n n nout (aligned_carries n nout) [a] [b].
+
+ Lemma widemul_correct a b :
+ 0 <= a * b < 2^log2base * 2^log2base ->
+ widemul a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base].
+ Proof.
+ cbv [widemul]; intros.
+ rewrite mul_converted_partitions by auto with zarith.
+ subst nout sw; cbv [weight]; cbn.
+ autorewrite with zsimplify.
+ rewrite Z.pow_mul_r, Z.pow_2_r by omega.
+ Z.rewrite_mod_small. reflexivity.
+ Qed.
+
+ Derive widemul_inlined
+ SuchThat (forall a b,
+ 0 <= a * b < 2^log2base * 2^log2base ->
+ widemul_inlined a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base])
+ As widemul_inlined_correct.
+ Proof.
+ intros.
+ rewrite <-widemul_correct by auto.
+ cbv beta iota delta [widemul mul_converted].
+ rewrite <-to_associational_inlined_correct with (p:=[a]).
+ rewrite <-to_associational_inlined_correct with (p:=[b]).
+ rewrite <-from_associational_inlined_correct.
+ subst widemul_inlined; reflexivity.
+ Qed.
+
+ Derive widemul_inlined_reverse
+ SuchThat (forall a b,
+ 0 <= a * b < 2^log2base * 2^log2base ->
+ widemul_inlined_reverse a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base])
+ As widemul_inlined_reverse_correct.
+ Proof.
+ intros.
+ rewrite <-widemul_inlined_correct by assumption.
+ cbv [widemul_inlined].
+ match goal with |- _ = from_associational_inlined sw dw ?idxs ?n ?p =>
+ transitivity (from_associational_inlined sw dw idxs n (rev p));
+ [ | transitivity (from_associational sw dw idxs n p); [ | reflexivity ] ](* reverse to make addc chains line up *)
+ end.
+ Focus 2. {
+ rewrite from_associational_inlined_correct by (subst nout; auto).
+ cbv [from_associational].
+ rewrite !Rows.flatten_partitions' by eauto using Rows.length_from_associational.
+ rewrite !Rows.eval_from_associational by (subst nout; auto).
+ f_equal.
+ rewrite !eval_carries, !Associational.bind_snd_correct, !Associational.eval_rev by auto.
+ reflexivity. } Unfocus.
+ subst widemul_inlined_reverse; reflexivity.
+ Qed.
+ End widemul.
+End BaseConversion.