aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-03-07 15:51:49 +0100
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-03 09:00:55 -0400
commit8690fc7b8054901178a0756d1fb6f47342a2bd55 (patch)
tree192844dc24b5b0ae0c67e07ab98c209a3830be4f /src
parent0be8c543e56273ba58d919836e4053e52e87b90a (diff)
organize proofs into sections
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v427
1 files changed, 216 insertions, 211 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 61df7b12f..48661f4c6 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -679,35 +679,29 @@ Module Columns.
apply Positional.eval_snoc; distr_length.
Qed. Hint Rewrite eval_snoc using (solve [distr_length]) : push_eval.
- Section flatten_column.
- Context (fw : Z). (* maximum size of the result *)
-
- (* Outputs (sum, carry) *)
- Definition flatten_column (digit: list Z) : (Z * Z) :=
- list_rect (fun _ => (Z * Z)%type) (0,0)
- (fun xx tl flatten_column_tl =>
- list_rect
- (fun _ => (Z * Z)%type) (xx mod fw, xx / fw)
- (fun yy tl' _ =>
- list_rect
- (fun _ => (Z * Z)%type) (dlet_nd x := xx in dlet_nd y := yy in Z.add_get_carry_full fw x y)
- (fun _ _ _ =>
- dlet_nd x := xx in
- dlet_nd rec := flatten_column_tl in (* recursively get the sum and carry *)
- dlet_nd sum_carry := Z.add_get_carry_full fw x (fst rec) in (* add the new value to the sum *)
- dlet_nd carry' := snd sum_carry + snd rec in (* add the two carries together *)
- (fst sum_carry, carry'))
- tl')
- tl)
- digit.
- End flatten_column.
-
- Definition flatten_step (digit:list Z) (acc_carry:list Z * Z) : list Z * Z :=
- dlet sum_carry := flatten_column (weight (S (length (fst acc_carry))) / weight (length (fst acc_carry))) (snd acc_carry::digit) in
- (fst acc_carry ++ fst sum_carry :: nil, snd sum_carry).
-
- Definition flatten (xs : list (list Z)) : list Z * Z :=
- fold_right (fun a b => flatten_step a b) (nil,0) (rev xs).
+ (* TODO: move out of Columns? *)
+ Section Weight.
+ Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0.
+ Proof.
+ induction j; intros;
+ repeat match goal with
+ | _ => rewrite Nat.add_succ_r
+ | _ => rewrite IHj
+ | |- context [weight (S ?x) mod weight _] =>
+ rewrite (Z.div_mod (weight (S x)) (weight x)), weight_multiples by auto
+ | _ => progress autorewrite with push_Zmod natsimplify zsimplify_fast
+ | _ => reflexivity
+ end.
+ Qed.
+ Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0.
+ Proof.
+ intros; replace j with (i + (j - i))%nat by omega.
+ apply weight_multiples_full'.
+ Qed.
+
+ Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i).
+ Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed.
+ End Weight.
Lemma list_rect_to_match A (P:list A -> Type) (Pnil: P []) (PS: forall a tl, P (a :: tl)) ls :
@list_rect A P Pnil (fun a tl _ => PS a tl) ls = match ls with
@@ -718,17 +712,47 @@ Module Columns.
Hint Rewrite <- Z.div_add' using omega : pull_Zdiv.
-
- Local Ltac cases :=
+ Ltac cases :=
match goal with
| |- _ /\ _ => split
| H: _ /\ _ |- _ => destruct H
| H: _ \/ _ |- _ => destruct H
| _ => progress break_match; try discriminate
end.
-
- Local Ltac push_fast :=
- repeat match goal with
+
+ Section Flatten.
+ Section flatten_column.
+ Context (fw : Z). (* maximum size of the result *)
+
+ (* Outputs (sum, carry) *)
+ Definition flatten_column (digit: list Z) : (Z * Z) :=
+ list_rect (fun _ => (Z * Z)%type) (0,0)
+ (fun xx tl flatten_column_tl =>
+ list_rect
+ (fun _ => (Z * Z)%type) (xx mod fw, xx / fw)
+ (fun yy tl' _ =>
+ list_rect
+ (fun _ => (Z * Z)%type) (dlet_nd x := xx in dlet_nd y := yy in Z.add_get_carry_full fw x y)
+ (fun _ _ _ =>
+ dlet_nd x := xx in
+ dlet_nd rec := flatten_column_tl in (* recursively get the sum and carry *)
+ dlet_nd sum_carry := Z.add_get_carry_full fw x (fst rec) in (* add the new value to the sum *)
+ dlet_nd carry' := snd sum_carry + snd rec in (* add the two carries together *)
+ (fst sum_carry, carry'))
+ tl')
+ tl)
+ digit.
+ End flatten_column.
+
+ Definition flatten_step (digit:list Z) (acc_carry:list Z * Z) : list Z * Z :=
+ dlet sum_carry := flatten_column (weight (S (length (fst acc_carry))) / weight (length (fst acc_carry))) (snd acc_carry::digit) in
+ (fst acc_carry ++ fst sum_carry :: nil, snd sum_carry).
+
+ Definition flatten (xs : list (list Z)) : list Z * Z :=
+ fold_right (fun a b => flatten_step a b) (nil,0) (rev xs).
+
+ Ltac push_fast :=
+ repeat match goal with
| _ => progress cbv [Let_In]
| |- context [list_rect _ _ _ ?ls] => rewrite list_rect_to_match; destruct ls
| _ => progress (unfold flatten_step in *; fold flatten_step in * )
@@ -740,195 +764,176 @@ Module Columns.
| _ => congruence
| _ => progress cases
end.
- Local Ltac push :=
- repeat match goal with
- | _ => progress push_fast
- | _ => progress autorewrite with cancel_pair to_div_mod
- | _ => progress autorewrite with push_sum push_fold_right push_nth_default in *
- | _ => progress autorewrite with pull_Zmod pull_Zdiv zsimplify_fast
- | _ => progress autorewrite with list distr_length push_eval
- end.
-
-
-
- Lemma flatten_column_mod fw (xs : list Z) :
- fst (flatten_column fw xs) = sum xs mod fw.
- Proof.
- induction xs; simpl flatten_column; cbv [Let_In];
+ Ltac push :=
repeat match goal with
- | _ => rewrite IHxs
- | _ => progress push
+ | _ => progress push_fast
+ | _ => progress autorewrite with cancel_pair to_div_mod
+ | _ => progress autorewrite with push_sum push_fold_right push_nth_default in *
+ | _ => progress autorewrite with pull_Zmod pull_Zdiv zsimplify_fast
+ | _ => progress autorewrite with list distr_length push_eval
end.
- Qed. Hint Rewrite flatten_column_mod : to_div_mod.
+
+ Lemma flatten_column_mod fw (xs : list Z) :
+ fst (flatten_column fw xs) = sum xs mod fw.
+ Proof.
+ induction xs; simpl flatten_column; cbv [Let_In];
+ repeat match goal with
+ | _ => rewrite IHxs
+ | _ => progress push
+ end.
+ Qed. Hint Rewrite flatten_column_mod : to_div_mod.
- Lemma flatten_column_div fw (xs : list Z) (fw_nz : fw <> 0) :
- snd (flatten_column fw xs) = sum xs / fw.
- Proof.
- induction xs; simpl flatten_column; cbv [Let_In];
- repeat match goal with
- | _ => rewrite IHxs
- | _ => rewrite Z.mul_div_eq_full by omega
- | _ => progress push
- end.
- Qed. Hint Rewrite flatten_column_div using auto with zarith : to_div_mod.
-
- (* helper for some of the modular logic in flatten *)
- Lemma flatten_mod_step a b c d: 0 < a -> 0 < b ->
- c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b).
- Proof. intros; rewrite Z.rem_mul_r by omega. push_Zmod. push. Qed.
-
- Lemma flatten_div_step a b c d : 0 < a -> 0 < b ->
- (c / a + d) / b = (a * d + c) / (a * b).
- Proof. intros; push. Qed.
-
- Hint Rewrite Positional.eval_nil : push_eval.
- Hint Resolve Z.gt_lt.
-
- Lemma length_flatten_step digit state :
- length (fst (flatten_step digit state)) = S (length (fst state)).
- Proof. cbv [flatten_step]; push. Qed.
- Hint Rewrite length_flatten_step : distr_length.
- Lemma length_flatten inp : length (fst (flatten inp)) = length inp.
- Proof. cbv [flatten]. induction inp using rev_ind; push. Qed.
- Hint Rewrite length_flatten : distr_length.
-
- Lemma flatten_div_mod n inp :
- length inp = n ->
- (Positional.eval weight n (fst (flatten inp))
- = (eval n inp) mod (weight n))
- /\ (snd (flatten inp) = eval n inp / weight n).
- Proof.
- (* to make the invariant take the right form, we make everything depend on output length, not input length *)
- intro. subst n. rewrite <-(length_flatten inp). cbv [flatten].
- induction inp using rev_ind; intros; [push|].
- repeat match goal with
- | _ => rewrite Nat.add_1_r
- | _ => progress (fold (flatten inp) in * )
- | _ => erewrite Positional.eval_snoc by (distr_length; reflexivity)
- | H: _ = _ mod (weight _) |- _ => rewrite H
- | H: _ = _ / (weight _) |- _ => rewrite H
- | _ => progress rewrite ?flatten_mod_step, ?flatten_div_step by auto
- | _ => progress autorewrite with cancel_pair to_div_mod push_sum list push_fold_right push_eval
- | _ => progress (distr_length; push_fast)
- end.
- Qed.
+ Lemma flatten_column_div fw (xs : list Z) (fw_nz : fw <> 0) :
+ snd (flatten_column fw xs) = sum xs / fw.
+ Proof.
+ induction xs; simpl flatten_column; cbv [Let_In];
+ repeat match goal with
+ | _ => rewrite IHxs
+ | _ => rewrite Z.mul_div_eq_full by omega
+ | _ => progress push
+ end.
+ Qed. Hint Rewrite flatten_column_div using auto with zarith : to_div_mod.
- Lemma flatten_mod {n} inp :
- length inp = n ->
- (Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)).
- Proof. apply flatten_div_mod. Qed.
- Hint Rewrite @flatten_mod : push_eval.
+ (* helper for some of the modular logic in flatten *)
+ Lemma flatten_mod_step a b c d: 0 < a -> 0 < b ->
+ c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b).
+ Proof. intros; rewrite Z.rem_mul_r by omega. push_Zmod. push. Qed.
- Lemma flatten_div {n} inp :
- length inp = n -> snd (flatten inp) = eval n inp / weight n.
- Proof. apply flatten_div_mod. Qed.
- Hint Rewrite @flatten_div : push_eval.
+ Lemma flatten_div_step a b c d : 0 < a -> 0 < b ->
+ (c / a + d) / b = (a * d + c) / (a * b).
+ Proof. intros; push. Qed.
- (* nils *)
- Definition nils n : list (list Z) := List.repeat nil n.
- Lemma length_nils n : length (nils n) = n. Proof. cbv [nils]. distr_length. Qed.
- Hint Rewrite length_nils : distr_length.
- Lemma eval_nils n : eval n (nils n) = 0.
- Proof.
- erewrite <-Positional.eval_zeros by eauto.
- cbv [eval nils]; rewrite List.map_repeat; reflexivity.
- Qed. Hint Rewrite eval_nils : push_eval.
-
- (* cons_to_nth *)
- Definition cons_to_nth i x (xs : list (list Z)) : list (list Z) :=
- ListUtil.update_nth i (fun y => cons x y) xs.
- Lemma length_cons_to_nth i x xs : length (cons_to_nth i x xs) = length xs.
- Proof. cbv [cons_to_nth]. distr_length. Qed.
- Hint Rewrite length_cons_to_nth : distr_length.
- Lemma cons_to_nth_add_to_nth xs : forall i x,
- map sum (cons_to_nth i x xs) = Positional.add_to_nth i x (map sum xs).
- Proof.
- cbv [cons_to_nth]; induction xs as [|? ? IHxs];
- intros i x; destruct i; simpl; rewrite ?IHxs; reflexivity.
- Qed.
- Lemma eval_cons_to_nth n i x xs : (i < length xs)%nat -> length xs = n ->
- eval n (cons_to_nth i x xs) = weight i * x + eval n xs.
- Proof using Type.
- cbv [eval]; intros. rewrite cons_to_nth_add_to_nth.
- apply Positional.eval_add_to_nth; distr_length.
- Qed. Hint Rewrite eval_cons_to_nth using (solve [distr_length]) : push_eval.
-
- Hint Rewrite Positional.eval_zeros : push_eval.
- Hint Rewrite Positional.length_from_associational : distr_length.
- Hint Rewrite Positional.eval_add_to_nth using (solve [distr_length]): push_eval.
-
- (* from_associational *)
- Definition from_associational n (p:list (Z*Z)) : list (list Z) :=
- List.fold_right (fun t ls =>
- let p := Positional.place weight t (pred n) in
- cons_to_nth (fst p) (snd p) ls ) (nils n) p.
- Lemma length_from_associational n p : length (from_associational n p) = n.
- Proof. cbv [from_associational]. apply fold_right_invariant; intros; distr_length. Qed.
- Hint Rewrite length_from_associational: distr_length.
- Lemma eval_from_associational n p (n_nonzero:n<>0%nat\/p=nil):
- eval n (from_associational n p) = Associational.eval p.
- Proof.
- erewrite <-Positional.eval_from_associational by eauto.
- induction p; push.
- cbv [from_associational Positional.from_associational]; autorewrite with push_fold_right.
- fold (from_associational n p); fold (Positional.from_associational weight n p).
- match goal with |- context [Positional.place _ ?x ?n] =>
- pose proof (Positional.place_in_range weight x n) end.
- repeat match goal with
- | _ => rewrite Nat.succ_pred in * by auto
- | _ => rewrite IHp by auto
- | _ => progress push
- end.
- Qed.
+ Hint Rewrite Positional.eval_nil : push_eval.
+ Hint Resolve Z.gt_lt.
- Lemma flatten_snoc x inp : flatten (inp ++ [x]) = flatten_step x (flatten inp).
- Proof. cbv [flatten]. rewrite rev_unit. reflexivity. Qed.
+ Lemma length_flatten_step digit state :
+ length (fst (flatten_step digit state)) = S (length (fst state)).
+ Proof. cbv [flatten_step]; push. Qed.
+ Hint Rewrite length_flatten_step : distr_length.
+ Lemma length_flatten inp : length (fst (flatten inp)) = length inp.
+ Proof. cbv [flatten]. induction inp using rev_ind; push. Qed.
+ Hint Rewrite length_flatten : distr_length.
- Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0.
- Proof.
- induction j; intros;
+ Lemma flatten_div_mod n inp :
+ length inp = n ->
+ (Positional.eval weight n (fst (flatten inp))
+ = (eval n inp) mod (weight n))
+ /\ (snd (flatten inp) = eval n inp / weight n).
+ Proof.
+ (* to make the invariant take the right form, we make everything depend on output length, not input length *)
+ intro. subst n. rewrite <-(length_flatten inp). cbv [flatten].
+ induction inp using rev_ind; intros; [push|].
repeat match goal with
- | _ => rewrite Nat.add_succ_r
- | _ => rewrite IHj
- | |- context [weight (S ?x) mod weight _] =>
- rewrite (Z.div_mod (weight (S x)) (weight x)), weight_multiples by auto
- | _ => progress autorewrite with push_Zmod natsimplify zsimplify_fast
- | _ => reflexivity
+ | _ => rewrite Nat.add_1_r
+ | _ => progress (fold (flatten inp) in * )
+ | _ => erewrite Positional.eval_snoc by (distr_length; reflexivity)
+ | H: _ = _ mod (weight _) |- _ => rewrite H
+ | H: _ = _ / (weight _) |- _ => rewrite H
+ | _ => progress rewrite ?flatten_mod_step, ?flatten_div_step by auto
+ | _ => progress autorewrite with cancel_pair to_div_mod push_sum list push_fold_right push_eval
+ | _ => progress (distr_length; push_fast)
end.
- Qed.
-
- Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0.
- Proof.
- intros; replace j with (i + (j - i))%nat by omega.
- apply weight_multiples_full'.
- Qed.
+ Qed.
- Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i).
- Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed.
+ Lemma flatten_mod {n} inp :
+ length inp = n ->
+ (Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)).
+ Proof. apply flatten_div_mod. Qed.
+ Hint Rewrite @flatten_mod : push_eval.
- (* TODO: move to ZUtil *)
- Lemma Z_divide_div_mul_exact' a b c : b <> 0 -> (b | a) -> a * c / b = c * (a / b).
- Proof. intros. rewrite Z.mul_comm. auto using Z.divide_div_mul_exact. Qed.
+ Lemma flatten_div {n} inp :
+ length inp = n -> snd (flatten inp) = eval n inp / weight n.
+ Proof. apply flatten_div_mod. Qed.
+ Hint Rewrite @flatten_div : push_eval.
- Lemma flatten_partitions inp:
- forall n i, length inp = n -> (i < n)%nat ->
- nth_default 0 (fst (flatten inp)) i = ((eval n inp) mod (weight (S i))) / weight i.
- Proof.
- induction inp using rev_ind; intros; destruct n; distr_length.
- rewrite flatten_snoc.
- push; distr_length;
- [rewrite IHinp with (n:=n) by omega; rewrite (weight_div_mod n (S i)) by omega; push_Zmod; push |].
- repeat match goal with
- | _ => progress replace (length inp) with n by omega
- | _ => progress replace i with n by omega
- | _ => progress push
- | _ => erewrite flatten_div by eauto
- | _ => rewrite <-Z.div_add' by auto
- | _ => rewrite Z.mul_div_eq' by auto
- | _ => rewrite Z.mod_pull_div by auto using Z.lt_le_incl
- | _ => progress autorewrite with push_nth_default natsimplify
- end.
- Qed.
+ Lemma flatten_snoc x inp : flatten (inp ++ [x]) = flatten_step x (flatten inp).
+ Proof. cbv [flatten]. rewrite rev_unit. reflexivity. Qed.
+
+ (* TODO: move to ZUtil *)
+ Lemma Z_divide_div_mul_exact' a b c : b <> 0 -> (b | a) -> a * c / b = c * (a / b).
+ Proof. intros. rewrite Z.mul_comm. auto using Z.divide_div_mul_exact. Qed.
+
+ Lemma flatten_partitions inp:
+ forall n i, length inp = n -> (i < n)%nat ->
+ nth_default 0 (fst (flatten inp)) i = ((eval n inp) mod (weight (S i))) / weight i.
+ Proof.
+ induction inp using rev_ind; intros; destruct n; distr_length.
+ rewrite flatten_snoc.
+ push; distr_length;
+ [rewrite IHinp with (n:=n) by omega; rewrite (weight_div_mod n (S i)) by omega; push_Zmod; push |].
+ repeat match goal with
+ | _ => progress replace (length inp) with n by omega
+ | _ => progress replace i with n by omega
+ | _ => progress push
+ | _ => erewrite flatten_div by eauto
+ | _ => rewrite <-Z.div_add' by auto
+ | _ => rewrite Z.mul_div_eq' by auto
+ | _ => rewrite Z.mod_pull_div by auto using Z.lt_le_incl
+ | _ => progress autorewrite with push_nth_default natsimplify
+ end.
+ Qed.
+ End Flatten.
+
+ Section FromAssociational.
+ (* nils *)
+ Definition nils n : list (list Z) := List.repeat nil n.
+ Lemma length_nils n : length (nils n) = n. Proof. cbv [nils]. distr_length. Qed.
+ Hint Rewrite length_nils : distr_length.
+ Lemma eval_nils n : eval n (nils n) = 0.
+ Proof.
+ erewrite <-Positional.eval_zeros by eauto.
+ cbv [eval nils]; rewrite List.map_repeat; reflexivity.
+ Qed. Hint Rewrite eval_nils : push_eval.
+
+ (* cons_to_nth *)
+ Definition cons_to_nth i x (xs : list (list Z)) : list (list Z) :=
+ ListUtil.update_nth i (fun y => cons x y) xs.
+ Lemma length_cons_to_nth i x xs : length (cons_to_nth i x xs) = length xs.
+ Proof. cbv [cons_to_nth]. distr_length. Qed.
+ Hint Rewrite length_cons_to_nth : distr_length.
+ Lemma cons_to_nth_add_to_nth xs : forall i x,
+ map sum (cons_to_nth i x xs) = Positional.add_to_nth i x (map sum xs).
+ Proof.
+ cbv [cons_to_nth]; induction xs as [|? ? IHxs];
+ intros i x; destruct i; simpl; rewrite ?IHxs; reflexivity.
+ Qed.
+ Lemma eval_cons_to_nth n i x xs : (i < length xs)%nat -> length xs = n ->
+ eval n (cons_to_nth i x xs) = weight i * x + eval n xs.
+ Proof using Type.
+ cbv [eval]; intros. rewrite cons_to_nth_add_to_nth.
+ apply Positional.eval_add_to_nth; distr_length.
+ Qed. Hint Rewrite eval_cons_to_nth using (solve [distr_length]) : push_eval.
+
+ Hint Rewrite Positional.eval_zeros : push_eval.
+ Hint Rewrite Positional.length_from_associational : distr_length.
+ Hint Rewrite Positional.eval_add_to_nth using (solve [distr_length]): push_eval.
+
+ (* from_associational *)
+ Definition from_associational n (p:list (Z*Z)) : list (list Z) :=
+ List.fold_right (fun t ls =>
+ let p := Positional.place weight t (pred n) in
+ cons_to_nth (fst p) (snd p) ls ) (nils n) p.
+ Lemma length_from_associational n p : length (from_associational n p) = n.
+ Proof. cbv [from_associational]. apply fold_right_invariant; intros; distr_length. Qed.
+ Hint Rewrite length_from_associational: distr_length.
+ Lemma eval_from_associational n p (n_nonzero:n<>0%nat\/p=nil):
+ eval n (from_associational n p) = Associational.eval p.
+ Proof.
+ erewrite <-Positional.eval_from_associational by eauto.
+ induction p; [ autorewrite with push_eval; congruence |].
+ cbv [from_associational Positional.from_associational]; autorewrite with push_fold_right.
+ fold (from_associational n p); fold (Positional.from_associational weight n p).
+ match goal with |- context [Positional.place _ ?x ?n] =>
+ pose proof (Positional.place_in_range weight x n) end.
+ repeat match goal with
+ | _ => rewrite Nat.succ_pred in * by auto
+ | _ => rewrite IHp by auto
+ | _ => progress autorewrite with push_eval
+ | _ => progress cases
+ | _ => congruence
+ end.
+ Qed.
+ End FromAssociational.
Section mul.
Definition mul s n m (p q : list Z) : list Z :=
@@ -951,9 +956,9 @@ Module Rows.
Local Notation rows := (list (list Z)) (only parsing).
Local Notation cols := (list (list Z)) (only parsing).
+
Hint Rewrite Positional.eval_nil Positional.eval0 @Positional.eval_snoc
Columns.eval_nil Columns.eval_snoc using (auto; solve [distr_length]) : push_eval.
-
Hint Resolve in_eq in_cons.
Definition eval n (inp : rows) :=