aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-03-07 16:07:12 +0100
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-03 09:00:55 -0400
commita4fd850b6d41a20e0282685186216ea82c707ce8 (patch)
treebea9fc16faa158c391574feac8b3b94eee6076f0 /src
parent8690fc7b8054901178a0756d1fb6f47342a2bd55 (diff)
organize Rows into sections
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v934
1 files changed, 469 insertions, 465 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 48661f4c6..4ca2e89be 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -976,496 +976,500 @@ Module Rows.
Proof. cbv [eval]; autorewrite with push_map push_sum; reflexivity. Qed.
Hint Rewrite eval_app : push_eval.
- Definition extract_row (inp : cols) : cols * list Z := (map (fun c => tl c) inp, map (fun c => hd 0 c) inp).
-
- Definition max_column_size (x:cols) := fold_right (fun a b => Nat.max a b) 0%nat (map (fun c => length c) x).
-
- Definition from_columns' n start_state : cols * rows :=
- fold_right (fun _ (state : cols * rows) =>
- let cols'_row := extract_row (fst state) in
- (fst cols'_row, snd state ++ [snd cols'_row])
- ) start_state (List.repeat 0 n).
+ Section FromAssociational.
+ (* extract row *)
+ Definition extract_row (inp : cols) : cols * list Z := (map (fun c => tl c) inp, map (fun c => hd 0 c) inp).
- Definition from_columns (inp : cols) : rows := snd (from_columns' (max_column_size inp) (inp, [])).
+ Lemma eval_extract_row (inp : cols): forall n,
+ length inp = n ->
+ Positional.eval weight n (snd (extract_row inp)) = Columns.eval weight n inp - Columns.eval weight n (fst (extract_row inp)) .
+ Proof.
+ cbv [extract_row].
+ induction inp using rev_ind; [ | destruct n ];
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress distr_length
+ | _ => rewrite Positional.eval_snoc with (n:=n) by distr_length
+ | _ => progress autorewrite with cancel_pair push_eval push_map in *
+ | _ => ring
+ end.
+ rewrite IHinp by distr_length.
+ destruct x; cbn [hd tl]; rewrite ?sum_nil, ?sum_cons; ring.
+ Qed. Hint Rewrite eval_extract_row using (solve [distr_length]) : push_eval.
+
+ Lemma length_fst_extract_row n (inp : cols) :
+ length inp = n -> length (fst (extract_row inp)) = n.
+ Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed.
+ Hint Rewrite length_fst_extract_row : distr_length.
+
+ Lemma length_snd_extract_row n (inp : cols) :
+ length inp = n -> length (snd (extract_row inp)) = n.
+ Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed.
+ Hint Rewrite length_snd_extract_row : distr_length.
+
+ (* max column size *)
+ Definition max_column_size (x:cols) := fold_right (fun a b => Nat.max a b) 0%nat (map (fun c => length c) x).
+
+ (* TODO: move to where list is defined *)
+ Hint Rewrite @app_nil_l : list.
+ Hint Rewrite <-@app_comm_cons: list.
+
+ Lemma max_column_size_nil : max_column_size nil = 0%nat.
+ Proof. reflexivity. Qed. Hint Rewrite max_column_size_nil : push_max_column_size.
+ Lemma max_column_size_cons col (inp : cols) :
+ max_column_size (col :: inp) = Nat.max (length col) (max_column_size inp).
+ Proof. reflexivity. Qed. Hint Rewrite max_column_size_cons : push_max_column_size.
+ Lemma max_column_size_app (x y : cols) :
+ max_column_size (x ++ y) = Nat.max (max_column_size x) (max_column_size y).
+ Proof. induction x; autorewrite with list push_max_column_size; lia. Qed.
+ Hint Rewrite max_column_size_app : push_max_column_size.
+ Lemma max_column_size0 (inp : cols) :
+ forall n,
+ length inp = n -> (* this is not needed to make the lemma true, but prevents reliance on the implementation of Columns.eval*)
+ max_column_size inp = 0%nat -> Columns.eval weight n inp = 0.
+ Proof.
+ induction inp as [|x inp] using rev_ind; destruct n; try destruct x; intros;
+ autorewrite with push_max_column_size push_eval push_sum distr_length in *; try lia.
+ rewrite IHinp; distr_length; lia.
+ Qed.
- Definition from_associational n (p : list (Z * Z)) := from_columns (Columns.from_associational weight n p).
-
- Lemma eval_extract_row (inp : cols): forall n,
- length inp = n ->
- Positional.eval weight n (snd (extract_row inp)) = Columns.eval weight n inp - Columns.eval weight n (fst (extract_row inp)) .
- Proof.
- cbv [extract_row].
- induction inp using rev_ind; [ | destruct n ];
- repeat match goal with
- | _ => progress intros
- | _ => progress distr_length
- | _ => rewrite Positional.eval_snoc with (n:=n) by distr_length
- | _ => progress autorewrite with cancel_pair push_eval push_map in *
- | _ => ring
- end.
- rewrite IHinp by distr_length.
- destruct x; cbn [hd tl]; rewrite ?sum_nil, ?sum_cons; ring.
- Qed. Hint Rewrite eval_extract_row using (solve [distr_length]) : push_eval.
-
- Lemma length_fst_extract_row n (inp : cols) :
- length inp = n -> length (fst (extract_row inp)) = n.
- Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed.
- Hint Rewrite length_fst_extract_row : distr_length.
-
- Lemma length_snd_extract_row n (inp : cols) :
- length inp = n -> length (snd (extract_row inp)) = n.
- Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed.
- Hint Rewrite length_snd_extract_row : distr_length.
-
- (* TODO: move to where list is defined *)
- Hint Rewrite @app_nil_l : list.
- Hint Rewrite <-@app_comm_cons: list.
-
- Lemma max_column_size_nil : max_column_size nil = 0%nat.
- Proof. reflexivity. Qed. Hint Rewrite max_column_size_nil : push_max_column_size.
- Lemma max_column_size_cons col (inp : cols) :
- max_column_size (col :: inp) = Nat.max (length col) (max_column_size inp).
- Proof. reflexivity. Qed. Hint Rewrite max_column_size_cons : push_max_column_size.
- Lemma max_column_size_app (x y : cols) :
- max_column_size (x ++ y) = Nat.max (max_column_size x) (max_column_size y).
- Proof. induction x; autorewrite with list push_max_column_size; lia. Qed.
- Hint Rewrite max_column_size_app : push_max_column_size.
- Lemma max_column_size0 (inp : cols) :
- forall n,
- length inp = n -> (* this is not needed to make the lemma true, but prevents reliance on the implementation of Columns.eval*)
- max_column_size inp = 0%nat -> Columns.eval weight n inp = 0.
- Proof.
- induction inp as [|x inp] using rev_ind; destruct n; try destruct x; intros;
- autorewrite with push_max_column_size push_eval push_sum distr_length in *; try lia.
- rewrite IHinp; distr_length; lia.
- Qed.
+ (* from_columns *)
+ Definition from_columns' n start_state : cols * rows :=
+ fold_right (fun _ (state : cols * rows) =>
+ let cols'_row := extract_row (fst state) in
+ (fst cols'_row, snd state ++ [snd cols'_row])
+ ) start_state (List.repeat 0 n).
- Local Ltac In_cases :=
- repeat match goal with
- | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H
- | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H
- | H: In _ nil |- _ => contradiction H
- end.
+ Definition from_columns (inp : cols) : rows := snd (from_columns' (max_column_size inp) (inp, [])).
- Lemma eval_from_columns'_with_length m st n:
- (length (fst st) = n) ->
- length (fst (from_columns' m st)) = n /\
- ((forall r, In r (snd st) -> length r = n) ->
- forall r, In r (snd (from_columns' m st)) -> length r = n) /\
- eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st)
- - Columns.eval weight n (fst (from_columns' m st)).
- Proof.
- cbv [from_columns']; intros.
- apply fold_right_invariant; intros;
+ Local Ltac In_cases :=
repeat match goal with
- | _ => progress (intros; subst)
- | _ => progress autorewrite with cancel_pair push_eval
- | _ => progress In_cases
- | _ => split; try omega
- | H: _ /\ _ |- _ => destruct H
- | _ => solve [auto using length_fst_extract_row, length_snd_extract_row]
+ | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H
+ | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H
+ | H: In _ nil |- _ => contradiction H
end.
- Qed.
- Lemma length_fst_from_columns' m st :
- length (fst (from_columns' m st)) = length (fst st).
- Proof. apply eval_from_columns'_with_length; reflexivity. Qed.
- Hint Rewrite length_fst_from_columns' : distr_length.
- Lemma length_snd_from_columns' m st :
- (forall r, In r (snd st) -> length r = length (fst st)) ->
- forall r, In r (snd (from_columns' m st)) -> length r = length (fst st).
- Proof. apply eval_from_columns'_with_length. reflexivity. Qed.
- Hint Rewrite length_snd_from_columns' : distr_length.
- Lemma eval_from_columns' m st n :
- (length (fst st) = n) ->
- eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st)
- - Columns.eval weight n (fst (from_columns' m st)).
- Proof. apply eval_from_columns'_with_length. Qed.
- Hint Rewrite eval_from_columns' using (auto; solve [distr_length]) : push_eval.
-
- Lemma max_column_size_extract_row inp :
- max_column_size (fst (extract_row inp)) = (max_column_size inp - 1)%nat.
- Proof.
- cbv [extract_row]. autorewrite with cancel_pair.
- induction inp; [ reflexivity | ].
- autorewrite with push_max_column_size push_map distr_length.
- rewrite IHinp. auto using Nat.sub_max_distr_r.
- Qed.
- Hint Rewrite max_column_size_extract_row : push_max_column_size.
- Lemma max_column_size_from_columns' m st :
- max_column_size (fst (from_columns' m st)) = (max_column_size (fst st) - m)%nat.
- Proof.
- cbv [from_columns']; induction m; intros; cbn - [max_column_size extract_row];
- autorewrite with push_max_column_size; lia.
- Qed.
- Hint Rewrite max_column_size_from_columns' : push_max_column_size.
-
- Lemma eval_from_columns (inp : cols) :
- forall n, length inp = n -> eval n (from_columns inp) = Columns.eval weight n inp.
- Proof.
- intros; cbv [from_columns];
- repeat match goal with
- | _ => progress autorewrite with cancel_pair push_eval push_max_column_size
- | _ => rewrite max_column_size0 with (inp := fst (from_columns' _ _)) by
- (autorewrite with push_max_column_size; distr_length)
- | _ => omega
- end.
- Qed.
- Hint Rewrite eval_from_columns using (auto; solve [distr_length]) : push_eval.
-
- Lemma length_from_columns inp:
- forall r, In r (from_columns inp) -> length r = length inp.
- Proof.
- cbv [from_columns]; intros.
- change inp with (fst (inp, @nil (list Z))).
- eapply length_snd_from_columns'; eauto.
- autorewrite with cancel_pair; intros; In_cases.
- Qed.
- Hint Rewrite length_from_columns : distr_length.
-
- Lemma eval_from_associational n p: (n <> 0%nat \/ p = nil) ->
- eval n (from_associational n p) = Associational.eval p.
- Proof.
- intros. cbv [from_associational].
- rewrite eval_from_columns by auto using Columns.length_from_associational.
- auto using Columns.eval_from_associational.
- Qed.
+ Lemma eval_from_columns'_with_length m st n:
+ (length (fst st) = n) ->
+ length (fst (from_columns' m st)) = n /\
+ ((forall r, In r (snd st) -> length r = n) ->
+ forall r, In r (snd (from_columns' m st)) -> length r = n) /\
+ eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st)
+ - Columns.eval weight n (fst (from_columns' m st)).
+ Proof.
+ cbv [from_columns']; intros.
+ apply fold_right_invariant; intros;
+ repeat match goal with
+ | _ => progress (intros; subst)
+ | _ => progress autorewrite with cancel_pair push_eval
+ | _ => progress In_cases
+ | _ => split; try omega
+ | H: _ /\ _ |- _ => destruct H
+ | _ => solve [auto using length_fst_extract_row, length_snd_extract_row]
+ end.
+ Qed.
+ Lemma length_fst_from_columns' m st :
+ length (fst (from_columns' m st)) = length (fst st).
+ Proof. apply eval_from_columns'_with_length; reflexivity. Qed.
+ Hint Rewrite length_fst_from_columns' : distr_length.
+ Lemma length_snd_from_columns' m st :
+ (forall r, In r (snd st) -> length r = length (fst st)) ->
+ forall r, In r (snd (from_columns' m st)) -> length r = length (fst st).
+ Proof. apply eval_from_columns'_with_length. reflexivity. Qed.
+ Hint Rewrite length_snd_from_columns' : distr_length.
+ Lemma eval_from_columns' m st n :
+ (length (fst st) = n) ->
+ eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st)
+ - Columns.eval weight n (fst (from_columns' m st)).
+ Proof. apply eval_from_columns'_with_length. Qed.
+ Hint Rewrite eval_from_columns' using (auto; solve [distr_length]) : push_eval.
+
+ Lemma max_column_size_extract_row inp :
+ max_column_size (fst (extract_row inp)) = (max_column_size inp - 1)%nat.
+ Proof.
+ cbv [extract_row]. autorewrite with cancel_pair.
+ induction inp; [ reflexivity | ].
+ autorewrite with push_max_column_size push_map distr_length.
+ rewrite IHinp. auto using Nat.sub_max_distr_r.
+ Qed.
+ Hint Rewrite max_column_size_extract_row : push_max_column_size.
- Lemma length_from_associational n p :
- forall r, In r (from_associational n p) -> length r = n.
- Proof.
- cbv [from_associational]; intros.
- match goal with H: _ |- _ => apply length_from_columns in H end.
- rewrite Columns.length_from_associational in *; auto.
- Qed.
+ Lemma max_column_size_from_columns' m st :
+ max_column_size (fst (from_columns' m st)) = (max_column_size (fst st) - m)%nat.
+ Proof.
+ cbv [from_columns']; induction m; intros; cbn - [max_column_size extract_row];
+ autorewrite with push_max_column_size; lia.
+ Qed.
+ Hint Rewrite max_column_size_from_columns' : push_max_column_size.
- Local Notation fw := (fun i => weight (S i) / weight i) (only parsing).
-
- Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z :=
- fold_right (fun next (state : list Z * Z) =>
- let i := length (fst state) in (* length of output accumulator tells us the index of [next] *)
- dlet_nd next := next in (* makes the output correctly bind variables *)
- dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in
- (fst state ++ [fst sum_carry], snd sum_carry)) start_state (rev (combine row1 row2)).
- Definition sum_rows := sum_rows' (nil,0).
-
- Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z :=
- fold_right (fun next_row (state : list Z * Z)=>
- let out_carry := sum_rows next_row (fst state) in
- (fst out_carry, snd state + snd out_carry)) start_state inp.
-
- (* For correctness if there is only one row, we add a row of
- zeroes with the same length so that the add loop still happens. *)
- Definition flatten (inp : rows) : list Z * Z :=
- let first_row := hd nil inp in
- flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)).
-
- Lemma sum_rows'_div_mod_length row1 :
- forall n m start_state row2 row1' row2',
- length (fst start_state) = m -> length row1 = n -> length row2 = n ->
- length row1' = m -> length row2' = m ->
- let nm : nat := (n + m)%nat in
- let eval := Positional.eval weight in
- eval m (fst start_state) = (eval m row1' + eval m row2') mod (weight m) ->
- snd start_state = (eval m row1' + eval m row2') / weight m ->
- length (fst (sum_rows' start_state row1 row2)) = (n + m)%nat
- /\ eval nm (fst (sum_rows' start_state row1 row2))
- = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight nm)
- /\ snd (sum_rows' start_state row1 row2)
- = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / (weight nm).
- Proof.
- cbv [sum_rows' Let_In].
- induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *;
- destruct row2 as [|x2 row2]; distr_length; [ subst n | ];
+ Lemma eval_from_columns (inp : cols) :
+ forall n, length inp = n -> eval n (from_columns inp) = Columns.eval weight n inp.
+ Proof.
+ intros; cbv [from_columns];
repeat match goal with
- | _ => progress autorewrite with natsimplify list
- | _ => progress cbn [fold_left combine]
+ | _ => progress autorewrite with cancel_pair push_eval push_max_column_size
+ | _ => rewrite max_column_size0 with (inp := fst (from_columns' _ _)) by
+ (autorewrite with push_max_column_size; distr_length)
| _ => omega
end.
+ Qed.
+ Hint Rewrite eval_from_columns using (auto; solve [distr_length]) : push_eval.
- specialize (IHrow1 (pred n) (S m)).
- replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega.
- rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
- rewrite <-fold_left_rev_right.
- apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega;
- repeat match goal with
- | H : ?LHS = _ |- _ =>
- match LHS with context [start_state] => rewrite H end
- | H : length _ = _ |- _ => rewrite H
- | _ => rewrite <-Z.div_add by auto
- | _ => rewrite Z.div_div by auto using Z.gt_lt
- | _ => rewrite Z.mul_div_eq by auto
- | _ => rewrite weight_multiples
- | _ => erewrite Positional.eval_snoc by eauto
- | _ => progress autorewrite with cancel_pair distr_length to_div_mod in *
- | |- context [ ?x mod ?m + ?m * (((?x + ?a * ?m + ?b * ?m)/ ?m) mod ?c) ] =>
- replace (x mod m) with ((x + a * m + b * m) mod m) by
- (autorewrite with zsimplify; ring);
- rewrite <-Z.rem_mul_r by auto using Z.gt_lt
- | _ => f_equal; ring
- end.
- Qed.
-
- Lemma sum_rows_div_mod n row1 row2 :
- length row1 = n -> length row2 = n ->
- let eval := Positional.eval weight in
- eval n (fst (sum_rows row1 row2)) = (eval n row1 + eval n row2) mod (weight n)
- /\ snd (sum_rows row1 row2) = (eval n row1 + eval n row2) / (weight n).
- Proof.
- cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
- edestruct sum_rows'_div_mod_length as [Hlen [Hmod Hdiv] ];
- try erewrite Hmod, Hdiv; auto using nil_length0;
- autorewrite with cancel_pair push_eval zsimplify_fast; distr_length.
- Qed.
+ Lemma length_from_columns inp:
+ forall r, In r (from_columns inp) -> length r = length inp.
+ Proof.
+ cbv [from_columns]; intros.
+ change inp with (fst (inp, @nil (list Z))).
+ eapply length_snd_from_columns'; eauto.
+ autorewrite with cancel_pair; intros; In_cases.
+ Qed.
+ Hint Rewrite length_from_columns : distr_length.
- Lemma sum_rows_mod n row1 row2 :
- length row1 = n -> length row2 = n ->
- Positional.eval weight n (fst (sum_rows row1 row2))
- = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n).
- Proof. apply sum_rows_div_mod. Qed.
- Lemma sum_rows_div row1 row2 n:
- length row1 = n -> length row2 = n ->
- snd (sum_rows row1 row2)
- = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n).
- Proof. apply sum_rows_div_mod. Qed.
-
- Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval.
- Local Ltac mul_div_weights i j :=
- rewrite (Z.div_mod (weight i) (weight j)) by auto;
- rewrite Columns.weight_multiples_full by (auto; omega);
- autorewrite with zsimplify_fast.
-
- (* TODO: figure out where to put this and weight_multiples_full *)
- Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0.
- Proof. auto using Z.div_positive_gt_0, Columns.weight_multiples_full. Qed.
-
- Lemma sum_rows'_partitions row1 :
- forall n m start_state row2 row1' row2',
- length (fst start_state) = m -> length row1 = n -> length row2 = n ->
- length row1' = m -> length row2' = m ->
- let nm := (n + m)%nat in
- let eval := Positional.eval weight in
- snd start_state = (eval m row1' + eval m row2') / weight m ->
- (forall j, (j < m)%nat ->
- nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) ->
- forall i, (i < nm)%nat ->
- nth_default 0 (fst (sum_rows' start_state row1 row2)) i
- = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i).
- Proof.
- cbv [sum_rows' Let_In].
- induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *;
- destruct row2 as [|x2 row2]; distr_length; [ subst n | ];
- repeat match goal with
- | _ => progress cbn [fold_left]
- | H : length _ = _ |- _ => rewrite H
- | H: _ |- _ => solve [apply H; omega]
- | _ => progress autorewrite with push_eval zsimplify_fast push_combine list natsimplify cancel_pair to_div_mod
- end.
+ (* from associational *)
+ Definition from_associational n (p : list (Z * Z)) := from_columns (Columns.from_associational weight n p).
+
+ Lemma eval_from_associational n p: (n <> 0%nat \/ p = nil) ->
+ eval n (from_associational n p) = Associational.eval p.
+ Proof.
+ intros. cbv [from_associational].
+ rewrite eval_from_columns by auto using Columns.length_from_associational.
+ auto using Columns.eval_from_associational.
+ Qed.
- specialize (IHrow1 (pred n) (S m)).
- replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega.
- rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
- rewrite <-fold_left_rev_right.
- apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega;
- repeat match goal with
- | _ => progress intros
- | _ => progress autorewrite with push_nth_default
- | _ => progress break_match
- | H : length _ = _ |- _ => rewrite H
- | H : ?LHS = _ |- _ =>
- match LHS with context [start_state] => rewrite H end
- | H : context [nth_default 0 (fst start_state)] |- _ => rewrite H by omega
- | _ => rewrite Positional.eval_snoc with (n:=m) by eauto
- end.
- { mul_div_weights (S m) m.
- rewrite <-Z.div_div by auto using Z.gt_lt.
- autorewrite with zsimplify.
- f_equal; ring. }
- { mul_div_weights m (S j).
- push_Zmod. autorewrite with zsimplify.
- lia. }
- { replace j with m by omega.
- autorewrite with push_nth_default natsimplify.
- rewrite <-!Z.div_add' by auto.
- rewrite Z.mod_pull_div, Z.mul_div_eq' by (auto using Z.lt_le_incl, Z.gt_lt).
- rewrite weight_multiples.
- autorewrite with zsimplify_fast.
- repeat (f_equal; try ring). }
- Qed.
-
- Lemma sum_rows_partitions row1: forall row2 n i,
- length row1 = n -> length row2 = n -> (i < n)%nat ->
- nth_default 0 (fst (sum_rows row1 row2)) i
- = ((Positional.eval weight n row1 + Positional.eval weight n row2) mod weight (S i)) / (weight i).
- Proof.
- cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
- rewrite <-(app_nil_l row1), <-(app_nil_l row2).
- apply sum_rows'_partitions; intros;
- autorewrite with cancel_pair push_eval zsimplify_fast push_nth_default; distr_length.
- rewrite Z.div_0_l by auto; omega.
- Qed.
+ Lemma length_from_associational n p :
+ forall r, In r (from_associational n p) -> length r = n.
+ Proof.
+ cbv [from_associational]; intros.
+ match goal with H: _ |- _ => apply length_from_columns in H end.
+ rewrite Columns.length_from_associational in *; auto.
+ Qed.
+ End FromAssociational.
- Lemma length_sum_rows row1 row2 n :
- length row1 = n -> length row2 = n ->
- length (fst (sum_rows row1 row2)) = n.
- Proof.
- cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
- eapply sum_rows'_div_mod_length; auto using nil_length0.
- Qed. Hint Rewrite length_sum_rows : distr_length.
-
-
- (* TODO: move to ListUtil *)
- (* connect state to remaining input *)
- Lemma fold_right_invariant_strong {A B} (P: list B -> A -> Type) (f: B -> A -> A):
- forall bs a0
- (Pnil : P nil a0)
- (IHfold:
- forall bs' b a',
- In b bs ->
- a' = fold_right f a0 bs' ->
- P bs' a' ->
- P (b :: bs') (f b a')),
- P bs (fold_right f a0 bs).
- Proof. induction bs; intros; autorewrite with push_fold_right; auto using in_eq,in_cons. Qed.
-
- Lemma flatten'_div_mod_length n start_state inp:
- length (fst start_state) = n ->
- (forall row, In row inp -> length row = n) ->
- length (fst (flatten' start_state inp)) = n
+ Section Flatten.
+ Local Notation fw := (fun i => weight (S i) / weight i) (only parsing).
+
+ Section SumRows.
+ Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z :=
+ fold_right (fun next (state : list Z * Z) =>
+ let i := length (fst state) in (* length of output accumulator tells us the index of [next] *)
+ dlet_nd next := next in (* makes the output correctly bind variables *)
+ dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in
+ (fst state ++ [fst sum_carry], snd sum_carry)) start_state (rev (combine row1 row2)).
+ Definition sum_rows := sum_rows' (nil,0).
+
+ Lemma sum_rows'_div_mod_length row1 :
+ forall n m start_state row2 row1' row2',
+ length (fst start_state) = m -> length row1 = n -> length row2 = n ->
+ length row1' = m -> length row2' = m ->
+ let nm : nat := (n + m)%nat in
+ let eval := Positional.eval weight in
+ eval m (fst start_state) = (eval m row1' + eval m row2') mod (weight m) ->
+ snd start_state = (eval m row1' + eval m row2') / weight m ->
+ length (fst (sum_rows' start_state row1 row2)) = (n + m)%nat
+ /\ eval nm (fst (sum_rows' start_state row1 row2))
+ = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight nm)
+ /\ snd (sum_rows' start_state row1 row2)
+ = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / (weight nm).
+ Proof.
+ cbv [sum_rows' Let_In].
+ induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *;
+ destruct row2 as [|x2 row2]; distr_length; [ subst n | ];
+ repeat match goal with
+ | _ => progress autorewrite with natsimplify list
+ | _ => progress cbn [fold_left combine]
+ | _ => omega
+ end.
+
+ specialize (IHrow1 (pred n) (S m)).
+ replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega.
+ rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
+ rewrite <-fold_left_rev_right.
+ apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega;
+ repeat match goal with
+ | H : ?LHS = _ |- _ =>
+ match LHS with context [start_state] => rewrite H end
+ | H : length _ = _ |- _ => rewrite H
+ | _ => rewrite <-Z.div_add by auto
+ | _ => rewrite Z.div_div by auto using Z.gt_lt
+ | _ => rewrite Z.mul_div_eq by auto
+ | _ => rewrite weight_multiples
+ | _ => erewrite Positional.eval_snoc by eauto
+ | _ => progress autorewrite with cancel_pair distr_length to_div_mod in *
+ | |- context [ ?x mod ?m + ?m * (((?x + ?a * ?m + ?b * ?m)/ ?m) mod ?c) ] =>
+ replace (x mod m) with ((x + a * m + b * m) mod m) by
+ (autorewrite with zsimplify; ring);
+ rewrite <-Z.rem_mul_r by auto using Z.gt_lt
+ | _ => f_equal; ring
+ end.
+ Qed.
+
+ Lemma sum_rows_div_mod n row1 row2 :
+ length row1 = n -> length row2 = n ->
+ let eval := Positional.eval weight in
+ eval n (fst (sum_rows row1 row2)) = (eval n row1 + eval n row2) mod (weight n)
+ /\ snd (sum_rows row1 row2) = (eval n row1 + eval n row2) / (weight n).
+ Proof.
+ cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
+ edestruct sum_rows'_div_mod_length as [Hlen [Hmod Hdiv] ];
+ try erewrite Hmod, Hdiv; auto using nil_length0;
+ autorewrite with cancel_pair push_eval zsimplify_fast; distr_length.
+ Qed.
+
+ Lemma sum_rows_mod n row1 row2 :
+ length row1 = n -> length row2 = n ->
+ Positional.eval weight n (fst (sum_rows row1 row2))
+ = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n).
+ Proof. apply sum_rows_div_mod. Qed.
+ Lemma sum_rows_div row1 row2 n:
+ length row1 = n -> length row2 = n ->
+ snd (sum_rows row1 row2)
+ = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n).
+ Proof. apply sum_rows_div_mod. Qed.
+
+ (* TODO: figure out where to put this and weight_multiples_full *)
+ Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0.
+ Proof. auto using Z.div_positive_gt_0, Columns.weight_multiples_full. Qed.
+
+ Lemma sum_rows'_partitions row1 :
+ forall n m start_state row2 row1' row2',
+ length (fst start_state) = m -> length row1 = n -> length row2 = n ->
+ length row1' = m -> length row2' = m ->
+ let nm := (n + m)%nat in
+ let eval := Positional.eval weight in
+ snd start_state = (eval m row1' + eval m row2') / weight m ->
+ (forall j, (j < m)%nat ->
+ nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) ->
+ forall i, (i < nm)%nat ->
+ nth_default 0 (fst (sum_rows' start_state row1 row2)) i
+ = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i).
+ Proof.
+ cbv [sum_rows' Let_In].
+ induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *;
+ destruct row2 as [|x2 row2]; distr_length; [ subst n | ];
+ repeat match goal with
+ | _ => progress cbn [fold_left]
+ | H : length _ = _ |- _ => rewrite H
+ | H: _ |- _ => solve [apply H; omega]
+ | _ => progress autorewrite with push_eval zsimplify_fast push_combine list natsimplify cancel_pair to_div_mod
+ end.
+
+ specialize (IHrow1 (pred n) (S m)).
+ replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega.
+ rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
+ rewrite <-fold_left_rev_right.
+ apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega;
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress autorewrite with push_nth_default
+ | _ => progress break_match
+ | H : length _ = _ |- _ => rewrite H
+ | H : ?LHS = _ |- _ =>
+ match LHS with context [start_state] => rewrite H end
+ | H : context [nth_default 0 (fst start_state)] |- _ => rewrite H by omega
+ | _ => rewrite Positional.eval_snoc with (n:=m) by eauto
+ end.
+ { erewrite @Columns.weight_div_mod with (j:=S m) (i:=m) by eauto.
+ rewrite <-Z.div_div by auto using Z.gt_lt.
+ autorewrite with zsimplify.
+ f_equal; ring. }
+ { erewrite @Columns.weight_div_mod with (j:=m) (i:=S j) by (eauto; omega).
+ push_Zmod. autorewrite with zsimplify. lia. }
+ { replace j with m by omega.
+ autorewrite with push_nth_default natsimplify.
+ rewrite <-!Z.div_add' by auto.
+ rewrite Z.mod_pull_div, Z.mul_div_eq' by (auto using Z.lt_le_incl, Z.gt_lt).
+ rewrite weight_multiples.
+ autorewrite with zsimplify_fast.
+ repeat (f_equal; try ring). }
+ Qed.
+
+ Lemma sum_rows_partitions row1: forall row2 n i,
+ length row1 = n -> length row2 = n -> (i < n)%nat ->
+ nth_default 0 (fst (sum_rows row1 row2)) i
+ = ((Positional.eval weight n row1 + Positional.eval weight n row2) mod weight (S i)) / (weight i).
+ Proof.
+ cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
+ rewrite <-(app_nil_l row1), <-(app_nil_l row2).
+ apply sum_rows'_partitions; intros;
+ autorewrite with cancel_pair push_eval zsimplify_fast push_nth_default; distr_length.
+ rewrite Z.div_0_l by auto; omega.
+ Qed.
+
+ Lemma length_sum_rows row1 row2 n :
+ length row1 = n -> length row2 = n ->
+ length (fst (sum_rows row1 row2)) = n.
+ Proof.
+ cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
+ eapply sum_rows'_div_mod_length; auto using nil_length0.
+ Qed. Hint Rewrite length_sum_rows : distr_length.
+ End SumRows.
+ Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval.
+
+
+ Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z :=
+ fold_right (fun next_row (state : list Z * Z)=>
+ let out_carry := sum_rows next_row (fst state) in
+ (fst out_carry, snd state + snd out_carry)) start_state inp.
+
+ (* For correctness if there is only one row, we add a row of
+ zeroes with the same length so that the add loop still happens. *)
+ Definition flatten (inp : rows) : list Z * Z :=
+ let first_row := hd nil inp in
+ flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)).
+
+ (* TODO: move to ListUtil *)
+ (* connect state to remaining input *)
+ Lemma fold_right_invariant_strong {A B} (P: list B -> A -> Type) (f: B -> A -> A):
+ forall bs a0
+ (Pnil : P nil a0)
+ (IHfold:
+ forall bs' b a',
+ In b bs ->
+ a' = fold_right f a0 bs' ->
+ P bs' a' ->
+ P (b :: bs') (f b a')),
+ P bs (fold_right f a0 bs).
+ Proof. induction bs; intros; autorewrite with push_fold_right; auto using in_eq,in_cons. Qed.
+
+ Lemma flatten'_div_mod_length n start_state inp:
+ length (fst start_state) = n ->
+ (forall row, In row inp -> length row = n) ->
+ length (fst (flatten' start_state inp)) = n
/\ (inp <> nil ->
Positional.eval weight n (fst (flatten' start_state inp)) = (Positional.eval weight n (fst start_state) + eval n inp) mod (weight n)
/\ snd (flatten' start_state inp) = snd start_state + (Positional.eval weight n (fst start_state) + eval n inp) / weight n).
- Proof.
- intro.
- cbv [flatten'].
- apply fold_right_invariant_strong with (bs:=inp); intros; [ tauto | destruct (dec (bs' = nil)) ];
- repeat match goal with
- | _ => subst a'
- | _ => subst bs'
- | _ => rewrite sum_rows_div with (n:=n) by auto
- | _ => rewrite @fold_right_nil in *
- | _ => progress autorewrite with cancel_pair push_eval in *
- | H : _ -> _ /\ _ |- _ =>
- let X := fresh in let Y := fresh in
- destruct H as [X Y]; [ solve [auto using in_cons] | rewrite ?X, ?Y ]
- | _ => split
- | _ => progress autorewrite with pull_Zmod zsimplify_fast
- | _ => solve [ auto using length_sum_rows ]
- | _ => solve [ ring_simplify; repeat (f_equal; try ring) ]
- end.
- apply Z.mul_cancel_l with (p:=weight n); [ apply weight_nonzero |].
- autorewrite with push_Zmul.
- rewrite !Z.mul_div_eq_full by apply weight_nonzero.
- autorewrite with pull_Zmod.
- ring_simplify_subterms. ring_simplify.
- repeat (f_equal; try ring).
- Qed.
+ Proof.
+ intro.
+ cbv [flatten'].
+ apply fold_right_invariant_strong with (bs:=inp); intros; [ tauto | destruct (dec (bs' = nil)) ];
+ repeat match goal with
+ | _ => subst a'
+ | _ => subst bs'
+ | _ => rewrite sum_rows_div with (n:=n) by auto
+ | _ => rewrite @fold_right_nil in *
+ | _ => progress autorewrite with cancel_pair push_eval in *
+ | H : _ -> _ /\ _ |- _ =>
+ let X := fresh in let Y := fresh in
+ destruct H as [X Y]; [ solve [auto using in_cons] | rewrite ?X, ?Y ]
+ | _ => split
+ | _ => progress autorewrite with pull_Zmod zsimplify_fast
+ | _ => solve [ auto using length_sum_rows ]
+ | _ => solve [ ring_simplify; repeat (f_equal; try ring) ]
+ end.
+ apply Z.mul_cancel_l with (p:=weight n); [ apply weight_nonzero |].
+ autorewrite with push_Zmul.
+ rewrite !Z.mul_div_eq_full by apply weight_nonzero.
+ autorewrite with pull_Zmod.
+ ring_simplify_subterms. ring_simplify.
+ repeat (f_equal; try ring).
+ Qed.
- Hint Rewrite (@Positional.length_zeros weight) : distr_length.
- Hint Rewrite (@Positional.eval_zeros weight) using auto : push_eval.
+ Hint Rewrite (@Positional.length_zeros weight) : distr_length.
+ Hint Rewrite (@Positional.eval_zeros weight) using auto : push_eval.
- Lemma flatten_div_mod inp n :
- (forall row, In row inp -> length row = n) ->
- Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)
- /\ snd (flatten inp) = (eval n inp) / weight n.
- Proof.
- intros; cbv [flatten].
- destruct inp; [|destruct inp]; cbn [hd tl].
- { cbn. autorewrite with push_eval. tauto. }
- { cbn.
- match goal with H : forall r, In r [?x] -> _ |- _ =>
- specialize (H x ltac:(auto)); rewrite H
- end.
- rewrite sum_rows_div with (n:=n) by distr_length.
- autorewrite with push_eval.
- split; f_equal; ring. }
- { autorewrite with push_eval.
- apply flatten'_div_mod_length; auto.
- congruence. }
- Qed.
+ Lemma flatten_div_mod inp n :
+ (forall row, In row inp -> length row = n) ->
+ Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)
+ /\ snd (flatten inp) = (eval n inp) / weight n.
+ Proof.
+ intros; cbv [flatten].
+ destruct inp; [|destruct inp]; cbn [hd tl].
+ { cbn. autorewrite with push_eval. tauto. }
+ { cbn.
+ match goal with H : forall r, In r [?x] -> _ |- _ =>
+ specialize (H x ltac:(auto)); rewrite H
+ end.
+ rewrite sum_rows_div with (n:=n) by distr_length.
+ autorewrite with push_eval.
+ split; f_equal; ring. }
+ { autorewrite with push_eval.
+ apply flatten'_div_mod_length; auto.
+ congruence. }
+ Qed.
- Lemma flatten_mod inp n :
- (forall row, In row inp -> length row = n) ->
- Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n).
- Proof. apply flatten_div_mod. Qed.
- Lemma flatten_div inp n :
- (forall row, In row inp -> length row = n) ->
- snd (flatten inp) = (eval n inp) / (weight n).
- Proof. apply flatten_div_mod. Qed.
-
- Lemma length_flatten' n start_state inp :
- length (fst start_state) = n ->
- (forall row, In row inp -> length row = n) ->
- length (fst (flatten' start_state inp)) = n.
- Proof. apply flatten'_div_mod_length. Qed.
- Hint Rewrite length_flatten' : distr_length.
-
- Lemma length_flatten n inp :
- (forall row, In row inp -> length row = n) ->
- inp <> nil ->
- length (fst (flatten inp)) = n.
- Proof.
- intros.
- destruct inp; [congruence |]; destruct inp;
- repeat match goal with
- | _ => progress intros
- | _ => progress cbn [hd tl] in *
- | _ => progress autorewrite with cancel_pair
- | _ => apply flatten'_div_mod_length
- | H: _ |- _ => apply in_inv in H; destruct H
- | _ => solve [auto]
- end;
- subst row; distr_length; auto.
- Qed. Hint Rewrite length_flatten : distr_length.
-
- Lemma flatten'_cons state x inp :
- flatten' state (x :: inp)
- = (fst (sum_rows x (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows x (fst (flatten' state inp)))).
- Proof. reflexivity. Qed.
-
- Lemma flatten'_partitions n start_state inp:
- length (fst start_state) = n ->
- (forall row, In row inp -> length row = n) ->
- inp <> nil ->
- forall i, (i < n)%nat ->
- nth_default 0 (fst (flatten' start_state inp)) i
- = ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i).
- Proof.
- intro.
- induction inp; intros; [congruence|].
- rewrite flatten'_cons.
- autorewrite with push_fold_right cancel_pair.
- rewrite sum_rows_partitions with (n:=n) by (rewrite ?length_flatten' with (n:=n); auto).
- destruct (dec (inp = nil)).
- { subst inp. cbn. repeat (f_equal; try ring). }
- { edestruct flatten'_div_mod_length with (inp:=inp) as [Hlen [Hmod Hdiv] ]; eauto.
- rewrite Hmod. autorewrite with push_eval.
- rewrite Z.add_mod_full. mul_div_weights n (S i).
- rewrite Z.rem_mul_r by auto using Z.gt_lt, weight_divides_full.
- autorewrite with zsimplify. pull_Zmod.
- repeat (f_equal; try ring). }
- Qed.
+ Lemma flatten_mod inp n :
+ (forall row, In row inp -> length row = n) ->
+ Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n).
+ Proof. apply flatten_div_mod. Qed.
+ Lemma flatten_div inp n :
+ (forall row, In row inp -> length row = n) ->
+ snd (flatten inp) = (eval n inp) / (weight n).
+ Proof. apply flatten_div_mod. Qed.
- Lemma flatten_partitions inp n :
- (forall row, In row inp -> length row = n) ->
- forall i, (i < n)%nat ->
- nth_default 0 (fst (flatten inp)) i = (eval n inp mod weight (S i)) / (weight i).
- Proof.
- intros; cbv [flatten].
- destruct inp; [|destruct inp]; cbn [hd tl].
- { cbn. autorewrite with push_eval push_nth_default zsimplify. reflexivity. }
- { cbn.
- match goal with H : forall r, In r [?x] -> _ |- _ =>
- specialize (H x ltac:(auto)); rewrite H
- end.
- rewrite sum_rows_partitions with (n:=n) by distr_length.
- autorewrite with push_eval.
- repeat (f_equal; try ring). }
- { autorewrite with push_eval.
- apply flatten'_partitions; auto.
- congruence. }
+ Lemma length_flatten' n start_state inp :
+ length (fst start_state) = n ->
+ (forall row, In row inp -> length row = n) ->
+ length (fst (flatten' start_state inp)) = n.
+ Proof. apply flatten'_div_mod_length. Qed.
+ Hint Rewrite length_flatten' : distr_length.
+
+ Lemma length_flatten n inp :
+ (forall row, In row inp -> length row = n) ->
+ inp <> nil ->
+ length (fst (flatten inp)) = n.
+ Proof.
+ intros.
+ destruct inp; [congruence |]; destruct inp;
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress cbn [hd tl] in *
+ | _ => progress autorewrite with cancel_pair
+ | _ => apply flatten'_div_mod_length
+ | H: _ |- _ => apply in_inv in H; destruct H
+ | _ => solve [auto]
+ end;
+ subst row; distr_length; auto.
+ Qed. Hint Rewrite length_flatten : distr_length.
+
+ Lemma flatten'_cons state x inp :
+ flatten' state (x :: inp)
+ = (fst (sum_rows x (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows x (fst (flatten' state inp)))).
+ Proof. reflexivity. Qed.
+
+ Lemma flatten'_partitions n start_state inp:
+ length (fst start_state) = n ->
+ (forall row, In row inp -> length row = n) ->
+ inp <> nil ->
+ forall i, (i < n)%nat ->
+ nth_default 0 (fst (flatten' start_state inp)) i
+ = ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i).
+ Proof.
+ intro.
+ induction inp; intros; [congruence|].
+ rewrite flatten'_cons.
+ autorewrite with push_fold_right cancel_pair.
+ rewrite sum_rows_partitions with (n:=n) by (rewrite ?length_flatten' with (n:=n); auto).
+ destruct (dec (inp = nil)).
+ { subst inp. cbn. repeat (f_equal; try ring). }
+ { edestruct flatten'_div_mod_length with (inp:=inp) as [Hlen [Hmod Hdiv] ]; eauto.
+ rewrite Hmod. autorewrite with push_eval.
+ rewrite Z.add_mod_full. erewrite @Columns.weight_div_mod with (j:=n) (i:=S i) by eauto.
+ rewrite Z.rem_mul_r by auto using Z.gt_lt, weight_divides_full.
+ autorewrite with zsimplify. pull_Zmod.
+ repeat (f_equal; try ring). }
Qed.
+ Lemma flatten_partitions inp n :
+ (forall row, In row inp -> length row = n) ->
+ forall i, (i < n)%nat ->
+ nth_default 0 (fst (flatten inp)) i = (eval n inp mod weight (S i)) / (weight i).
+ Proof.
+ intros; cbv [flatten].
+ destruct inp; [|destruct inp]; cbn [hd tl].
+ { cbn. autorewrite with push_eval push_nth_default zsimplify. reflexivity. }
+ { cbn.
+ match goal with H : forall r, In r [?x] -> _ |- _ =>
+ specialize (H x ltac:(auto)); rewrite H
+ end.
+ rewrite sum_rows_partitions with (n:=n) by distr_length.
+ autorewrite with push_eval.
+ repeat (f_equal; try ring). }
+ { autorewrite with push_eval.
+ apply flatten'_partitions; auto.
+ congruence. }
+ Qed.
+ End Flatten.
+
End Rows.
End Rows.