diff options
author | Jade Philipoom <jadep@google.com> | 2018-03-07 16:07:12 +0100 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2018-04-03 09:00:55 -0400 |
commit | a4fd850b6d41a20e0282685186216ea82c707ce8 (patch) | |
tree | bea9fc16faa158c391574feac8b3b94eee6076f0 /src | |
parent | 8690fc7b8054901178a0756d1fb6f47342a2bd55 (diff) |
organize Rows into sections
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 934 |
1 files changed, 469 insertions, 465 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 48661f4c6..4ca2e89be 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -976,496 +976,500 @@ Module Rows. Proof. cbv [eval]; autorewrite with push_map push_sum; reflexivity. Qed. Hint Rewrite eval_app : push_eval. - Definition extract_row (inp : cols) : cols * list Z := (map (fun c => tl c) inp, map (fun c => hd 0 c) inp). - - Definition max_column_size (x:cols) := fold_right (fun a b => Nat.max a b) 0%nat (map (fun c => length c) x). - - Definition from_columns' n start_state : cols * rows := - fold_right (fun _ (state : cols * rows) => - let cols'_row := extract_row (fst state) in - (fst cols'_row, snd state ++ [snd cols'_row]) - ) start_state (List.repeat 0 n). + Section FromAssociational. + (* extract row *) + Definition extract_row (inp : cols) : cols * list Z := (map (fun c => tl c) inp, map (fun c => hd 0 c) inp). - Definition from_columns (inp : cols) : rows := snd (from_columns' (max_column_size inp) (inp, [])). + Lemma eval_extract_row (inp : cols): forall n, + length inp = n -> + Positional.eval weight n (snd (extract_row inp)) = Columns.eval weight n inp - Columns.eval weight n (fst (extract_row inp)) . + Proof. + cbv [extract_row]. + induction inp using rev_ind; [ | destruct n ]; + repeat match goal with + | _ => progress intros + | _ => progress distr_length + | _ => rewrite Positional.eval_snoc with (n:=n) by distr_length + | _ => progress autorewrite with cancel_pair push_eval push_map in * + | _ => ring + end. + rewrite IHinp by distr_length. + destruct x; cbn [hd tl]; rewrite ?sum_nil, ?sum_cons; ring. + Qed. Hint Rewrite eval_extract_row using (solve [distr_length]) : push_eval. + + Lemma length_fst_extract_row n (inp : cols) : + length inp = n -> length (fst (extract_row inp)) = n. + Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed. + Hint Rewrite length_fst_extract_row : distr_length. + + Lemma length_snd_extract_row n (inp : cols) : + length inp = n -> length (snd (extract_row inp)) = n. + Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed. + Hint Rewrite length_snd_extract_row : distr_length. + + (* max column size *) + Definition max_column_size (x:cols) := fold_right (fun a b => Nat.max a b) 0%nat (map (fun c => length c) x). + + (* TODO: move to where list is defined *) + Hint Rewrite @app_nil_l : list. + Hint Rewrite <-@app_comm_cons: list. + + Lemma max_column_size_nil : max_column_size nil = 0%nat. + Proof. reflexivity. Qed. Hint Rewrite max_column_size_nil : push_max_column_size. + Lemma max_column_size_cons col (inp : cols) : + max_column_size (col :: inp) = Nat.max (length col) (max_column_size inp). + Proof. reflexivity. Qed. Hint Rewrite max_column_size_cons : push_max_column_size. + Lemma max_column_size_app (x y : cols) : + max_column_size (x ++ y) = Nat.max (max_column_size x) (max_column_size y). + Proof. induction x; autorewrite with list push_max_column_size; lia. Qed. + Hint Rewrite max_column_size_app : push_max_column_size. + Lemma max_column_size0 (inp : cols) : + forall n, + length inp = n -> (* this is not needed to make the lemma true, but prevents reliance on the implementation of Columns.eval*) + max_column_size inp = 0%nat -> Columns.eval weight n inp = 0. + Proof. + induction inp as [|x inp] using rev_ind; destruct n; try destruct x; intros; + autorewrite with push_max_column_size push_eval push_sum distr_length in *; try lia. + rewrite IHinp; distr_length; lia. + Qed. - Definition from_associational n (p : list (Z * Z)) := from_columns (Columns.from_associational weight n p). - - Lemma eval_extract_row (inp : cols): forall n, - length inp = n -> - Positional.eval weight n (snd (extract_row inp)) = Columns.eval weight n inp - Columns.eval weight n (fst (extract_row inp)) . - Proof. - cbv [extract_row]. - induction inp using rev_ind; [ | destruct n ]; - repeat match goal with - | _ => progress intros - | _ => progress distr_length - | _ => rewrite Positional.eval_snoc with (n:=n) by distr_length - | _ => progress autorewrite with cancel_pair push_eval push_map in * - | _ => ring - end. - rewrite IHinp by distr_length. - destruct x; cbn [hd tl]; rewrite ?sum_nil, ?sum_cons; ring. - Qed. Hint Rewrite eval_extract_row using (solve [distr_length]) : push_eval. - - Lemma length_fst_extract_row n (inp : cols) : - length inp = n -> length (fst (extract_row inp)) = n. - Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed. - Hint Rewrite length_fst_extract_row : distr_length. - - Lemma length_snd_extract_row n (inp : cols) : - length inp = n -> length (snd (extract_row inp)) = n. - Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed. - Hint Rewrite length_snd_extract_row : distr_length. - - (* TODO: move to where list is defined *) - Hint Rewrite @app_nil_l : list. - Hint Rewrite <-@app_comm_cons: list. - - Lemma max_column_size_nil : max_column_size nil = 0%nat. - Proof. reflexivity. Qed. Hint Rewrite max_column_size_nil : push_max_column_size. - Lemma max_column_size_cons col (inp : cols) : - max_column_size (col :: inp) = Nat.max (length col) (max_column_size inp). - Proof. reflexivity. Qed. Hint Rewrite max_column_size_cons : push_max_column_size. - Lemma max_column_size_app (x y : cols) : - max_column_size (x ++ y) = Nat.max (max_column_size x) (max_column_size y). - Proof. induction x; autorewrite with list push_max_column_size; lia. Qed. - Hint Rewrite max_column_size_app : push_max_column_size. - Lemma max_column_size0 (inp : cols) : - forall n, - length inp = n -> (* this is not needed to make the lemma true, but prevents reliance on the implementation of Columns.eval*) - max_column_size inp = 0%nat -> Columns.eval weight n inp = 0. - Proof. - induction inp as [|x inp] using rev_ind; destruct n; try destruct x; intros; - autorewrite with push_max_column_size push_eval push_sum distr_length in *; try lia. - rewrite IHinp; distr_length; lia. - Qed. + (* from_columns *) + Definition from_columns' n start_state : cols * rows := + fold_right (fun _ (state : cols * rows) => + let cols'_row := extract_row (fst state) in + (fst cols'_row, snd state ++ [snd cols'_row]) + ) start_state (List.repeat 0 n). - Local Ltac In_cases := - repeat match goal with - | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H - | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H - | H: In _ nil |- _ => contradiction H - end. + Definition from_columns (inp : cols) : rows := snd (from_columns' (max_column_size inp) (inp, [])). - Lemma eval_from_columns'_with_length m st n: - (length (fst st) = n) -> - length (fst (from_columns' m st)) = n /\ - ((forall r, In r (snd st) -> length r = n) -> - forall r, In r (snd (from_columns' m st)) -> length r = n) /\ - eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st) - - Columns.eval weight n (fst (from_columns' m st)). - Proof. - cbv [from_columns']; intros. - apply fold_right_invariant; intros; + Local Ltac In_cases := repeat match goal with - | _ => progress (intros; subst) - | _ => progress autorewrite with cancel_pair push_eval - | _ => progress In_cases - | _ => split; try omega - | H: _ /\ _ |- _ => destruct H - | _ => solve [auto using length_fst_extract_row, length_snd_extract_row] + | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H + | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H + | H: In _ nil |- _ => contradiction H end. - Qed. - Lemma length_fst_from_columns' m st : - length (fst (from_columns' m st)) = length (fst st). - Proof. apply eval_from_columns'_with_length; reflexivity. Qed. - Hint Rewrite length_fst_from_columns' : distr_length. - Lemma length_snd_from_columns' m st : - (forall r, In r (snd st) -> length r = length (fst st)) -> - forall r, In r (snd (from_columns' m st)) -> length r = length (fst st). - Proof. apply eval_from_columns'_with_length. reflexivity. Qed. - Hint Rewrite length_snd_from_columns' : distr_length. - Lemma eval_from_columns' m st n : - (length (fst st) = n) -> - eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st) - - Columns.eval weight n (fst (from_columns' m st)). - Proof. apply eval_from_columns'_with_length. Qed. - Hint Rewrite eval_from_columns' using (auto; solve [distr_length]) : push_eval. - - Lemma max_column_size_extract_row inp : - max_column_size (fst (extract_row inp)) = (max_column_size inp - 1)%nat. - Proof. - cbv [extract_row]. autorewrite with cancel_pair. - induction inp; [ reflexivity | ]. - autorewrite with push_max_column_size push_map distr_length. - rewrite IHinp. auto using Nat.sub_max_distr_r. - Qed. - Hint Rewrite max_column_size_extract_row : push_max_column_size. - Lemma max_column_size_from_columns' m st : - max_column_size (fst (from_columns' m st)) = (max_column_size (fst st) - m)%nat. - Proof. - cbv [from_columns']; induction m; intros; cbn - [max_column_size extract_row]; - autorewrite with push_max_column_size; lia. - Qed. - Hint Rewrite max_column_size_from_columns' : push_max_column_size. - - Lemma eval_from_columns (inp : cols) : - forall n, length inp = n -> eval n (from_columns inp) = Columns.eval weight n inp. - Proof. - intros; cbv [from_columns]; - repeat match goal with - | _ => progress autorewrite with cancel_pair push_eval push_max_column_size - | _ => rewrite max_column_size0 with (inp := fst (from_columns' _ _)) by - (autorewrite with push_max_column_size; distr_length) - | _ => omega - end. - Qed. - Hint Rewrite eval_from_columns using (auto; solve [distr_length]) : push_eval. - - Lemma length_from_columns inp: - forall r, In r (from_columns inp) -> length r = length inp. - Proof. - cbv [from_columns]; intros. - change inp with (fst (inp, @nil (list Z))). - eapply length_snd_from_columns'; eauto. - autorewrite with cancel_pair; intros; In_cases. - Qed. - Hint Rewrite length_from_columns : distr_length. - - Lemma eval_from_associational n p: (n <> 0%nat \/ p = nil) -> - eval n (from_associational n p) = Associational.eval p. - Proof. - intros. cbv [from_associational]. - rewrite eval_from_columns by auto using Columns.length_from_associational. - auto using Columns.eval_from_associational. - Qed. + Lemma eval_from_columns'_with_length m st n: + (length (fst st) = n) -> + length (fst (from_columns' m st)) = n /\ + ((forall r, In r (snd st) -> length r = n) -> + forall r, In r (snd (from_columns' m st)) -> length r = n) /\ + eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st) + - Columns.eval weight n (fst (from_columns' m st)). + Proof. + cbv [from_columns']; intros. + apply fold_right_invariant; intros; + repeat match goal with + | _ => progress (intros; subst) + | _ => progress autorewrite with cancel_pair push_eval + | _ => progress In_cases + | _ => split; try omega + | H: _ /\ _ |- _ => destruct H + | _ => solve [auto using length_fst_extract_row, length_snd_extract_row] + end. + Qed. + Lemma length_fst_from_columns' m st : + length (fst (from_columns' m st)) = length (fst st). + Proof. apply eval_from_columns'_with_length; reflexivity. Qed. + Hint Rewrite length_fst_from_columns' : distr_length. + Lemma length_snd_from_columns' m st : + (forall r, In r (snd st) -> length r = length (fst st)) -> + forall r, In r (snd (from_columns' m st)) -> length r = length (fst st). + Proof. apply eval_from_columns'_with_length. reflexivity. Qed. + Hint Rewrite length_snd_from_columns' : distr_length. + Lemma eval_from_columns' m st n : + (length (fst st) = n) -> + eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st) + - Columns.eval weight n (fst (from_columns' m st)). + Proof. apply eval_from_columns'_with_length. Qed. + Hint Rewrite eval_from_columns' using (auto; solve [distr_length]) : push_eval. + + Lemma max_column_size_extract_row inp : + max_column_size (fst (extract_row inp)) = (max_column_size inp - 1)%nat. + Proof. + cbv [extract_row]. autorewrite with cancel_pair. + induction inp; [ reflexivity | ]. + autorewrite with push_max_column_size push_map distr_length. + rewrite IHinp. auto using Nat.sub_max_distr_r. + Qed. + Hint Rewrite max_column_size_extract_row : push_max_column_size. - Lemma length_from_associational n p : - forall r, In r (from_associational n p) -> length r = n. - Proof. - cbv [from_associational]; intros. - match goal with H: _ |- _ => apply length_from_columns in H end. - rewrite Columns.length_from_associational in *; auto. - Qed. + Lemma max_column_size_from_columns' m st : + max_column_size (fst (from_columns' m st)) = (max_column_size (fst st) - m)%nat. + Proof. + cbv [from_columns']; induction m; intros; cbn - [max_column_size extract_row]; + autorewrite with push_max_column_size; lia. + Qed. + Hint Rewrite max_column_size_from_columns' : push_max_column_size. - Local Notation fw := (fun i => weight (S i) / weight i) (only parsing). - - Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z := - fold_right (fun next (state : list Z * Z) => - let i := length (fst state) in (* length of output accumulator tells us the index of [next] *) - dlet_nd next := next in (* makes the output correctly bind variables *) - dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in - (fst state ++ [fst sum_carry], snd sum_carry)) start_state (rev (combine row1 row2)). - Definition sum_rows := sum_rows' (nil,0). - - Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z := - fold_right (fun next_row (state : list Z * Z)=> - let out_carry := sum_rows next_row (fst state) in - (fst out_carry, snd state + snd out_carry)) start_state inp. - - (* For correctness if there is only one row, we add a row of - zeroes with the same length so that the add loop still happens. *) - Definition flatten (inp : rows) : list Z * Z := - let first_row := hd nil inp in - flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)). - - Lemma sum_rows'_div_mod_length row1 : - forall n m start_state row2 row1' row2', - length (fst start_state) = m -> length row1 = n -> length row2 = n -> - length row1' = m -> length row2' = m -> - let nm : nat := (n + m)%nat in - let eval := Positional.eval weight in - eval m (fst start_state) = (eval m row1' + eval m row2') mod (weight m) -> - snd start_state = (eval m row1' + eval m row2') / weight m -> - length (fst (sum_rows' start_state row1 row2)) = (n + m)%nat - /\ eval nm (fst (sum_rows' start_state row1 row2)) - = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight nm) - /\ snd (sum_rows' start_state row1 row2) - = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / (weight nm). - Proof. - cbv [sum_rows' Let_In]. - induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *; - destruct row2 as [|x2 row2]; distr_length; [ subst n | ]; + Lemma eval_from_columns (inp : cols) : + forall n, length inp = n -> eval n (from_columns inp) = Columns.eval weight n inp. + Proof. + intros; cbv [from_columns]; repeat match goal with - | _ => progress autorewrite with natsimplify list - | _ => progress cbn [fold_left combine] + | _ => progress autorewrite with cancel_pair push_eval push_max_column_size + | _ => rewrite max_column_size0 with (inp := fst (from_columns' _ _)) by + (autorewrite with push_max_column_size; distr_length) | _ => omega end. + Qed. + Hint Rewrite eval_from_columns using (auto; solve [distr_length]) : push_eval. - specialize (IHrow1 (pred n) (S m)). - replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega. - rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). - rewrite <-fold_left_rev_right. - apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega; - repeat match goal with - | H : ?LHS = _ |- _ => - match LHS with context [start_state] => rewrite H end - | H : length _ = _ |- _ => rewrite H - | _ => rewrite <-Z.div_add by auto - | _ => rewrite Z.div_div by auto using Z.gt_lt - | _ => rewrite Z.mul_div_eq by auto - | _ => rewrite weight_multiples - | _ => erewrite Positional.eval_snoc by eauto - | _ => progress autorewrite with cancel_pair distr_length to_div_mod in * - | |- context [ ?x mod ?m + ?m * (((?x + ?a * ?m + ?b * ?m)/ ?m) mod ?c) ] => - replace (x mod m) with ((x + a * m + b * m) mod m) by - (autorewrite with zsimplify; ring); - rewrite <-Z.rem_mul_r by auto using Z.gt_lt - | _ => f_equal; ring - end. - Qed. - - Lemma sum_rows_div_mod n row1 row2 : - length row1 = n -> length row2 = n -> - let eval := Positional.eval weight in - eval n (fst (sum_rows row1 row2)) = (eval n row1 + eval n row2) mod (weight n) - /\ snd (sum_rows row1 row2) = (eval n row1 + eval n row2) / (weight n). - Proof. - cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). - edestruct sum_rows'_div_mod_length as [Hlen [Hmod Hdiv] ]; - try erewrite Hmod, Hdiv; auto using nil_length0; - autorewrite with cancel_pair push_eval zsimplify_fast; distr_length. - Qed. + Lemma length_from_columns inp: + forall r, In r (from_columns inp) -> length r = length inp. + Proof. + cbv [from_columns]; intros. + change inp with (fst (inp, @nil (list Z))). + eapply length_snd_from_columns'; eauto. + autorewrite with cancel_pair; intros; In_cases. + Qed. + Hint Rewrite length_from_columns : distr_length. - Lemma sum_rows_mod n row1 row2 : - length row1 = n -> length row2 = n -> - Positional.eval weight n (fst (sum_rows row1 row2)) - = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n). - Proof. apply sum_rows_div_mod. Qed. - Lemma sum_rows_div row1 row2 n: - length row1 = n -> length row2 = n -> - snd (sum_rows row1 row2) - = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n). - Proof. apply sum_rows_div_mod. Qed. - - Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval. - Local Ltac mul_div_weights i j := - rewrite (Z.div_mod (weight i) (weight j)) by auto; - rewrite Columns.weight_multiples_full by (auto; omega); - autorewrite with zsimplify_fast. - - (* TODO: figure out where to put this and weight_multiples_full *) - Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0. - Proof. auto using Z.div_positive_gt_0, Columns.weight_multiples_full. Qed. - - Lemma sum_rows'_partitions row1 : - forall n m start_state row2 row1' row2', - length (fst start_state) = m -> length row1 = n -> length row2 = n -> - length row1' = m -> length row2' = m -> - let nm := (n + m)%nat in - let eval := Positional.eval weight in - snd start_state = (eval m row1' + eval m row2') / weight m -> - (forall j, (j < m)%nat -> - nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) -> - forall i, (i < nm)%nat -> - nth_default 0 (fst (sum_rows' start_state row1 row2)) i - = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i). - Proof. - cbv [sum_rows' Let_In]. - induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *; - destruct row2 as [|x2 row2]; distr_length; [ subst n | ]; - repeat match goal with - | _ => progress cbn [fold_left] - | H : length _ = _ |- _ => rewrite H - | H: _ |- _ => solve [apply H; omega] - | _ => progress autorewrite with push_eval zsimplify_fast push_combine list natsimplify cancel_pair to_div_mod - end. + (* from associational *) + Definition from_associational n (p : list (Z * Z)) := from_columns (Columns.from_associational weight n p). + + Lemma eval_from_associational n p: (n <> 0%nat \/ p = nil) -> + eval n (from_associational n p) = Associational.eval p. + Proof. + intros. cbv [from_associational]. + rewrite eval_from_columns by auto using Columns.length_from_associational. + auto using Columns.eval_from_associational. + Qed. - specialize (IHrow1 (pred n) (S m)). - replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega. - rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). - rewrite <-fold_left_rev_right. - apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega; - repeat match goal with - | _ => progress intros - | _ => progress autorewrite with push_nth_default - | _ => progress break_match - | H : length _ = _ |- _ => rewrite H - | H : ?LHS = _ |- _ => - match LHS with context [start_state] => rewrite H end - | H : context [nth_default 0 (fst start_state)] |- _ => rewrite H by omega - | _ => rewrite Positional.eval_snoc with (n:=m) by eauto - end. - { mul_div_weights (S m) m. - rewrite <-Z.div_div by auto using Z.gt_lt. - autorewrite with zsimplify. - f_equal; ring. } - { mul_div_weights m (S j). - push_Zmod. autorewrite with zsimplify. - lia. } - { replace j with m by omega. - autorewrite with push_nth_default natsimplify. - rewrite <-!Z.div_add' by auto. - rewrite Z.mod_pull_div, Z.mul_div_eq' by (auto using Z.lt_le_incl, Z.gt_lt). - rewrite weight_multiples. - autorewrite with zsimplify_fast. - repeat (f_equal; try ring). } - Qed. - - Lemma sum_rows_partitions row1: forall row2 n i, - length row1 = n -> length row2 = n -> (i < n)%nat -> - nth_default 0 (fst (sum_rows row1 row2)) i - = ((Positional.eval weight n row1 + Positional.eval weight n row2) mod weight (S i)) / (weight i). - Proof. - cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). - rewrite <-(app_nil_l row1), <-(app_nil_l row2). - apply sum_rows'_partitions; intros; - autorewrite with cancel_pair push_eval zsimplify_fast push_nth_default; distr_length. - rewrite Z.div_0_l by auto; omega. - Qed. + Lemma length_from_associational n p : + forall r, In r (from_associational n p) -> length r = n. + Proof. + cbv [from_associational]; intros. + match goal with H: _ |- _ => apply length_from_columns in H end. + rewrite Columns.length_from_associational in *; auto. + Qed. + End FromAssociational. - Lemma length_sum_rows row1 row2 n : - length row1 = n -> length row2 = n -> - length (fst (sum_rows row1 row2)) = n. - Proof. - cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). - eapply sum_rows'_div_mod_length; auto using nil_length0. - Qed. Hint Rewrite length_sum_rows : distr_length. - - - (* TODO: move to ListUtil *) - (* connect state to remaining input *) - Lemma fold_right_invariant_strong {A B} (P: list B -> A -> Type) (f: B -> A -> A): - forall bs a0 - (Pnil : P nil a0) - (IHfold: - forall bs' b a', - In b bs -> - a' = fold_right f a0 bs' -> - P bs' a' -> - P (b :: bs') (f b a')), - P bs (fold_right f a0 bs). - Proof. induction bs; intros; autorewrite with push_fold_right; auto using in_eq,in_cons. Qed. - - Lemma flatten'_div_mod_length n start_state inp: - length (fst start_state) = n -> - (forall row, In row inp -> length row = n) -> - length (fst (flatten' start_state inp)) = n + Section Flatten. + Local Notation fw := (fun i => weight (S i) / weight i) (only parsing). + + Section SumRows. + Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z := + fold_right (fun next (state : list Z * Z) => + let i := length (fst state) in (* length of output accumulator tells us the index of [next] *) + dlet_nd next := next in (* makes the output correctly bind variables *) + dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in + (fst state ++ [fst sum_carry], snd sum_carry)) start_state (rev (combine row1 row2)). + Definition sum_rows := sum_rows' (nil,0). + + Lemma sum_rows'_div_mod_length row1 : + forall n m start_state row2 row1' row2', + length (fst start_state) = m -> length row1 = n -> length row2 = n -> + length row1' = m -> length row2' = m -> + let nm : nat := (n + m)%nat in + let eval := Positional.eval weight in + eval m (fst start_state) = (eval m row1' + eval m row2') mod (weight m) -> + snd start_state = (eval m row1' + eval m row2') / weight m -> + length (fst (sum_rows' start_state row1 row2)) = (n + m)%nat + /\ eval nm (fst (sum_rows' start_state row1 row2)) + = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight nm) + /\ snd (sum_rows' start_state row1 row2) + = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / (weight nm). + Proof. + cbv [sum_rows' Let_In]. + induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *; + destruct row2 as [|x2 row2]; distr_length; [ subst n | ]; + repeat match goal with + | _ => progress autorewrite with natsimplify list + | _ => progress cbn [fold_left combine] + | _ => omega + end. + + specialize (IHrow1 (pred n) (S m)). + replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega. + rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). + rewrite <-fold_left_rev_right. + apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega; + repeat match goal with + | H : ?LHS = _ |- _ => + match LHS with context [start_state] => rewrite H end + | H : length _ = _ |- _ => rewrite H + | _ => rewrite <-Z.div_add by auto + | _ => rewrite Z.div_div by auto using Z.gt_lt + | _ => rewrite Z.mul_div_eq by auto + | _ => rewrite weight_multiples + | _ => erewrite Positional.eval_snoc by eauto + | _ => progress autorewrite with cancel_pair distr_length to_div_mod in * + | |- context [ ?x mod ?m + ?m * (((?x + ?a * ?m + ?b * ?m)/ ?m) mod ?c) ] => + replace (x mod m) with ((x + a * m + b * m) mod m) by + (autorewrite with zsimplify; ring); + rewrite <-Z.rem_mul_r by auto using Z.gt_lt + | _ => f_equal; ring + end. + Qed. + + Lemma sum_rows_div_mod n row1 row2 : + length row1 = n -> length row2 = n -> + let eval := Positional.eval weight in + eval n (fst (sum_rows row1 row2)) = (eval n row1 + eval n row2) mod (weight n) + /\ snd (sum_rows row1 row2) = (eval n row1 + eval n row2) / (weight n). + Proof. + cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). + edestruct sum_rows'_div_mod_length as [Hlen [Hmod Hdiv] ]; + try erewrite Hmod, Hdiv; auto using nil_length0; + autorewrite with cancel_pair push_eval zsimplify_fast; distr_length. + Qed. + + Lemma sum_rows_mod n row1 row2 : + length row1 = n -> length row2 = n -> + Positional.eval weight n (fst (sum_rows row1 row2)) + = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n). + Proof. apply sum_rows_div_mod. Qed. + Lemma sum_rows_div row1 row2 n: + length row1 = n -> length row2 = n -> + snd (sum_rows row1 row2) + = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n). + Proof. apply sum_rows_div_mod. Qed. + + (* TODO: figure out where to put this and weight_multiples_full *) + Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0. + Proof. auto using Z.div_positive_gt_0, Columns.weight_multiples_full. Qed. + + Lemma sum_rows'_partitions row1 : + forall n m start_state row2 row1' row2', + length (fst start_state) = m -> length row1 = n -> length row2 = n -> + length row1' = m -> length row2' = m -> + let nm := (n + m)%nat in + let eval := Positional.eval weight in + snd start_state = (eval m row1' + eval m row2') / weight m -> + (forall j, (j < m)%nat -> + nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) -> + forall i, (i < nm)%nat -> + nth_default 0 (fst (sum_rows' start_state row1 row2)) i + = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i). + Proof. + cbv [sum_rows' Let_In]. + induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *; + destruct row2 as [|x2 row2]; distr_length; [ subst n | ]; + repeat match goal with + | _ => progress cbn [fold_left] + | H : length _ = _ |- _ => rewrite H + | H: _ |- _ => solve [apply H; omega] + | _ => progress autorewrite with push_eval zsimplify_fast push_combine list natsimplify cancel_pair to_div_mod + end. + + specialize (IHrow1 (pred n) (S m)). + replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega. + rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). + rewrite <-fold_left_rev_right. + apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega; + repeat match goal with + | _ => progress intros + | _ => progress autorewrite with push_nth_default + | _ => progress break_match + | H : length _ = _ |- _ => rewrite H + | H : ?LHS = _ |- _ => + match LHS with context [start_state] => rewrite H end + | H : context [nth_default 0 (fst start_state)] |- _ => rewrite H by omega + | _ => rewrite Positional.eval_snoc with (n:=m) by eauto + end. + { erewrite @Columns.weight_div_mod with (j:=S m) (i:=m) by eauto. + rewrite <-Z.div_div by auto using Z.gt_lt. + autorewrite with zsimplify. + f_equal; ring. } + { erewrite @Columns.weight_div_mod with (j:=m) (i:=S j) by (eauto; omega). + push_Zmod. autorewrite with zsimplify. lia. } + { replace j with m by omega. + autorewrite with push_nth_default natsimplify. + rewrite <-!Z.div_add' by auto. + rewrite Z.mod_pull_div, Z.mul_div_eq' by (auto using Z.lt_le_incl, Z.gt_lt). + rewrite weight_multiples. + autorewrite with zsimplify_fast. + repeat (f_equal; try ring). } + Qed. + + Lemma sum_rows_partitions row1: forall row2 n i, + length row1 = n -> length row2 = n -> (i < n)%nat -> + nth_default 0 (fst (sum_rows row1 row2)) i + = ((Positional.eval weight n row1 + Positional.eval weight n row2) mod weight (S i)) / (weight i). + Proof. + cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). + rewrite <-(app_nil_l row1), <-(app_nil_l row2). + apply sum_rows'_partitions; intros; + autorewrite with cancel_pair push_eval zsimplify_fast push_nth_default; distr_length. + rewrite Z.div_0_l by auto; omega. + Qed. + + Lemma length_sum_rows row1 row2 n : + length row1 = n -> length row2 = n -> + length (fst (sum_rows row1 row2)) = n. + Proof. + cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). + eapply sum_rows'_div_mod_length; auto using nil_length0. + Qed. Hint Rewrite length_sum_rows : distr_length. + End SumRows. + Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval. + + + Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z := + fold_right (fun next_row (state : list Z * Z)=> + let out_carry := sum_rows next_row (fst state) in + (fst out_carry, snd state + snd out_carry)) start_state inp. + + (* For correctness if there is only one row, we add a row of + zeroes with the same length so that the add loop still happens. *) + Definition flatten (inp : rows) : list Z * Z := + let first_row := hd nil inp in + flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)). + + (* TODO: move to ListUtil *) + (* connect state to remaining input *) + Lemma fold_right_invariant_strong {A B} (P: list B -> A -> Type) (f: B -> A -> A): + forall bs a0 + (Pnil : P nil a0) + (IHfold: + forall bs' b a', + In b bs -> + a' = fold_right f a0 bs' -> + P bs' a' -> + P (b :: bs') (f b a')), + P bs (fold_right f a0 bs). + Proof. induction bs; intros; autorewrite with push_fold_right; auto using in_eq,in_cons. Qed. + + Lemma flatten'_div_mod_length n start_state inp: + length (fst start_state) = n -> + (forall row, In row inp -> length row = n) -> + length (fst (flatten' start_state inp)) = n /\ (inp <> nil -> Positional.eval weight n (fst (flatten' start_state inp)) = (Positional.eval weight n (fst start_state) + eval n inp) mod (weight n) /\ snd (flatten' start_state inp) = snd start_state + (Positional.eval weight n (fst start_state) + eval n inp) / weight n). - Proof. - intro. - cbv [flatten']. - apply fold_right_invariant_strong with (bs:=inp); intros; [ tauto | destruct (dec (bs' = nil)) ]; - repeat match goal with - | _ => subst a' - | _ => subst bs' - | _ => rewrite sum_rows_div with (n:=n) by auto - | _ => rewrite @fold_right_nil in * - | _ => progress autorewrite with cancel_pair push_eval in * - | H : _ -> _ /\ _ |- _ => - let X := fresh in let Y := fresh in - destruct H as [X Y]; [ solve [auto using in_cons] | rewrite ?X, ?Y ] - | _ => split - | _ => progress autorewrite with pull_Zmod zsimplify_fast - | _ => solve [ auto using length_sum_rows ] - | _ => solve [ ring_simplify; repeat (f_equal; try ring) ] - end. - apply Z.mul_cancel_l with (p:=weight n); [ apply weight_nonzero |]. - autorewrite with push_Zmul. - rewrite !Z.mul_div_eq_full by apply weight_nonzero. - autorewrite with pull_Zmod. - ring_simplify_subterms. ring_simplify. - repeat (f_equal; try ring). - Qed. + Proof. + intro. + cbv [flatten']. + apply fold_right_invariant_strong with (bs:=inp); intros; [ tauto | destruct (dec (bs' = nil)) ]; + repeat match goal with + | _ => subst a' + | _ => subst bs' + | _ => rewrite sum_rows_div with (n:=n) by auto + | _ => rewrite @fold_right_nil in * + | _ => progress autorewrite with cancel_pair push_eval in * + | H : _ -> _ /\ _ |- _ => + let X := fresh in let Y := fresh in + destruct H as [X Y]; [ solve [auto using in_cons] | rewrite ?X, ?Y ] + | _ => split + | _ => progress autorewrite with pull_Zmod zsimplify_fast + | _ => solve [ auto using length_sum_rows ] + | _ => solve [ ring_simplify; repeat (f_equal; try ring) ] + end. + apply Z.mul_cancel_l with (p:=weight n); [ apply weight_nonzero |]. + autorewrite with push_Zmul. + rewrite !Z.mul_div_eq_full by apply weight_nonzero. + autorewrite with pull_Zmod. + ring_simplify_subterms. ring_simplify. + repeat (f_equal; try ring). + Qed. - Hint Rewrite (@Positional.length_zeros weight) : distr_length. - Hint Rewrite (@Positional.eval_zeros weight) using auto : push_eval. + Hint Rewrite (@Positional.length_zeros weight) : distr_length. + Hint Rewrite (@Positional.eval_zeros weight) using auto : push_eval. - Lemma flatten_div_mod inp n : - (forall row, In row inp -> length row = n) -> - Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n) - /\ snd (flatten inp) = (eval n inp) / weight n. - Proof. - intros; cbv [flatten]. - destruct inp; [|destruct inp]; cbn [hd tl]. - { cbn. autorewrite with push_eval. tauto. } - { cbn. - match goal with H : forall r, In r [?x] -> _ |- _ => - specialize (H x ltac:(auto)); rewrite H - end. - rewrite sum_rows_div with (n:=n) by distr_length. - autorewrite with push_eval. - split; f_equal; ring. } - { autorewrite with push_eval. - apply flatten'_div_mod_length; auto. - congruence. } - Qed. + Lemma flatten_div_mod inp n : + (forall row, In row inp -> length row = n) -> + Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n) + /\ snd (flatten inp) = (eval n inp) / weight n. + Proof. + intros; cbv [flatten]. + destruct inp; [|destruct inp]; cbn [hd tl]. + { cbn. autorewrite with push_eval. tauto. } + { cbn. + match goal with H : forall r, In r [?x] -> _ |- _ => + specialize (H x ltac:(auto)); rewrite H + end. + rewrite sum_rows_div with (n:=n) by distr_length. + autorewrite with push_eval. + split; f_equal; ring. } + { autorewrite with push_eval. + apply flatten'_div_mod_length; auto. + congruence. } + Qed. - Lemma flatten_mod inp n : - (forall row, In row inp -> length row = n) -> - Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n). - Proof. apply flatten_div_mod. Qed. - Lemma flatten_div inp n : - (forall row, In row inp -> length row = n) -> - snd (flatten inp) = (eval n inp) / (weight n). - Proof. apply flatten_div_mod. Qed. - - Lemma length_flatten' n start_state inp : - length (fst start_state) = n -> - (forall row, In row inp -> length row = n) -> - length (fst (flatten' start_state inp)) = n. - Proof. apply flatten'_div_mod_length. Qed. - Hint Rewrite length_flatten' : distr_length. - - Lemma length_flatten n inp : - (forall row, In row inp -> length row = n) -> - inp <> nil -> - length (fst (flatten inp)) = n. - Proof. - intros. - destruct inp; [congruence |]; destruct inp; - repeat match goal with - | _ => progress intros - | _ => progress cbn [hd tl] in * - | _ => progress autorewrite with cancel_pair - | _ => apply flatten'_div_mod_length - | H: _ |- _ => apply in_inv in H; destruct H - | _ => solve [auto] - end; - subst row; distr_length; auto. - Qed. Hint Rewrite length_flatten : distr_length. - - Lemma flatten'_cons state x inp : - flatten' state (x :: inp) - = (fst (sum_rows x (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows x (fst (flatten' state inp)))). - Proof. reflexivity. Qed. - - Lemma flatten'_partitions n start_state inp: - length (fst start_state) = n -> - (forall row, In row inp -> length row = n) -> - inp <> nil -> - forall i, (i < n)%nat -> - nth_default 0 (fst (flatten' start_state inp)) i - = ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i). - Proof. - intro. - induction inp; intros; [congruence|]. - rewrite flatten'_cons. - autorewrite with push_fold_right cancel_pair. - rewrite sum_rows_partitions with (n:=n) by (rewrite ?length_flatten' with (n:=n); auto). - destruct (dec (inp = nil)). - { subst inp. cbn. repeat (f_equal; try ring). } - { edestruct flatten'_div_mod_length with (inp:=inp) as [Hlen [Hmod Hdiv] ]; eauto. - rewrite Hmod. autorewrite with push_eval. - rewrite Z.add_mod_full. mul_div_weights n (S i). - rewrite Z.rem_mul_r by auto using Z.gt_lt, weight_divides_full. - autorewrite with zsimplify. pull_Zmod. - repeat (f_equal; try ring). } - Qed. + Lemma flatten_mod inp n : + (forall row, In row inp -> length row = n) -> + Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n). + Proof. apply flatten_div_mod. Qed. + Lemma flatten_div inp n : + (forall row, In row inp -> length row = n) -> + snd (flatten inp) = (eval n inp) / (weight n). + Proof. apply flatten_div_mod. Qed. - Lemma flatten_partitions inp n : - (forall row, In row inp -> length row = n) -> - forall i, (i < n)%nat -> - nth_default 0 (fst (flatten inp)) i = (eval n inp mod weight (S i)) / (weight i). - Proof. - intros; cbv [flatten]. - destruct inp; [|destruct inp]; cbn [hd tl]. - { cbn. autorewrite with push_eval push_nth_default zsimplify. reflexivity. } - { cbn. - match goal with H : forall r, In r [?x] -> _ |- _ => - specialize (H x ltac:(auto)); rewrite H - end. - rewrite sum_rows_partitions with (n:=n) by distr_length. - autorewrite with push_eval. - repeat (f_equal; try ring). } - { autorewrite with push_eval. - apply flatten'_partitions; auto. - congruence. } + Lemma length_flatten' n start_state inp : + length (fst start_state) = n -> + (forall row, In row inp -> length row = n) -> + length (fst (flatten' start_state inp)) = n. + Proof. apply flatten'_div_mod_length. Qed. + Hint Rewrite length_flatten' : distr_length. + + Lemma length_flatten n inp : + (forall row, In row inp -> length row = n) -> + inp <> nil -> + length (fst (flatten inp)) = n. + Proof. + intros. + destruct inp; [congruence |]; destruct inp; + repeat match goal with + | _ => progress intros + | _ => progress cbn [hd tl] in * + | _ => progress autorewrite with cancel_pair + | _ => apply flatten'_div_mod_length + | H: _ |- _ => apply in_inv in H; destruct H + | _ => solve [auto] + end; + subst row; distr_length; auto. + Qed. Hint Rewrite length_flatten : distr_length. + + Lemma flatten'_cons state x inp : + flatten' state (x :: inp) + = (fst (sum_rows x (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows x (fst (flatten' state inp)))). + Proof. reflexivity. Qed. + + Lemma flatten'_partitions n start_state inp: + length (fst start_state) = n -> + (forall row, In row inp -> length row = n) -> + inp <> nil -> + forall i, (i < n)%nat -> + nth_default 0 (fst (flatten' start_state inp)) i + = ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i). + Proof. + intro. + induction inp; intros; [congruence|]. + rewrite flatten'_cons. + autorewrite with push_fold_right cancel_pair. + rewrite sum_rows_partitions with (n:=n) by (rewrite ?length_flatten' with (n:=n); auto). + destruct (dec (inp = nil)). + { subst inp. cbn. repeat (f_equal; try ring). } + { edestruct flatten'_div_mod_length with (inp:=inp) as [Hlen [Hmod Hdiv] ]; eauto. + rewrite Hmod. autorewrite with push_eval. + rewrite Z.add_mod_full. erewrite @Columns.weight_div_mod with (j:=n) (i:=S i) by eauto. + rewrite Z.rem_mul_r by auto using Z.gt_lt, weight_divides_full. + autorewrite with zsimplify. pull_Zmod. + repeat (f_equal; try ring). } Qed. + Lemma flatten_partitions inp n : + (forall row, In row inp -> length row = n) -> + forall i, (i < n)%nat -> + nth_default 0 (fst (flatten inp)) i = (eval n inp mod weight (S i)) / (weight i). + Proof. + intros; cbv [flatten]. + destruct inp; [|destruct inp]; cbn [hd tl]. + { cbn. autorewrite with push_eval push_nth_default zsimplify. reflexivity. } + { cbn. + match goal with H : forall r, In r [?x] -> _ |- _ => + specialize (H x ltac:(auto)); rewrite H + end. + rewrite sum_rows_partitions with (n:=n) by distr_length. + autorewrite with push_eval. + repeat (f_equal; try ring). } + { autorewrite with push_eval. + apply flatten'_partitions; auto. + congruence. } + Qed. + End Flatten. + End Rows. End Rows. |