aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-03-02 16:38:46 +0100
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-03 09:00:55 -0400
commite07980afdbbc95b2aee339e15cf74c69661b2fd9 (patch)
tree30e3c6c454911f417a1ace9a35f41e0ed3102672
parent9cfd648b0330bb4fff8a6277496d97b9c00d79fe (diff)
proved admits about sum_rows
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v88
1 files changed, 76 insertions, 12 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 3559955c4..ae921d1ae 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -1076,6 +1076,7 @@ Module Rows.
(* TODO: move to listUtil or wherever push_map is defined *)
Hint Rewrite @map_app : push_map.
Hint Rewrite sum_nil sum_cons sum_app : push_sum.
+ Hint Rewrite @combine_nil_r @combine_cons : push_combine.
Hint Resolve in_eq in_cons.
@@ -1227,18 +1228,18 @@ Module Rows.
Definition flatten (inp : rows) : list Z * Z :=
let first_row := hd nil inp in
flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)).
-
- Lemma sum_rows'_div_mod row1 :
+ Lemma sum_rows'_div_mod_length row1 :
forall n m start_state row2 row1' row2',
length (fst start_state) = m -> length row1 = n -> length row2 = n ->
length row1' = m -> length row2' = m ->
let nm : nat := (n + m)%nat in
- let eval (n:nat) := Positional.eval weight n in
+ let eval := Positional.eval weight in
eval m (fst start_state) = (eval m row1' + eval m row2') mod (weight m) ->
snd start_state = (eval m row1' + eval m row2') / weight m ->
- eval nm (fst (sum_rows' start_state row1 row2))
- = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight nm)
+ length (fst (sum_rows' start_state row1 row2)) = (n + m)%nat
+ /\ eval nm (fst (sum_rows' start_state row1 row2))
+ = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight nm)
/\ snd (sum_rows' start_state row1 row2)
= (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / (weight nm).
Proof.
@@ -1246,9 +1247,8 @@ Module Rows.
induction row1 as [|x1 row1]; intros;
destruct row2 as [|x2 row2]; distr_length; [ subst n | ];
repeat match goal with
- | _ => progress autorewrite with natsimplify
+ | _ => progress autorewrite with natsimplify list
| _ => progress cbn [fold_left combine]
- | _ => rewrite app_nil_r
| _ => omega
end.
@@ -1281,7 +1281,7 @@ Module Rows.
/\ snd (sum_rows row1 row2) = (eval n row1 + eval n row2) / (weight n).
Proof.
cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
- edestruct sum_rows'_div_mod as [Hmod Hdiv];
+ edestruct sum_rows'_div_mod_length as [Hlen [Hmod Hdiv] ];
try erewrite Hmod, Hdiv; auto using nil_length0;
autorewrite with cancel_pair push_eval zsimplify_fast; distr_length.
Qed.
@@ -1298,18 +1298,82 @@ Module Rows.
Proof. apply sum_rows_div_mod. Qed.
Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval.
+ Local Ltac mul_div_weights i j :=
+ rewrite (Z.div_mod (weight i) (weight j)) by auto;
+ rewrite Columns.weight_multiples_full by (auto; omega);
+ autorewrite with zsimplify_fast.
+
+ Lemma sum_rows'_partitions row1 :
+ forall n m start_state row2 row1' row2',
+ length (fst start_state) = m -> length row1 = n -> length row2 = n ->
+ length row1' = m -> length row2' = m ->
+ let nm := (n + m)%nat in
+ let eval := Positional.eval weight in
+ snd start_state = (eval m row1' + eval m row2') / weight m ->
+ (forall j, (j < m)%nat ->
+ nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') / (weight j)) mod (fw j)) ->
+ forall i, (i < nm)%nat ->
+ nth_default 0 (fst (sum_rows' start_state row1 row2)) i
+ = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / weight i) mod (fw i).
+ Proof.
+ cbv [sum_rows'].
+ induction row1 as [|x1 row1]; intros;
+ destruct row2 as [|x2 row2]; distr_length; [ subst n | ];
+ repeat match goal with
+ | _ => progress cbn [fold_left]
+ | H : length _ = _ |- _ => rewrite H
+ | H: _ |- _ => solve [apply H; omega]
+ | _ => progress autorewrite with push_eval zsimplify_fast push_combine list natsimplify cancel_pair to_div_mod
+ end.
- Lemma sum_rows_partitions row1 row2 n i:
+ specialize (IHrow1 (pred n) (S m)).
+ replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega.
+ rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
+ apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega;
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress autorewrite with push_nth_default
+ | _ => progress break_match
+ | H : length _ = _ |- _ => rewrite H
+ | H : ?LHS = _ |- _ =>
+ match LHS with context [start_state] => rewrite H end
+ | H : context [nth_default 0 (fst start_state)] |- _ => rewrite H by omega
+ | _ => rewrite Positional.eval_snoc with (n:=m) by eauto
+ end.
+ { mul_div_weights (S m) m.
+ rewrite <-Z.div_div by auto using Z.gt_lt.
+ autorewrite with zsimplify.
+ f_equal; ring. }
+ { mul_div_weights m j.
+ ring_simplify_subterms. autorewrite with zsimplify.
+ mul_div_weights m (S j).
+ rewrite Columns.Z_divide_div_mul_exact' by (try apply Z.mod_divide; auto).
+ push_Zmod. autorewrite with zsimplify_fast.
+ lia. }
+ { replace j with m by omega.
+ autorewrite with push_nth_default natsimplify zsimplify.
+ f_equal; ring. }
+ Qed.
+
+ Lemma sum_rows_partitions row1: forall row2 n i,
length row1 = n -> length row2 = n -> (i < n)%nat ->
nth_default 0 (fst (sum_rows row1 row2)) i
= ((Positional.eval weight n row1 + Positional.eval weight n row2) / weight i) mod (fw i).
Proof.
- Admitted.
+ cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
+ rewrite <-(app_nil_l row1), <-(app_nil_l row2).
+ apply sum_rows'_partitions; intros;
+ autorewrite with cancel_pair push_eval zsimplify_fast push_nth_default; distr_length.
+ rewrite Z.div_0_l by auto; omega.
+ Qed.
Lemma length_sum_rows row1 row2 n :
length row1 = n -> length row2 = n ->
length (fst (sum_rows row1 row2)) = n.
- Admitted.
+ Proof.
+ cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
+ eapply sum_rows'_div_mod_length; auto using nil_length0.
+ Qed. Hint Rewrite length_sum_rows : distr_length.
(* TODO: move to ListUtil *)
@@ -1324,7 +1388,7 @@ Module Rows.
P bs' a' ->
P (b :: bs') (f b a')),
P bs (fold_right f a0 bs).
- Proof. induction bs; intros; rewrite ?fold_right_cons; auto using in_eq,in_cons. Qed.
+ Proof. induction bs; intros; autorewrite with push_fold_right; auto using in_eq,in_cons. Qed.
Lemma flatten'_div_mod n start_state inp:
length (fst start_state) = n ->