aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-03-01 17:45:06 +0100
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-03 09:00:55 -0400
commit884de22a21f9ae7b8d2743271b66e750fc5d4ff2 (patch)
tree754d26012446ceac61ddf937a08e4a9d8f3c601f
parente6a306ceb5824b161c6e8688333cbde21a0ac8f4 (diff)
rowwise flatten (more fleshed-out) and proof outline
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v203
1 files changed, 163 insertions, 40 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 2b3a767f9..f01fd29e0 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -952,46 +952,6 @@ Module Columns.
reflexivity. } }
Qed.
- Section Rows.
-
- Let fw := (fun i => weight (S i) / weight i).
-
- Fixpoint finish_row (idx: nat) (carry:Z) (inp: list (list Z)) : list (list Z) * Z :=
- match inp with
- | nil => (nil, carry)
- | col :: inp' =>
- match col with
- | nil => ([carry] :: inp', 0)
- | x :: nil =>
- let sum_carry := Z.add_with_get_carry_full (fw idx) carry x 0 in
- let rec := finish_row (S idx) (snd sum_carry) inp' in
- ([(fst sum_carry)] :: fst rec, snd rec)
- | x :: y :: tl =>
- let sum_carry := Z.add_with_get_carry_full (fw idx) carry x y in
- let rec := finish_row (S idx) (snd sum_carry) inp' in
- (((fst sum_carry) :: tl) :: fst rec, snd rec)
- end
- end.
-
- Fixpoint flatten_rowwise' (start_value : Z) (start_col : list Z) (other_cols : list (list Z)) : list Z * Z :=
- match start_col with
- | nil => (start_value :: map sum other_cols, 0)
- | x :: tl =>
- let sum_carry := Z.add_get_carry_full (fw 0%nat) start_value x in
- let other_cols'_carry := finish_row 1%nat (snd sum_carry) other_cols in
- let rec := flatten_rowwise' (fst sum_carry) tl (fst other_cols'_carry) in
- (fst rec, snd rec + snd other_cols'_carry)
- end.
-
- Definition flatten_rowwise (inp : list (list Z)) : list Z * Z :=
- match inp with
- | (x :: tl) :: other_cols =>
- flatten_rowwise' x tl other_cols
- | _ => (map sum inp, 0)
- end.
-
- End Rows.
-
Section mul.
Definition mul s n m (p q : list Z) : list Z :=
let p_a := Positional.to_associational weight n p in
@@ -1096,6 +1056,169 @@ Module Columns.
End mul_converted.
End Columns.
+Module Rows.
+ Section Rows.
+ Context (weight : nat->Z)
+ {weight_0 : weight 0%nat = 1}
+ {weight_nonzero : forall i, weight i <> 0}
+ {weight_positive : forall i, weight i > 0}
+ {weight_multiples : forall i, weight (S i) mod weight i = 0}
+ {weight_divides : forall i : nat, weight (S i) / weight i > 0}.
+
+ Local Notation rows := (list (list Z)) (only parsing).
+ Local Notation cols := (list (list Z)) (only parsing).
+ Hint Rewrite Positional.eval_nil : push_eval.
+ Hint Resolve in_eq in_cons.
+
+ Definition eval n (inp : rows) :=
+ sum (map (Positional.eval weight n) inp).
+ Lemma eval_nil n : eval n nil = 0.
+ Proof. cbv [eval]. rewrite map_nil, sum_nil; reflexivity. Qed.
+ Hint Rewrite eval_nil : push_eval.
+ Lemma eval_cons n r inp : eval n (r :: inp) = Positional.eval weight n r + eval n inp.
+ Proof. cbv [eval]. rewrite map_cons, sum_cons; reflexivity. Qed.
+ Hint Rewrite eval_cons : push_eval.
+
+ Definition extract_row (inp : cols) : cols * list Z :=
+ fold_right (fun col state =>
+ (tl col :: fst state, hd 0 col :: snd state)
+ ) (nil, nil) inp.
+
+ Definition from_columns n (inp : cols) : rows :=
+ snd
+ (fold_right (fun _ (state : cols * rows) =>
+ let cols'_row := extract_row (fst state) in
+ (fst cols'_row, snd state ++ [snd cols'_row])
+ ) (inp, nil) (seq 0 n)).
+
+ Lemma eval_from_columns (inp : cols) :
+ forall n m, length inp = n ->
+ (m = fold_right Nat.max 0%nat (map (@length Z) inp))%nat ->
+ eval n (from_columns m inp) = Columns.eval weight n inp.
+ Proof.
+ cbv [eval Columns.eval from_columns].
+ induction inp; intros; distr_length.
+ { subst n. subst m. reflexivity. }
+ Admitted.
+
+ Local Notation fw := (fun i => weight (S i) / weight i) (only parsing).
+
+ Definition sum_rows (row1 row2 : list Z) : list Z * Z :=
+ fold_left (fun (state : list Z * Z) next =>
+ let i := length (fst state) in (* length of output accumulator tells us the index of [next] *)
+ let sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in
+ (fst state ++ [fst sum_carry], snd sum_carry)) (combine row1 row2) (nil,0).
+
+ Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z :=
+ fold_right (fun next_row (state : list Z * Z)=>
+ let out_carry := sum_rows next_row (fst state) in
+ (fst out_carry, snd state + snd out_carry)) start_state inp.
+
+ (* For correctness if there is only one row, we add a row of
+ zeroes with the same length so that the add loop still happens. *)
+ Definition flatten (inp : rows) : list Z * Z :=
+ let first_row := hd nil inp in
+ flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)).
+
+ Lemma sum_rows_mod row1 row2 n:
+ length row1 = n -> length row2 = n ->
+ Positional.eval weight n (fst (sum_rows row1 row2))
+ = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n).
+ Admitted.
+
+ Lemma sum_rows_div row1 row2 n:
+ length row1 = n -> length row2 = n ->
+ snd (sum_rows row1 row2)
+ = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n).
+ Admitted.
+
+ Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval.
+
+ Lemma sum_rows_partitions row1 row2 n i:
+ length row1 = n -> length row2 = n -> (i < n)%nat ->
+ nth_default 0 (fst (sum_rows row1 row2)) i
+ = ((Positional.eval weight n row1 + Positional.eval weight n row2) / weight i) mod (fw i).
+ Admitted.
+
+ Lemma length_sum_rows row1 row2 n :
+ length row1 = n -> length row2 = n ->
+ length (fst (sum_rows row1 row2)) = n.
+ Admitted.
+
+
+ (* TODO: move to ListUtil *)
+ (* connect state to remaining input *)
+ Lemma fold_right_invariant_strong {A B} (P: list B -> A -> Type) (f: B -> A -> A):
+ forall bs a0
+ (Pnil : P nil a0)
+ (IHfold:
+ forall bs' b a',
+ In b bs ->
+ a' = fold_right f a0 bs' ->
+ P bs' a' ->
+ P (b :: bs') (f b a')),
+ P bs (fold_right f a0 bs).
+ Proof. induction bs; intros; rewrite ?fold_right_cons; auto using in_eq,in_cons. Qed.
+
+ Lemma flatten'_div_mod n start_state inp:
+ length (fst start_state) = n ->
+ (forall row, In row inp -> length row = n) ->
+ length (fst (flatten' start_state inp)) = n
+ /\ (inp <> nil ->
+ Positional.eval weight n (fst (flatten' start_state inp)) = (Positional.eval weight n (fst start_state) + eval n inp) mod (weight n)
+ /\ snd (flatten' start_state inp) = snd start_state + (Positional.eval weight n (fst start_state) + eval n inp) / weight n).
+ Proof.
+ intro.
+ cbv [flatten'].
+ apply fold_right_invariant_strong with (bs:=inp); intros; [ tauto | destruct (dec (bs' = nil)) ];
+ repeat match goal with
+ | _ => subst a'
+ | _ => subst bs'
+ | _ => rewrite sum_rows_div with (n:=n) by auto
+ | _ => rewrite @fold_right_nil in *
+ | _ => progress autorewrite with cancel_pair push_eval in *
+ | H : _ -> _ /\ _ |- _ =>
+ let X := fresh in let Y := fresh in
+ destruct H as [X Y]; [ solve [auto using in_cons] | rewrite ?X, ?Y ]
+ | _ => split
+ | _ => progress autorewrite with pull_Zmod zsimplify_fast
+ | _ => solve [ auto using length_sum_rows ]
+ | _ => solve [ ring_simplify; repeat (f_equal; try ring) ]
+ end.
+ apply Z.mul_cancel_l with (p:=weight n); [ apply weight_nonzero |].
+ autorewrite with push_Zmul.
+ rewrite !Z.mul_div_eq_full by apply weight_nonzero.
+ autorewrite with pull_Zmod.
+ ring_simplify_subterms. ring_simplify.
+ repeat (f_equal; try ring).
+ Qed.
+
+ Hint Rewrite (@Positional.length_zeros weight) : distr_length.
+ Hint Rewrite (@Positional.eval_zeros weight) using auto : push_eval.
+
+ Lemma flatten_div_mod inp n :
+ (forall row, In row inp -> length row = n) ->
+ Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)
+ /\ snd (flatten inp) = (eval n inp) / weight n.
+ Proof.
+ intros; cbv [flatten].
+ destruct inp; [|destruct inp]; cbn [hd tl].
+ { cbn. autorewrite with push_eval. tauto. }
+ { cbn.
+ match goal with H : forall r, In r [?x] -> _ |- _ =>
+ specialize (H x ltac:(auto)); rewrite H
+ end.
+ rewrite sum_rows_div with (n:=n) by distr_length.
+ autorewrite with push_eval.
+ split; f_equal; ring. }
+ { autorewrite with push_eval.
+ apply flatten'_div_mod; auto.
+ congruence. }
+ Qed.
+
+ End Rows.
+End Rows.
+
Module Import MOVEME.
Fixpoint fold_andb_map {A B} (f : A -> B -> bool) (ls1 : list A) (ls2 : list B)
: bool