aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-03-06 13:39:13 +0100
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-03 09:00:55 -0400
commitd809687fc691725f0347ceb4d76b35c373965980 (patch)
tree28fcbfbf2eb53f82884587810e981a7335cf9de8 /src
parent05567335df0a787e66877a222b2284975b0f7f0a (diff)
add Rows.from_associational and some more length proofs that allow Rows.length_from_associational
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v79
1 files changed, 71 insertions, 8 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 2e46d0dd8..4d0af6491 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -998,6 +998,8 @@ Module Rows.
) start_state (List.repeat 0 n).
Definition from_columns (inp : cols) : rows := snd (from_columns' (max_column_size inp) (inp, [])).
+
+ Definition from_associational n (p : list (Z * Z)) := from_columns (Columns.from_associational weight n p).
Lemma eval_extract_row (inp : cols): forall n,
length inp = n ->
@@ -1016,10 +1018,15 @@ Module Rows.
destruct x; cbn [hd tl]; rewrite ?sum_nil, ?sum_cons; ring.
Qed. Hint Rewrite eval_extract_row using (solve [distr_length]) : push_eval.
- Lemma length_extract_row n (inp : cols) :
+ Lemma length_fst_extract_row n (inp : cols) :
length inp = n -> length (fst (extract_row inp)) = n.
Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed.
- Hint Rewrite length_extract_row : distr_length.
+ Hint Rewrite length_fst_extract_row : distr_length.
+
+ Lemma length_snd_extract_row n (inp : cols) :
+ length inp = n -> length (snd (extract_row inp)) = n.
+ Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed.
+ Hint Rewrite length_snd_extract_row : distr_length.
(* TODO: move to where list is defined *)
Hint Rewrite @app_nil_l : list.
@@ -1044,20 +1051,41 @@ Module Rows.
rewrite IHinp; distr_length; lia.
Qed.
+ Local Ltac In_cases :=
+ repeat match goal with
+ | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H
+ | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H
+ | H: In _ nil |- _ => contradiction H
+ end.
+
Lemma eval_from_columns'_with_length m st n:
(length (fst st) = n) ->
length (fst (from_columns' m st)) = n /\
+ ((forall r, In r (snd st) -> length r = n) ->
+ forall r, In r (snd (from_columns' m st)) -> length r = n) /\
eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st)
- Columns.eval weight n (fst (from_columns' m st)).
Proof.
cbv [from_columns']; intros.
- apply fold_right_invariant; intros; [ split; assumption || ring | ].
- autorewrite with cancel_pair push_eval.
- split; [ | omega]. apply length_extract_row; tauto.
+ apply fold_right_invariant; intros;
+ repeat match goal with
+ | _ => progress (intros; subst)
+ | _ => progress autorewrite with cancel_pair push_eval
+ | _ => progress In_cases
+ | _ => split; try omega
+ | H: _ /\ _ |- _ => destruct H
+ | _ => solve [auto using length_fst_extract_row, length_snd_extract_row]
+ end.
Qed.
- Lemma length_from_columns' m st : length (fst (from_columns' m st)) = length (fst st).
+ Lemma length_fst_from_columns' m st :
+ length (fst (from_columns' m st)) = length (fst st).
Proof. apply eval_from_columns'_with_length; reflexivity. Qed.
- Hint Rewrite length_from_columns' : distr_length.
+ Hint Rewrite length_fst_from_columns' : distr_length.
+ Lemma length_snd_from_columns' m st :
+ (forall r, In r (snd st) -> length r = length (fst st)) ->
+ forall r, In r (snd (from_columns' m st)) -> length r = length (fst st).
+ Proof. apply eval_from_columns'_with_length. reflexivity. Qed.
+ Hint Rewrite length_snd_from_columns' : distr_length.
Lemma eval_from_columns' m st n :
(length (fst st) = n) ->
eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st)
@@ -1101,6 +1129,32 @@ Module Rows.
Qed.
Hint Rewrite eval_from_columns using (auto; solve [distr_length]) : push_eval.
+ Lemma length_from_columns inp:
+ forall r, In r (from_columns inp) -> length r = length inp.
+ Proof.
+ cbv [from_columns]; intros.
+ change inp with (fst (inp, @nil (list Z))).
+ eapply length_snd_from_columns'; eauto.
+ autorewrite with cancel_pair; intros; In_cases.
+ Qed.
+ Hint Rewrite length_from_columns : distr_length.
+
+ Lemma eval_from_associational n p: (n <> 0%nat \/ p = nil) ->
+ eval n (from_associational n p) = Associational.eval p.
+ Proof.
+ intros. cbv [from_associational].
+ rewrite eval_from_columns by auto using Columns.length_from_associational.
+ auto using Columns.eval_from_associational.
+ Qed.
+
+ Lemma length_from_associational n p :
+ forall r, In r (from_associational n p) -> length r = n.
+ Proof.
+ cbv [from_associational]; intros.
+ match goal with H: _ |- _ => apply length_from_columns in H end.
+ rewrite Columns.length_from_associational in *; auto.
+ Qed.
+
Local Notation fw := (fun i => weight (S i) / weight i) (only parsing).
Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z :=
@@ -1343,6 +1397,15 @@ Module Rows.
congruence. }
Qed.
+ Lemma flatten_mod inp n :
+ (forall row, In row inp -> length row = n) ->
+ Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n).
+ Proof. apply flatten_div_mod. Qed.
+ Lemma flatten_div inp n :
+ (forall row, In row inp -> length row = n) ->
+ snd (flatten inp) = (eval n inp) / (weight n).
+ Proof. apply flatten_div_mod. Qed.
+
Lemma length_flatten' n start_state inp :
length (fst start_state) = n ->
(forall row, In row inp -> length row = n) ->
@@ -1366,7 +1429,7 @@ Module Rows.
| _ => solve [auto]
end;
subst row; distr_length; auto.
- Qed. Hint Rewrite length_flatten : distr_length.
+ Qed. Hint Rewrite length_flatten : distr_length.
Lemma flatten'_cons state x inp :
flatten' state (x :: inp)