aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-03-06 11:17:21 +0100
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-03 09:00:55 -0400
commitfcf5f782aade5339ad91e077f23010e1dd27d98c (patch)
treeb1dc18d00db4136b36ab068847da4117ad8e00c8 /src
parente07980afdbbc95b2aee339e15cf74c69661b2fd9 (diff)
finish flatten_partitions and slightly change the format of _partitions lemma statements
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v143
1 files changed, 101 insertions, 42 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index ae921d1ae..c301fc67f 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -910,45 +910,31 @@ Module Columns.
Lemma flatten_partitions inp:
forall n i, length inp = n -> (i < n)%nat ->
- nth_default 0 (fst (flatten inp)) i = (((eval n inp) / weight i)) mod (weight (S i) / weight i).
+ nth_default 0 (fst (flatten inp)) i = ((eval n inp) mod (weight (S i))) / weight i.
Proof.
- induction inp using rev_ind; distr_length; intros.
- { cbn.
- autorewrite with push_eval push_nth_default zsimplify.
- reflexivity. }
- {
- destruct n as [| n]; [omega|].
- rewrite flatten_snoc, eval_snoc by omega.
+ induction inp using rev_ind; intros; destruct n; distr_length.
+ { rewrite flatten_snoc, eval_snoc by omega.
cbv [flatten_step Let_In]. cbn [fst].
rewrite nth_default_app.
break_match; distr_length.
{ rewrite IHinp with (n:=n) by omega.
- rewrite (Z.div_mod (weight n) (weight i)) by auto.
- rewrite weight_multiples_full by omega.
rewrite (Z.div_mod (weight n) (weight (S i))) by auto.
rewrite weight_multiples_full by omega.
+ push_Zmod.
autorewrite with zsimplify.
- repeat match goal with
- | _ => rewrite Z_divide_div_mul_exact' by (try apply Z.mod_divide; auto)
- | |- context [ (_ + ?a * ?b * ?c) / ?a ] =>
- replace (a * b * c) with (a * (b * c)) by ring;
- rewrite Z.div_add' by auto
- | |- context [ (_ + ?a * ?b * ?c) mod ?b ] =>
- replace (a * b * c) with (a * c * b) by ring;
- rewrite Z.mod_add by auto using ZUtil.Z.positive_is_nonzero
- | _ => reflexivity
- end.
- }
+ reflexivity. }
{ repeat match goal with
| _ => progress replace (Datatypes.length inp) with n by omega
| _ => progress replace i with n by omega
- | _ => rewrite nth_default_cons
| _ => rewrite sum_cons
| _ => rewrite flatten_column_mod
| _ => erewrite flatten_div by eauto
- | _ => progress autorewrite with natsimplify
+ | _ => rewrite <-Z.div_add' by auto
+ | _ => rewrite Z.mod_pull_div by auto using Z.lt_le_incl, Z.gt_lt
+ | _ => rewrite Z.mul_div_eq', weight_multiples by auto
+ | _ => progress autorewrite with push_nth_default natsimplify
end.
- rewrite Z.div_add' by auto.
+ autorewrite with zsimplify.
reflexivity. } }
Qed.
@@ -1039,12 +1025,7 @@ Module Columns.
intros; subst n3; cbv [mul_converted].
erewrite flatten_partitions by (auto; distr_length).
autorewrite with distr_length push_eval.
- pose proof (w_positive 1).
- apply Z.mod_small.
- split; [ solve[Z.zero_bounds] | ].
- apply Z.div_lt_upper_bound; [omega|].
- rewrite Z.mul_div_eq_full by auto.
- rewrite w_multiples. omega.
+ rewrite Z.mod_small; omega.
Qed.
(* shortcut definition for convert-mul-convert for cases when we are halving the bitwidth before multiplying. *)
@@ -1303,6 +1284,10 @@ Module Rows.
rewrite Columns.weight_multiples_full by (auto; omega);
autorewrite with zsimplify_fast.
+ (* TODO: figure out where to put this and weight_multiples_full *)
+ Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0.
+ Proof. auto using Z.div_positive_gt_0, Columns.weight_multiples_full. Qed.
+
Lemma sum_rows'_partitions row1 :
forall n m start_state row2 row1' row2',
length (fst start_state) = m -> length row1 = n -> length row2 = n ->
@@ -1311,10 +1296,10 @@ Module Rows.
let eval := Positional.eval weight in
snd start_state = (eval m row1' + eval m row2') / weight m ->
(forall j, (j < m)%nat ->
- nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') / (weight j)) mod (fw j)) ->
+ nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) ->
forall i, (i < nm)%nat ->
nth_default 0 (fst (sum_rows' start_state row1 row2)) i
- = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / weight i) mod (fw i).
+ = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i).
Proof.
cbv [sum_rows'].
induction row1 as [|x1 row1]; intros;
@@ -1344,21 +1329,22 @@ Module Rows.
rewrite <-Z.div_div by auto using Z.gt_lt.
autorewrite with zsimplify.
f_equal; ring. }
- { mul_div_weights m j.
- ring_simplify_subterms. autorewrite with zsimplify.
- mul_div_weights m (S j).
- rewrite Columns.Z_divide_div_mul_exact' by (try apply Z.mod_divide; auto).
- push_Zmod. autorewrite with zsimplify_fast.
+ { mul_div_weights m (S j).
+ push_Zmod. autorewrite with zsimplify.
lia. }
{ replace j with m by omega.
- autorewrite with push_nth_default natsimplify zsimplify.
- f_equal; ring. }
+ autorewrite with push_nth_default natsimplify.
+ rewrite <-!Z.div_add' by auto.
+ rewrite Z.mod_pull_div, Z.mul_div_eq' by (auto using Z.lt_le_incl, Z.gt_lt).
+ rewrite weight_multiples.
+ autorewrite with zsimplify_fast.
+ repeat (f_equal; try ring). }
Qed.
Lemma sum_rows_partitions row1: forall row2 n i,
length row1 = n -> length row2 = n -> (i < n)%nat ->
nth_default 0 (fst (sum_rows row1 row2)) i
- = ((Positional.eval weight n row1 + Positional.eval weight n row2) / weight i) mod (fw i).
+ = ((Positional.eval weight n row1 + Positional.eval weight n row2) mod weight (S i)) / (weight i).
Proof.
cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n).
rewrite <-(app_nil_l row1), <-(app_nil_l row2).
@@ -1390,7 +1376,7 @@ Module Rows.
P bs (fold_right f a0 bs).
Proof. induction bs; intros; autorewrite with push_fold_right; auto using in_eq,in_cons. Qed.
- Lemma flatten'_div_mod n start_state inp:
+ Lemma flatten'_div_mod_length n start_state inp:
length (fst start_state) = n ->
(forall row, In row inp -> length row = n) ->
length (fst (flatten' start_state inp)) = n
@@ -1442,7 +1428,80 @@ Module Rows.
autorewrite with push_eval.
split; f_equal; ring. }
{ autorewrite with push_eval.
- apply flatten'_div_mod; auto.
+ apply flatten'_div_mod_length; auto.
+ congruence. }
+ Qed.
+
+ Lemma length_flatten' n start_state inp :
+ length (fst start_state) = n ->
+ (forall row, In row inp -> length row = n) ->
+ length (fst (flatten' start_state inp)) = n.
+ Proof. apply flatten'_div_mod_length. Qed.
+ Hint Rewrite length_flatten' : distr_length.
+
+ Lemma length_flatten n inp :
+ (forall row, In row inp -> length row = n) ->
+ inp <> nil ->
+ length (fst (flatten inp)) = n.
+ Proof.
+ intros.
+ destruct inp; [congruence |]; destruct inp;
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress cbn [hd tl] in *
+ | _ => progress autorewrite with cancel_pair
+ | _ => apply flatten'_div_mod_length
+ | H: _ |- _ => apply in_inv in H; destruct H
+ | _ => solve [auto]
+ end;
+ subst row; distr_length; auto.
+ Qed. Hint Rewrite length_flatten : distr_length.
+
+ Lemma flatten'_cons state x inp :
+ flatten' state (x :: inp)
+ = (fst (sum_rows x (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows x (fst (flatten' state inp)))).
+ Proof. reflexivity. Qed.
+
+ Lemma flatten'_partitions n start_state inp:
+ length (fst start_state) = n ->
+ (forall row, In row inp -> length row = n) ->
+ inp <> nil ->
+ forall i, (i < n)%nat ->
+ nth_default 0 (fst (flatten' start_state inp)) i
+ = ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i).
+ Proof.
+ intro.
+ induction inp; intros; [congruence|].
+ rewrite flatten'_cons.
+ autorewrite with push_fold_right cancel_pair.
+ rewrite sum_rows_partitions with (n:=n) by (rewrite ?length_flatten' with (n:=n); auto).
+ destruct (dec (inp = nil)).
+ { subst inp. cbn. repeat (f_equal; try ring). }
+ { edestruct flatten'_div_mod_length with (inp:=inp) as [Hlen [Hmod Hdiv] ]; eauto.
+ rewrite Hmod. autorewrite with push_eval.
+ rewrite Z.add_mod_full. mul_div_weights n (S i).
+ rewrite Z.rem_mul_r by auto using Z.gt_lt, weight_divides_full.
+ autorewrite with zsimplify. pull_Zmod.
+ repeat (f_equal; try ring). }
+ Qed.
+
+ Lemma flatten_partitions inp n :
+ (forall row, In row inp -> length row = n) ->
+ forall i, (i < n)%nat ->
+ nth_default 0 (fst (flatten inp)) i = (eval n inp mod weight (S i)) / (weight i).
+ Proof.
+ intros; cbv [flatten].
+ destruct inp; [|destruct inp]; cbn [hd tl].
+ { cbn. autorewrite with push_eval push_nth_default zsimplify. reflexivity. }
+ { cbn.
+ match goal with H : forall r, In r [?x] -> _ |- _ =>
+ specialize (H x ltac:(auto)); rewrite H
+ end.
+ rewrite sum_rows_partitions with (n:=n) by distr_length.
+ autorewrite with push_eval.
+ repeat (f_equal; try ring). }
+ { autorewrite with push_eval.
+ apply flatten'_partitions; auto.
congruence. }
Qed.