diff options
author | Jade Philipoom <jadep@google.com> | 2018-03-06 11:17:21 +0100 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2018-04-03 09:00:55 -0400 |
commit | fcf5f782aade5339ad91e077f23010e1dd27d98c (patch) | |
tree | b1dc18d00db4136b36ab068847da4117ad8e00c8 /src/Experiments/SimplyTypedArithmetic.v | |
parent | e07980afdbbc95b2aee339e15cf74c69661b2fd9 (diff) |
finish flatten_partitions and slightly change the format of _partitions lemma statements
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 143 |
1 files changed, 101 insertions, 42 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index ae921d1ae..c301fc67f 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -910,45 +910,31 @@ Module Columns. Lemma flatten_partitions inp: forall n i, length inp = n -> (i < n)%nat -> - nth_default 0 (fst (flatten inp)) i = (((eval n inp) / weight i)) mod (weight (S i) / weight i). + nth_default 0 (fst (flatten inp)) i = ((eval n inp) mod (weight (S i))) / weight i. Proof. - induction inp using rev_ind; distr_length; intros. - { cbn. - autorewrite with push_eval push_nth_default zsimplify. - reflexivity. } - { - destruct n as [| n]; [omega|]. - rewrite flatten_snoc, eval_snoc by omega. + induction inp using rev_ind; intros; destruct n; distr_length. + { rewrite flatten_snoc, eval_snoc by omega. cbv [flatten_step Let_In]. cbn [fst]. rewrite nth_default_app. break_match; distr_length. { rewrite IHinp with (n:=n) by omega. - rewrite (Z.div_mod (weight n) (weight i)) by auto. - rewrite weight_multiples_full by omega. rewrite (Z.div_mod (weight n) (weight (S i))) by auto. rewrite weight_multiples_full by omega. + push_Zmod. autorewrite with zsimplify. - repeat match goal with - | _ => rewrite Z_divide_div_mul_exact' by (try apply Z.mod_divide; auto) - | |- context [ (_ + ?a * ?b * ?c) / ?a ] => - replace (a * b * c) with (a * (b * c)) by ring; - rewrite Z.div_add' by auto - | |- context [ (_ + ?a * ?b * ?c) mod ?b ] => - replace (a * b * c) with (a * c * b) by ring; - rewrite Z.mod_add by auto using ZUtil.Z.positive_is_nonzero - | _ => reflexivity - end. - } + reflexivity. } { repeat match goal with | _ => progress replace (Datatypes.length inp) with n by omega | _ => progress replace i with n by omega - | _ => rewrite nth_default_cons | _ => rewrite sum_cons | _ => rewrite flatten_column_mod | _ => erewrite flatten_div by eauto - | _ => progress autorewrite with natsimplify + | _ => rewrite <-Z.div_add' by auto + | _ => rewrite Z.mod_pull_div by auto using Z.lt_le_incl, Z.gt_lt + | _ => rewrite Z.mul_div_eq', weight_multiples by auto + | _ => progress autorewrite with push_nth_default natsimplify end. - rewrite Z.div_add' by auto. + autorewrite with zsimplify. reflexivity. } } Qed. @@ -1039,12 +1025,7 @@ Module Columns. intros; subst n3; cbv [mul_converted]. erewrite flatten_partitions by (auto; distr_length). autorewrite with distr_length push_eval. - pose proof (w_positive 1). - apply Z.mod_small. - split; [ solve[Z.zero_bounds] | ]. - apply Z.div_lt_upper_bound; [omega|]. - rewrite Z.mul_div_eq_full by auto. - rewrite w_multiples. omega. + rewrite Z.mod_small; omega. Qed. (* shortcut definition for convert-mul-convert for cases when we are halving the bitwidth before multiplying. *) @@ -1303,6 +1284,10 @@ Module Rows. rewrite Columns.weight_multiples_full by (auto; omega); autorewrite with zsimplify_fast. + (* TODO: figure out where to put this and weight_multiples_full *) + Lemma weight_divides_full j i : (i <= j)%nat -> weight j / weight i > 0. + Proof. auto using Z.div_positive_gt_0, Columns.weight_multiples_full. Qed. + Lemma sum_rows'_partitions row1 : forall n m start_state row2 row1' row2', length (fst start_state) = m -> length row1 = n -> length row2 = n -> @@ -1311,10 +1296,10 @@ Module Rows. let eval := Positional.eval weight in snd start_state = (eval m row1' + eval m row2') / weight m -> (forall j, (j < m)%nat -> - nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') / (weight j)) mod (fw j)) -> + nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) -> forall i, (i < nm)%nat -> nth_default 0 (fst (sum_rows' start_state row1 row2)) i - = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / weight i) mod (fw i). + = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i). Proof. cbv [sum_rows']. induction row1 as [|x1 row1]; intros; @@ -1344,21 +1329,22 @@ Module Rows. rewrite <-Z.div_div by auto using Z.gt_lt. autorewrite with zsimplify. f_equal; ring. } - { mul_div_weights m j. - ring_simplify_subterms. autorewrite with zsimplify. - mul_div_weights m (S j). - rewrite Columns.Z_divide_div_mul_exact' by (try apply Z.mod_divide; auto). - push_Zmod. autorewrite with zsimplify_fast. + { mul_div_weights m (S j). + push_Zmod. autorewrite with zsimplify. lia. } { replace j with m by omega. - autorewrite with push_nth_default natsimplify zsimplify. - f_equal; ring. } + autorewrite with push_nth_default natsimplify. + rewrite <-!Z.div_add' by auto. + rewrite Z.mod_pull_div, Z.mul_div_eq' by (auto using Z.lt_le_incl, Z.gt_lt). + rewrite weight_multiples. + autorewrite with zsimplify_fast. + repeat (f_equal; try ring). } Qed. Lemma sum_rows_partitions row1: forall row2 n i, length row1 = n -> length row2 = n -> (i < n)%nat -> nth_default 0 (fst (sum_rows row1 row2)) i - = ((Positional.eval weight n row1 + Positional.eval weight n row2) / weight i) mod (fw i). + = ((Positional.eval weight n row1 + Positional.eval weight n row2) mod weight (S i)) / (weight i). Proof. cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). rewrite <-(app_nil_l row1), <-(app_nil_l row2). @@ -1390,7 +1376,7 @@ Module Rows. P bs (fold_right f a0 bs). Proof. induction bs; intros; autorewrite with push_fold_right; auto using in_eq,in_cons. Qed. - Lemma flatten'_div_mod n start_state inp: + Lemma flatten'_div_mod_length n start_state inp: length (fst start_state) = n -> (forall row, In row inp -> length row = n) -> length (fst (flatten' start_state inp)) = n @@ -1442,7 +1428,80 @@ Module Rows. autorewrite with push_eval. split; f_equal; ring. } { autorewrite with push_eval. - apply flatten'_div_mod; auto. + apply flatten'_div_mod_length; auto. + congruence. } + Qed. + + Lemma length_flatten' n start_state inp : + length (fst start_state) = n -> + (forall row, In row inp -> length row = n) -> + length (fst (flatten' start_state inp)) = n. + Proof. apply flatten'_div_mod_length. Qed. + Hint Rewrite length_flatten' : distr_length. + + Lemma length_flatten n inp : + (forall row, In row inp -> length row = n) -> + inp <> nil -> + length (fst (flatten inp)) = n. + Proof. + intros. + destruct inp; [congruence |]; destruct inp; + repeat match goal with + | _ => progress intros + | _ => progress cbn [hd tl] in * + | _ => progress autorewrite with cancel_pair + | _ => apply flatten'_div_mod_length + | H: _ |- _ => apply in_inv in H; destruct H + | _ => solve [auto] + end; + subst row; distr_length; auto. + Qed. Hint Rewrite length_flatten : distr_length. + + Lemma flatten'_cons state x inp : + flatten' state (x :: inp) + = (fst (sum_rows x (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows x (fst (flatten' state inp)))). + Proof. reflexivity. Qed. + + Lemma flatten'_partitions n start_state inp: + length (fst start_state) = n -> + (forall row, In row inp -> length row = n) -> + inp <> nil -> + forall i, (i < n)%nat -> + nth_default 0 (fst (flatten' start_state inp)) i + = ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i). + Proof. + intro. + induction inp; intros; [congruence|]. + rewrite flatten'_cons. + autorewrite with push_fold_right cancel_pair. + rewrite sum_rows_partitions with (n:=n) by (rewrite ?length_flatten' with (n:=n); auto). + destruct (dec (inp = nil)). + { subst inp. cbn. repeat (f_equal; try ring). } + { edestruct flatten'_div_mod_length with (inp:=inp) as [Hlen [Hmod Hdiv] ]; eauto. + rewrite Hmod. autorewrite with push_eval. + rewrite Z.add_mod_full. mul_div_weights n (S i). + rewrite Z.rem_mul_r by auto using Z.gt_lt, weight_divides_full. + autorewrite with zsimplify. pull_Zmod. + repeat (f_equal; try ring). } + Qed. + + Lemma flatten_partitions inp n : + (forall row, In row inp -> length row = n) -> + forall i, (i < n)%nat -> + nth_default 0 (fst (flatten inp)) i = (eval n inp mod weight (S i)) / (weight i). + Proof. + intros; cbv [flatten]. + destruct inp; [|destruct inp]; cbn [hd tl]. + { cbn. autorewrite with push_eval push_nth_default zsimplify. reflexivity. } + { cbn. + match goal with H : forall r, In r [?x] -> _ |- _ => + specialize (H x ltac:(auto)); rewrite H + end. + rewrite sum_rows_partitions with (n:=n) by distr_length. + autorewrite with push_eval. + repeat (f_equal; try ring). } + { autorewrite with push_eval. + apply flatten'_partitions; auto. congruence. } Qed. |