diff options
author | Jason Gross <jagro@google.com> | 2016-07-06 15:41:14 -0700 |
---|---|---|
committer | Jason Gross <jagro@google.com> | 2016-07-06 15:41:14 -0700 |
commit | 56e58b21bb80e7b460b0010a8b307f97c3fefea4 (patch) | |
tree | 436ce6f5151b59b2452c39de699db695c33e2f4f /src/Util/ListUtil.v | |
parent | 6ddfe39affef2f47836b03d49fc6e4b9266e75c8 (diff) |
Add [update_nth] to ListUtil, change [set_nth]
Define [set_nth] in terms of [update_nth]
Diffstat (limited to 'src/Util/ListUtil.v')
-rw-r--r-- | src/Util/ListUtil.v | 241 |
1 files changed, 195 insertions, 46 deletions
diff --git a/src/Util/ListUtil.v b/src/Util/ListUtil.v index 060372341..50927c2a4 100644 --- a/src/Util/ListUtil.v +++ b/src/Util/ListUtil.v @@ -1,8 +1,27 @@ Require Import Coq.Lists.List. Require Import Coq.omega.Omega. Require Import Coq.Arith.Peano_dec. +Require Import Coq.Classes.Morphisms. Require Import Crypto.Tactics.VerdiTactics. Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Crypto.Util.NatUtil. + +Create HintDb distr_length discriminated. +Create HintDb simpl_set_nth discriminated. +Create HintDb simpl_update_nth discriminated. + +Hint Rewrite + @app_length + @rev_length + @map_length + @seq_length + @fold_left_length + @split_length_l + @split_length_r + @firstn_length + @combine_length + @prod_length + : distr_length. Definition sum_firstn l n := fold_right Z.add 0%Z (firstn n l). @@ -106,42 +125,130 @@ Proof. induction xs; boring. Qed. -(* xs[n] := x *) -Fixpoint set_nth {T} n x (xs:list T) {struct n} := +(* xs[n] := f xs[n] *) +Fixpoint update_nth {T} n f (xs:list T) {struct n} := match n with | O => match xs with | nil => nil - | x'::xs' => x::xs' + | x'::xs' => f x'::xs' end | S n' => match xs with | nil => nil - | x'::xs' => x'::set_nth n' x xs' + | x'::xs' => x'::update_nth n' f xs' end end. -Lemma nth_set_nth : forall m {T} (xs:list T) (n:nat) (x x':T), - nth_error (set_nth m x xs) n = - if eq_nat_dec n m - then (if lt_dec n (length xs) then Some x else None) - else nth_error xs n. +(* xs[n] := x *) +Definition set_nth {T} n x (xs:list T) + := update_nth n (fun _ => x) xs. + +Lemma unfold_set_nth {T} n x + : forall xs, + @set_nth T n x xs + = match n with + | O => match xs with + | nil => nil + | x'::xs' => x::xs' + end + | S n' => match xs with + | nil => nil + | x'::xs' => x'::set_nth n' x xs' + end + end. Proof. - induction m. + induction n; destruct xs; reflexivity. +Qed. + +Lemma simpl_set_nth_0 {T} x + : forall xs, + @set_nth T 0 x xs + = match xs with + | nil => nil + | x'::xs' => x::xs' + end. +Proof. intro; rewrite unfold_set_nth; reflexivity. Qed. + +Lemma simpl_set_nth_S {T} x n + : forall xs, + @set_nth T (S n) x xs + = match xs with + | nil => nil + | x'::xs' => x'::set_nth n x xs' + end. +Proof. intro; rewrite unfold_set_nth; reflexivity. Qed. + +Hint Rewrite @simpl_set_nth_S @simpl_set_nth_0 : simpl_set_nth. + +Lemma update_nth_ext {T} f g n + : forall xs, (forall x, nth_error xs n = Some x -> f x = g x) + -> @update_nth T n f xs = @update_nth T n g xs. +Proof. + induction n; destruct xs; simpl; intros H; + try rewrite IHn; try rewrite H; + try congruence; trivial. +Qed. + +Global Instance update_nth_Proper {T} + : Proper (eq ==> pointwise_relation _ eq ==> eq ==> eq) (@update_nth T). +Proof. repeat intro; subst; apply update_nth_ext; trivial. Qed. + +Lemma update_nth_id_eq_specific {T} f n + : forall (xs : list T) (H : forall x, nth_error xs n = Some x -> f x = x), + update_nth n f xs = xs. +Proof. + induction n; destruct xs; simpl; intros; + try rewrite IHn; try rewrite H; + try congruence; assumption. +Qed. - destruct n, xs; auto. +Hint Rewrite @update_nth_id_eq_specific using congruence : simpl_update_nth. - intros; destruct xs, n; auto. - simpl; unfold error; match goal with - [ |- None = if ?x then None else None ] => destruct x - end; auto. +Lemma update_nth_id_eq : forall {T} f (H : forall x, f x = x) n (xs : list T), + update_nth n f xs = xs. +Proof. intros; apply update_nth_id_eq_specific; trivial. Qed. - simpl nth_error; erewrite IHm by auto; clear IHm. - destruct (eq_nat_dec n m), (eq_nat_dec (S n) (S m)); nth_tac. +Hint Rewrite @update_nth_id_eq using congruence : simpl_update_nth. + +Lemma update_nth_id : forall {T} n (xs : list T), + update_nth n (fun x => x) xs = xs. +Proof. intros; apply update_nth_id_eq; trivial. Qed. + +Hint Rewrite @update_nth_id : simpl_update_nth. + +Lemma nth_update_nth : forall m {T} (xs:list T) (n:nat) (f:T -> T), + nth_error (update_nth m f xs) n = + if eq_nat_dec n m + then option_map f (nth_error xs n) + else nth_error xs n. +Proof. + induction m. + { destruct n, xs; auto. } + { destruct xs, n; intros; simpl; auto; + [ | rewrite IHm ]; clear IHm; + edestruct eq_nat_dec; reflexivity. } Qed. -Lemma length_set_nth : forall {T} i (x:T) xs, length (set_nth i x xs) = length xs. +Lemma length_update_nth : forall {T} i f (xs:list T), length (update_nth i f xs) = length xs. +Proof. induction i, xs; boring. Qed. +Lemma nth_set_nth : forall m {T} (xs:list T) (n:nat) x, + nth_error (set_nth m x xs) n = + if eq_nat_dec n m + then (if lt_dec n (length xs) then Some x else None) + else nth_error xs n. +Proof. + intros; unfold set_nth; rewrite nth_update_nth. + destruct (nth_error xs n) eqn:?, (lt_dec n (length xs)) as [p|p]; + rewrite <- nth_error_Some in p; + solve [ reflexivity + | exfalso; apply p; congruence ]. +Qed. + +Lemma length_set_nth : forall {T} i x (xs:list T), length (set_nth i x xs) = length xs. +Proof. intros; apply length_update_nth. Qed. + Lemma nth_error_length_exists_value : forall {A} (i : nat) (xs : list A), (i < length xs)%nat -> exists x, nth_error xs i = Some x. Proof. @@ -208,11 +315,50 @@ Lemma set_nth_equiv_splice_nth: forall {T} n x (xs:list T), then splice_nth n x xs else xs. Proof. - induction n; destruct xs; intros; simpl in *; - try (rewrite IHn; clear IHn); auto. + induction n; destruct xs; intros; + autorewrite with simpl_set_nth in *; simpl in *; + try (rewrite IHn; clear IHn); auto. break_if; break_if; auto; omega. Qed. +Lemma combine_update_nth : forall {A B} n f g (xs:list A) (ys:list B), + combine (update_nth n f xs) (update_nth n g ys) = + update_nth n (fun xy => (f (fst xy), g (snd xy))) (combine xs ys). +Proof. + induction n; destruct xs, ys; simpl; try rewrite IHn; reflexivity. +Qed. + +(* grumble, grumble, [rewrite] is bad at inferring the identity function, and constant functions *) +Ltac rewrite_rev_combine_update_nth := + let lem := match goal with + | [ |- appcontext[update_nth ?n (fun xy => (@?f xy, @?g xy)) (combine ?xs ?ys)] ] + => let f := match (eval cbv [fst] in (fun y x => f (x, y))) with + | fun _ => ?f => f + end in + let g := match (eval cbv [snd] in (fun x y => g (x, y))) with + | fun _ => ?g => g + end in + constr:(@combine_update_nth _ _ n f g xs ys) + end in + rewrite <- lem. + +Lemma combine_update_nth_l : forall {A B} n (f : A -> A) xs (ys:list B), + combine (update_nth n f xs) ys = + update_nth n (fun xy => (f (fst xy), snd xy)) (combine xs ys). +Proof. + intros ??? f xs ys. + etransitivity; [ | apply combine_update_nth with (g := fun x => x) ]. + rewrite update_nth_id; reflexivity. +Qed. + +Lemma combine_update_nth_r : forall {A B} n (g : B -> B) (xs:list A) (ys:list B), + combine xs (update_nth n g ys) = + update_nth n (fun xy => (fst xy, g (snd xy))) (combine xs ys). +Proof. + intros ??? g xs ys. + etransitivity; [ | apply combine_update_nth with (f := fun x => x) ]. + rewrite update_nth_id; reflexivity. +Qed. Lemma combine_set_nth : forall {A B} n (x:A) xs (ys:list B), combine (set_nth n x xs) ys = @@ -221,12 +367,12 @@ Lemma combine_set_nth : forall {A B} n (x:A) xs (ys:list B), | Some y => set_nth n (x,y) (combine xs ys) end. Proof. - (* TODO(andreser): this proof can totally be automated, but requires writing ltac that vets multiple hypotheses at once *) - induction n, xs, ys; nth_tac; try rewrite IHn; nth_tac; - try (f_equal; specialize (IHn x xs ys ); rewrite H in IHn; rewrite <- IHn); - try (specialize (nth_error_value_length _ _ _ _ H); omega). - assert (Some b0=Some b1) as HA by (rewrite <-H, <-H0; auto). - injection HA; intros; subst; auto. + intros; unfold set_nth; rewrite combine_update_nth_l. + nth_tac; + [ repeat rewrite_rev_combine_update_nth; apply f_equal2 + | assert (nth_error (combine xs ys) n = None) + by (apply nth_error_None; rewrite combine_length; omega * ) ]; + autorewrite with simpl_update_nth; reflexivity. Qed. Lemma nth_error_value_In : forall {T} n xs (x:T), @@ -461,42 +607,40 @@ Proof. reflexivity. Qed. +Lemma update_nth_cons : forall {T} f (u0 : T) us, update_nth 0 f (u0 :: us) = (f u0) :: us. +Proof. reflexivity. Qed. + Lemma set_nth_cons : forall {T} (x u0 : T) us, set_nth 0 x (u0 :: us) = x :: us. -Proof. - auto. -Qed. +Proof. intros; apply update_nth_cons. Qed. -Create HintDb distr_length discriminated. Hint Rewrite @nil_length0 @length_cons - @app_length - @rev_length - @map_length - @seq_length - @fold_left_length - @split_length_l - @split_length_r - @firstn_length @skipn_length - @combine_length - @prod_length + @length_update_nth @length_set_nth : distr_length. Ltac distr_length := autorewrite with distr_length in *; try solve [simpl in *; omega]. -Lemma cons_set_nth : forall {T} n (x y : T) us, - y :: set_nth n x us = set_nth (S n) x (y :: us). +Lemma cons_update_nth : forall {T} n f (y : T) us, + y :: update_nth n f us = update_nth (S n) f (y :: us). Proof. induction n; boring. Qed. -Lemma set_nth_nil : forall {T} n (x : T), set_nth n x nil = nil. +Lemma update_nth_nil : forall {T} n f, set_nth n f (@nil T) = @nil T. Proof. induction n; boring. Qed. +Lemma cons_set_nth : forall {T} n (x y : T) us, + y :: set_nth n x us = set_nth (S n) x (y :: us). +Proof. intros; apply cons_update_nth. Qed. + +Lemma set_nth_nil : forall {T} n (x : T), set_nth n x nil = nil. +Proof. intros; apply update_nth_nil. Qed. + Lemma nth_default_nil : forall {T} n (d : T), nth_default d nil n = d. Proof. induction n; boring. @@ -629,15 +773,20 @@ Proof. omega. Qed. -Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat -> - nth_default d (set_nth n x l) i = - if (eq_nat_dec i n) then x else nth_default d l i. +Lemma update_nth_nth_default : forall {A} (d:A) n f l i, (0 <= i < length l)%nat -> + nth_default d (update_nth n f l) i = + if (eq_nat_dec i n) then f (nth_default d l i) else nth_default d l i. Proof. induction n; (destruct l; [intros; simpl in *; omega | ]); simpl; destruct i; break_if; try omega; intros; try apply nth_default_cons; rewrite !nth_default_cons_S, ?IHn; try break_if; omega || reflexivity. Qed. +Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat -> + nth_default d (set_nth n x l) i = + if (eq_nat_dec i n) then x else nth_default d l i. +Proof. intros; apply update_nth_nth_default; assumption. Qed. + Lemma nth_default_preserves_properties : forall {A} (P : A -> Prop) l n d, (forall x, In x l -> P x) -> P d -> P (nth_default d l n). Proof. |