diff options
author | 2018-03-01 17:45:06 +0100 | |
---|---|---|
committer | 2018-04-03 09:00:55 -0400 | |
commit | 884de22a21f9ae7b8d2743271b66e750fc5d4ff2 (patch) | |
tree | 754d26012446ceac61ddf937a08e4a9d8f3c601f | |
parent | e6a306ceb5824b161c6e8688333cbde21a0ac8f4 (diff) |
rowwise flatten (more fleshed-out) and proof outline
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 203 |
1 files changed, 163 insertions, 40 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 2b3a767f9..f01fd29e0 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -952,46 +952,6 @@ Module Columns. reflexivity. } } Qed. - Section Rows. - - Let fw := (fun i => weight (S i) / weight i). - - Fixpoint finish_row (idx: nat) (carry:Z) (inp: list (list Z)) : list (list Z) * Z := - match inp with - | nil => (nil, carry) - | col :: inp' => - match col with - | nil => ([carry] :: inp', 0) - | x :: nil => - let sum_carry := Z.add_with_get_carry_full (fw idx) carry x 0 in - let rec := finish_row (S idx) (snd sum_carry) inp' in - ([(fst sum_carry)] :: fst rec, snd rec) - | x :: y :: tl => - let sum_carry := Z.add_with_get_carry_full (fw idx) carry x y in - let rec := finish_row (S idx) (snd sum_carry) inp' in - (((fst sum_carry) :: tl) :: fst rec, snd rec) - end - end. - - Fixpoint flatten_rowwise' (start_value : Z) (start_col : list Z) (other_cols : list (list Z)) : list Z * Z := - match start_col with - | nil => (start_value :: map sum other_cols, 0) - | x :: tl => - let sum_carry := Z.add_get_carry_full (fw 0%nat) start_value x in - let other_cols'_carry := finish_row 1%nat (snd sum_carry) other_cols in - let rec := flatten_rowwise' (fst sum_carry) tl (fst other_cols'_carry) in - (fst rec, snd rec + snd other_cols'_carry) - end. - - Definition flatten_rowwise (inp : list (list Z)) : list Z * Z := - match inp with - | (x :: tl) :: other_cols => - flatten_rowwise' x tl other_cols - | _ => (map sum inp, 0) - end. - - End Rows. - Section mul. Definition mul s n m (p q : list Z) : list Z := let p_a := Positional.to_associational weight n p in @@ -1096,6 +1056,169 @@ Module Columns. End mul_converted. End Columns. +Module Rows. + Section Rows. + Context (weight : nat->Z) + {weight_0 : weight 0%nat = 1} + {weight_nonzero : forall i, weight i <> 0} + {weight_positive : forall i, weight i > 0} + {weight_multiples : forall i, weight (S i) mod weight i = 0} + {weight_divides : forall i : nat, weight (S i) / weight i > 0}. + + Local Notation rows := (list (list Z)) (only parsing). + Local Notation cols := (list (list Z)) (only parsing). + Hint Rewrite Positional.eval_nil : push_eval. + Hint Resolve in_eq in_cons. + + Definition eval n (inp : rows) := + sum (map (Positional.eval weight n) inp). + Lemma eval_nil n : eval n nil = 0. + Proof. cbv [eval]. rewrite map_nil, sum_nil; reflexivity. Qed. + Hint Rewrite eval_nil : push_eval. + Lemma eval_cons n r inp : eval n (r :: inp) = Positional.eval weight n r + eval n inp. + Proof. cbv [eval]. rewrite map_cons, sum_cons; reflexivity. Qed. + Hint Rewrite eval_cons : push_eval. + + Definition extract_row (inp : cols) : cols * list Z := + fold_right (fun col state => + (tl col :: fst state, hd 0 col :: snd state) + ) (nil, nil) inp. + + Definition from_columns n (inp : cols) : rows := + snd + (fold_right (fun _ (state : cols * rows) => + let cols'_row := extract_row (fst state) in + (fst cols'_row, snd state ++ [snd cols'_row]) + ) (inp, nil) (seq 0 n)). + + Lemma eval_from_columns (inp : cols) : + forall n m, length inp = n -> + (m = fold_right Nat.max 0%nat (map (@length Z) inp))%nat -> + eval n (from_columns m inp) = Columns.eval weight n inp. + Proof. + cbv [eval Columns.eval from_columns]. + induction inp; intros; distr_length. + { subst n. subst m. reflexivity. } + Admitted. + + Local Notation fw := (fun i => weight (S i) / weight i) (only parsing). + + Definition sum_rows (row1 row2 : list Z) : list Z * Z := + fold_left (fun (state : list Z * Z) next => + let i := length (fst state) in (* length of output accumulator tells us the index of [next] *) + let sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in + (fst state ++ [fst sum_carry], snd sum_carry)) (combine row1 row2) (nil,0). + + Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z := + fold_right (fun next_row (state : list Z * Z)=> + let out_carry := sum_rows next_row (fst state) in + (fst out_carry, snd state + snd out_carry)) start_state inp. + + (* For correctness if there is only one row, we add a row of + zeroes with the same length so that the add loop still happens. *) + Definition flatten (inp : rows) : list Z * Z := + let first_row := hd nil inp in + flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)). + + Lemma sum_rows_mod row1 row2 n: + length row1 = n -> length row2 = n -> + Positional.eval weight n (fst (sum_rows row1 row2)) + = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n). + Admitted. + + Lemma sum_rows_div row1 row2 n: + length row1 = n -> length row2 = n -> + snd (sum_rows row1 row2) + = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n). + Admitted. + + Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval. + + Lemma sum_rows_partitions row1 row2 n i: + length row1 = n -> length row2 = n -> (i < n)%nat -> + nth_default 0 (fst (sum_rows row1 row2)) i + = ((Positional.eval weight n row1 + Positional.eval weight n row2) / weight i) mod (fw i). + Admitted. + + Lemma length_sum_rows row1 row2 n : + length row1 = n -> length row2 = n -> + length (fst (sum_rows row1 row2)) = n. + Admitted. + + + (* TODO: move to ListUtil *) + (* connect state to remaining input *) + Lemma fold_right_invariant_strong {A B} (P: list B -> A -> Type) (f: B -> A -> A): + forall bs a0 + (Pnil : P nil a0) + (IHfold: + forall bs' b a', + In b bs -> + a' = fold_right f a0 bs' -> + P bs' a' -> + P (b :: bs') (f b a')), + P bs (fold_right f a0 bs). + Proof. induction bs; intros; rewrite ?fold_right_cons; auto using in_eq,in_cons. Qed. + + Lemma flatten'_div_mod n start_state inp: + length (fst start_state) = n -> + (forall row, In row inp -> length row = n) -> + length (fst (flatten' start_state inp)) = n + /\ (inp <> nil -> + Positional.eval weight n (fst (flatten' start_state inp)) = (Positional.eval weight n (fst start_state) + eval n inp) mod (weight n) + /\ snd (flatten' start_state inp) = snd start_state + (Positional.eval weight n (fst start_state) + eval n inp) / weight n). + Proof. + intro. + cbv [flatten']. + apply fold_right_invariant_strong with (bs:=inp); intros; [ tauto | destruct (dec (bs' = nil)) ]; + repeat match goal with + | _ => subst a' + | _ => subst bs' + | _ => rewrite sum_rows_div with (n:=n) by auto + | _ => rewrite @fold_right_nil in * + | _ => progress autorewrite with cancel_pair push_eval in * + | H : _ -> _ /\ _ |- _ => + let X := fresh in let Y := fresh in + destruct H as [X Y]; [ solve [auto using in_cons] | rewrite ?X, ?Y ] + | _ => split + | _ => progress autorewrite with pull_Zmod zsimplify_fast + | _ => solve [ auto using length_sum_rows ] + | _ => solve [ ring_simplify; repeat (f_equal; try ring) ] + end. + apply Z.mul_cancel_l with (p:=weight n); [ apply weight_nonzero |]. + autorewrite with push_Zmul. + rewrite !Z.mul_div_eq_full by apply weight_nonzero. + autorewrite with pull_Zmod. + ring_simplify_subterms. ring_simplify. + repeat (f_equal; try ring). + Qed. + + Hint Rewrite (@Positional.length_zeros weight) : distr_length. + Hint Rewrite (@Positional.eval_zeros weight) using auto : push_eval. + + Lemma flatten_div_mod inp n : + (forall row, In row inp -> length row = n) -> + Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n) + /\ snd (flatten inp) = (eval n inp) / weight n. + Proof. + intros; cbv [flatten]. + destruct inp; [|destruct inp]; cbn [hd tl]. + { cbn. autorewrite with push_eval. tauto. } + { cbn. + match goal with H : forall r, In r [?x] -> _ |- _ => + specialize (H x ltac:(auto)); rewrite H + end. + rewrite sum_rows_div with (n:=n) by distr_length. + autorewrite with push_eval. + split; f_equal; ring. } + { autorewrite with push_eval. + apply flatten'_div_mod; auto. + congruence. } + Qed. + + End Rows. +End Rows. + Module Import MOVEME. Fixpoint fold_andb_map {A B} (f : A -> B -> bool) (ls1 : list A) (ls2 : list B) : bool |