diff options
author | Jade Philipoom <jadep@google.com> | 2018-03-02 16:38:46 +0100 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2018-04-03 09:00:55 -0400 |
commit | e07980afdbbc95b2aee339e15cf74c69661b2fd9 (patch) | |
tree | 30e3c6c454911f417a1ace9a35f41e0ed3102672 | |
parent | 9cfd648b0330bb4fff8a6277496d97b9c00d79fe (diff) |
proved admits about sum_rows
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 88 |
1 files changed, 76 insertions, 12 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 3559955c4..ae921d1ae 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -1076,6 +1076,7 @@ Module Rows. (* TODO: move to listUtil or wherever push_map is defined *) Hint Rewrite @map_app : push_map. Hint Rewrite sum_nil sum_cons sum_app : push_sum. + Hint Rewrite @combine_nil_r @combine_cons : push_combine. Hint Resolve in_eq in_cons. @@ -1227,18 +1228,18 @@ Module Rows. Definition flatten (inp : rows) : list Z * Z := let first_row := hd nil inp in flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)). - - Lemma sum_rows'_div_mod row1 : + Lemma sum_rows'_div_mod_length row1 : forall n m start_state row2 row1' row2', length (fst start_state) = m -> length row1 = n -> length row2 = n -> length row1' = m -> length row2' = m -> let nm : nat := (n + m)%nat in - let eval (n:nat) := Positional.eval weight n in + let eval := Positional.eval weight in eval m (fst start_state) = (eval m row1' + eval m row2') mod (weight m) -> snd start_state = (eval m row1' + eval m row2') / weight m -> - eval nm (fst (sum_rows' start_state row1 row2)) - = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight nm) + length (fst (sum_rows' start_state row1 row2)) = (n + m)%nat + /\ eval nm (fst (sum_rows' start_state row1 row2)) + = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight nm) /\ snd (sum_rows' start_state row1 row2) = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / (weight nm). Proof. @@ -1246,9 +1247,8 @@ Module Rows. induction row1 as [|x1 row1]; intros; destruct row2 as [|x2 row2]; distr_length; [ subst n | ]; repeat match goal with - | _ => progress autorewrite with natsimplify + | _ => progress autorewrite with natsimplify list | _ => progress cbn [fold_left combine] - | _ => rewrite app_nil_r | _ => omega end. @@ -1281,7 +1281,7 @@ Module Rows. /\ snd (sum_rows row1 row2) = (eval n row1 + eval n row2) / (weight n). Proof. cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). - edestruct sum_rows'_div_mod as [Hmod Hdiv]; + edestruct sum_rows'_div_mod_length as [Hlen [Hmod Hdiv] ]; try erewrite Hmod, Hdiv; auto using nil_length0; autorewrite with cancel_pair push_eval zsimplify_fast; distr_length. Qed. @@ -1298,18 +1298,82 @@ Module Rows. Proof. apply sum_rows_div_mod. Qed. Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval. + Local Ltac mul_div_weights i j := + rewrite (Z.div_mod (weight i) (weight j)) by auto; + rewrite Columns.weight_multiples_full by (auto; omega); + autorewrite with zsimplify_fast. + + Lemma sum_rows'_partitions row1 : + forall n m start_state row2 row1' row2', + length (fst start_state) = m -> length row1 = n -> length row2 = n -> + length row1' = m -> length row2' = m -> + let nm := (n + m)%nat in + let eval := Positional.eval weight in + snd start_state = (eval m row1' + eval m row2') / weight m -> + (forall j, (j < m)%nat -> + nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') / (weight j)) mod (fw j)) -> + forall i, (i < nm)%nat -> + nth_default 0 (fst (sum_rows' start_state row1 row2)) i + = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / weight i) mod (fw i). + Proof. + cbv [sum_rows']. + induction row1 as [|x1 row1]; intros; + destruct row2 as [|x2 row2]; distr_length; [ subst n | ]; + repeat match goal with + | _ => progress cbn [fold_left] + | H : length _ = _ |- _ => rewrite H + | H: _ |- _ => solve [apply H; omega] + | _ => progress autorewrite with push_eval zsimplify_fast push_combine list natsimplify cancel_pair to_div_mod + end. - Lemma sum_rows_partitions row1 row2 n i: + specialize (IHrow1 (pred n) (S m)). + replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega. + rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). + apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega; + repeat match goal with + | _ => progress intros + | _ => progress autorewrite with push_nth_default + | _ => progress break_match + | H : length _ = _ |- _ => rewrite H + | H : ?LHS = _ |- _ => + match LHS with context [start_state] => rewrite H end + | H : context [nth_default 0 (fst start_state)] |- _ => rewrite H by omega + | _ => rewrite Positional.eval_snoc with (n:=m) by eauto + end. + { mul_div_weights (S m) m. + rewrite <-Z.div_div by auto using Z.gt_lt. + autorewrite with zsimplify. + f_equal; ring. } + { mul_div_weights m j. + ring_simplify_subterms. autorewrite with zsimplify. + mul_div_weights m (S j). + rewrite Columns.Z_divide_div_mul_exact' by (try apply Z.mod_divide; auto). + push_Zmod. autorewrite with zsimplify_fast. + lia. } + { replace j with m by omega. + autorewrite with push_nth_default natsimplify zsimplify. + f_equal; ring. } + Qed. + + Lemma sum_rows_partitions row1: forall row2 n i, length row1 = n -> length row2 = n -> (i < n)%nat -> nth_default 0 (fst (sum_rows row1 row2)) i = ((Positional.eval weight n row1 + Positional.eval weight n row2) / weight i) mod (fw i). Proof. - Admitted. + cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). + rewrite <-(app_nil_l row1), <-(app_nil_l row2). + apply sum_rows'_partitions; intros; + autorewrite with cancel_pair push_eval zsimplify_fast push_nth_default; distr_length. + rewrite Z.div_0_l by auto; omega. + Qed. Lemma length_sum_rows row1 row2 n : length row1 = n -> length row2 = n -> length (fst (sum_rows row1 row2)) = n. - Admitted. + Proof. + cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). + eapply sum_rows'_div_mod_length; auto using nil_length0. + Qed. Hint Rewrite length_sum_rows : distr_length. (* TODO: move to ListUtil *) @@ -1324,7 +1388,7 @@ Module Rows. P bs' a' -> P (b :: bs') (f b a')), P bs (fold_right f a0 bs). - Proof. induction bs; intros; rewrite ?fold_right_cons; auto using in_eq,in_cons. Qed. + Proof. induction bs; intros; autorewrite with push_fold_right; auto using in_eq,in_cons. Qed. Lemma flatten'_div_mod n start_state inp: length (fst start_state) = n -> |