aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-17 13:36:10 +0200
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-30 04:20:04 -0400
commit026c09658d1554e8a24cbab8a147c7675deb961b (patch)
treefac0e92cf4731a5654c01df40d99986d02514ce2 /src/Experiments/SimplyTypedArithmetic.v
parent4793e0570c137e8d890bc8c3b2bff90e2aa692ea (diff)
tweak definition of flatten to use an index rather than check the length of the output accumulator--this prevents the accumulator from repeatedly showing up in the expression and making the term huge
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v47
1 files changed, 28 insertions, 19 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 1daa54370..8e635c524 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -1245,12 +1245,17 @@ Module Rows.
Local Notation fw := (fun i => weight (S i) / weight i) (only parsing).
Section SumRows.
- Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z :=
- fold_right (fun next (state : list Z * Z) =>
- let i := length (fst state) in (* length of output accumulator tells us the index of [next] *)
- dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in
- (fst state ++ [fst sum_carry], snd sum_carry)) start_state (rev (combine row1 row2)).
- Definition sum_rows := sum_rows' (nil,0).
+ Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z * nat :=
+ fold_right (fun next (state : list Z * Z * nat) =>
+ let i := snd state in
+ let low_high' :=
+ dlet_nd low_high := fst state in
+ let low := fst low_high in
+ let high := snd low_high in
+ dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) high (fst next) (snd next) in
+ (low ++ [fst sum_carry], snd sum_carry) in
+ (low_high', S i)) start_state (rev (combine row1 row2)).
+ Definition sum_rows row1 row2 := fst (sum_rows' (nil, 0, 0%nat) row1 row2).
Ltac push :=
repeat match goal with
@@ -1273,7 +1278,9 @@ Module Rows.
Lemma sum_rows'_cons state x1 row1 x2 row2 :
sum_rows' state (x1 :: row1) (x2 :: row2) =
- sum_rows' (fst state ++ [(snd state + x1 + x2) mod (fw (length (fst state)))], (snd state + x1 + x2) / fw (length (fst state))) row1 row2.
+ sum_rows' (fst (fst state) ++ [(snd (fst state) + x1 + x2) mod (fw (snd state))],
+ (snd (fst state) + x1 + x2) / fw (snd state),
+ S (snd state)) row1 row2.
Proof.
cbv [sum_rows' Let_In]; autorewrite with push_combine.
rewrite !fold_left_rev_right. cbn [fold_left].
@@ -1288,20 +1295,21 @@ Module Rows.
Lemma sum_rows'_div_mod_length row1 :
forall nm start_state row2 row1' row2',
- let m := length (fst start_state) in
+ let m := snd start_state in
let n := length row1 in
length row2 = n ->
length row1' = m ->
length row2' = m ->
+ length (fst (fst start_state)) = m ->
(nm = n + m)%nat ->
let eval := Positional.eval weight in
- is_div_mod (eval m) start_state (eval m row1' + eval m row2') (weight m) ->
- length (fst (sum_rows' start_state row1 row2)) = nm
- /\ is_div_mod (eval nm) (sum_rows' start_state row1 row2)
+ is_div_mod (eval m) (fst start_state) (eval m row1' + eval m row2') (weight m) ->
+ length (fst (fst (sum_rows' start_state row1 row2))) = nm
+ /\ is_div_mod (eval nm) (fst (sum_rows' start_state row1 row2))
(eval nm (row1' ++ row1) + eval nm (row2' ++ row2))
(weight nm).
Proof.
- induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [].
+ induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [ ].
rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length in *; try omega.
eapply is_div_mod_step with (x := x1 + x2); try eassumption; push.
@@ -1330,18 +1338,19 @@ Module Rows.
Lemma sum_rows'_partitions row1 :
forall nm start_state row2 row1' row2',
- let m := length (fst start_state) in
+ let m := snd start_state in
let n := length row1 in
length row2 = n ->
length row1' = m ->
length row2' = m ->
+ length (fst (fst start_state)) = m ->
nm = (n + m)%nat ->
let eval := Positional.eval weight in
- snd start_state = (eval m row1' + eval m row2') / weight m ->
+ snd (fst start_state) = (eval m row1' + eval m row2') / weight m ->
(forall j, (j < m)%nat ->
- nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) ->
+ nth_default 0 (fst (fst start_state)) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) ->
forall i, (i < nm)%nat ->
- nth_default 0 (fst (sum_rows' start_state row1 row2)) i
+ nth_default 0 (fst (fst (sum_rows' start_state row1 row2))) i
= ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i).
Proof.
induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [].
@@ -1351,14 +1360,14 @@ Module Rows.
repeat match goal with
| H : ?LHS = _ |- _ =>
match LHS with context [start_state] => rewrite H end
- | H : context [nth_default 0 (fst start_state)] |- _ => rewrite H by omega
+ | H : context [nth_default 0 (fst (fst start_state))] |- _ => rewrite H by omega
| _ => rewrite <-(Z.add_assoc _ x1 x2)
end.
{ rewrite div_step by auto using Z.gt_lt.
rewrite Z.mul_div_eq_full by auto; rewrite weight_multiples by auto. push. }
- { rewrite weight_div_mod with (j:=length (fst start_state)) (i:=S j) by (auto; omega).
+ { rewrite weight_div_mod with (j:=snd start_state) (i:=S j) by (auto; omega).
push_Zmod. autorewrite with zsimplify_fast. reflexivity. }
- { push. replace (length (fst start_state)) with j in * by omega.
+ { push. replace (snd start_state) with j in * by omega.
push. rewrite add_mod_div_multiple by auto using Z.lt_le_incl.
push. }
Qed.