diff options
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 47 |
1 files changed, 28 insertions, 19 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 1daa54370..8e635c524 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -1245,12 +1245,17 @@ Module Rows. Local Notation fw := (fun i => weight (S i) / weight i) (only parsing). Section SumRows. - Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z := - fold_right (fun next (state : list Z * Z) => - let i := length (fst state) in (* length of output accumulator tells us the index of [next] *) - dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in - (fst state ++ [fst sum_carry], snd sum_carry)) start_state (rev (combine row1 row2)). - Definition sum_rows := sum_rows' (nil,0). + Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z * nat := + fold_right (fun next (state : list Z * Z * nat) => + let i := snd state in + let low_high' := + dlet_nd low_high := fst state in + let low := fst low_high in + let high := snd low_high in + dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) high (fst next) (snd next) in + (low ++ [fst sum_carry], snd sum_carry) in + (low_high', S i)) start_state (rev (combine row1 row2)). + Definition sum_rows row1 row2 := fst (sum_rows' (nil, 0, 0%nat) row1 row2). Ltac push := repeat match goal with @@ -1273,7 +1278,9 @@ Module Rows. Lemma sum_rows'_cons state x1 row1 x2 row2 : sum_rows' state (x1 :: row1) (x2 :: row2) = - sum_rows' (fst state ++ [(snd state + x1 + x2) mod (fw (length (fst state)))], (snd state + x1 + x2) / fw (length (fst state))) row1 row2. + sum_rows' (fst (fst state) ++ [(snd (fst state) + x1 + x2) mod (fw (snd state))], + (snd (fst state) + x1 + x2) / fw (snd state), + S (snd state)) row1 row2. Proof. cbv [sum_rows' Let_In]; autorewrite with push_combine. rewrite !fold_left_rev_right. cbn [fold_left]. @@ -1288,20 +1295,21 @@ Module Rows. Lemma sum_rows'_div_mod_length row1 : forall nm start_state row2 row1' row2', - let m := length (fst start_state) in + let m := snd start_state in let n := length row1 in length row2 = n -> length row1' = m -> length row2' = m -> + length (fst (fst start_state)) = m -> (nm = n + m)%nat -> let eval := Positional.eval weight in - is_div_mod (eval m) start_state (eval m row1' + eval m row2') (weight m) -> - length (fst (sum_rows' start_state row1 row2)) = nm - /\ is_div_mod (eval nm) (sum_rows' start_state row1 row2) + is_div_mod (eval m) (fst start_state) (eval m row1' + eval m row2') (weight m) -> + length (fst (fst (sum_rows' start_state row1 row2))) = nm + /\ is_div_mod (eval nm) (fst (sum_rows' start_state row1 row2)) (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) (weight nm). Proof. - induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; []. + induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [ ]. rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length in *; try omega. eapply is_div_mod_step with (x := x1 + x2); try eassumption; push. @@ -1330,18 +1338,19 @@ Module Rows. Lemma sum_rows'_partitions row1 : forall nm start_state row2 row1' row2', - let m := length (fst start_state) in + let m := snd start_state in let n := length row1 in length row2 = n -> length row1' = m -> length row2' = m -> + length (fst (fst start_state)) = m -> nm = (n + m)%nat -> let eval := Positional.eval weight in - snd start_state = (eval m row1' + eval m row2') / weight m -> + snd (fst start_state) = (eval m row1' + eval m row2') / weight m -> (forall j, (j < m)%nat -> - nth_default 0 (fst start_state) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) -> + nth_default 0 (fst (fst start_state)) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) -> forall i, (i < nm)%nat -> - nth_default 0 (fst (sum_rows' start_state row1 row2)) i + nth_default 0 (fst (fst (sum_rows' start_state row1 row2))) i = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i). Proof. induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; []. @@ -1351,14 +1360,14 @@ Module Rows. repeat match goal with | H : ?LHS = _ |- _ => match LHS with context [start_state] => rewrite H end - | H : context [nth_default 0 (fst start_state)] |- _ => rewrite H by omega + | H : context [nth_default 0 (fst (fst start_state))] |- _ => rewrite H by omega | _ => rewrite <-(Z.add_assoc _ x1 x2) end. { rewrite div_step by auto using Z.gt_lt. rewrite Z.mul_div_eq_full by auto; rewrite weight_multiples by auto. push. } - { rewrite weight_div_mod with (j:=length (fst start_state)) (i:=S j) by (auto; omega). + { rewrite weight_div_mod with (j:=snd start_state) (i:=S j) by (auto; omega). push_Zmod. autorewrite with zsimplify_fast. reflexivity. } - { push. replace (length (fst start_state)) with j in * by omega. + { push. replace (snd start_state) with j in * by omega. push. rewrite add_mod_div_multiple by auto using Z.lt_le_incl. push. } Qed. |