diff options
author | Jason Gross <jgross@mit.edu> | 2018-02-16 21:32:41 -0500 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-02-19 17:59:16 -0500 |
commit | aa6044f40e9e46856dd94748bfad61565de1266a (patch) | |
tree | 3feeceee7c0483783e59961004fc92ba9290aa22 /src | |
parent | 88ac24f4afe52af45a509c8b3d61a8598f80a233 (diff) |
[experiments] Add some more arithmetic operations
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 85 |
1 files changed, 69 insertions, 16 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 7d1e780b9..e6b8681b0 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -20,9 +20,11 @@ Import ListNotations. Local Open Scope Z_scope. Definition runtime_mul := Z.mul. Definition runtime_add := Z.add. +Definition runtime_opp := Z.opp. Delimit Scope runtime_scope with RT. Infix "*" := runtime_mul : runtime_scope. Infix "+" := runtime_add : runtime_scope. +Notation "- a" := (runtime_opp a%RT) : runtime_scope. Module Associational. Definition eval (p:list (Z*Z)) : Z := @@ -55,6 +57,12 @@ Module Associational. Proof. induction p; cbv [mul]; push; nsatz. Qed. Hint Rewrite eval_mul : push_eval. + Definition negate_snd (p:list (Z*Z)) : list (Z*Z) := + map (fun cx => (fst cx, (-snd cx)%RT)) p. + Lemma eval_negate_snd p : eval (negate_snd p) = - eval p. + Proof. induction p; cbv [negate_snd]; push; nsatz. Qed. + Hint Rewrite eval_negate_snd : push_eval. + Example base10_2digit_mul (a0:Z) (a1:Z) (b0:Z) (b1:Z) : {ab| eval ab = eval [(10,a1);(1,a0)] * eval [(10,b1);(1,b0)]}. eexists ?[ab]. @@ -93,15 +101,11 @@ Module Associational. Hint Rewrite eval_reduce : push_eval. Section Carries. - Context {modulo div : Z -> Z -> Z}. - Context {div_mod : forall a b:Z, b <> 0 -> - a = b * (div a b) + modulo a b}. - Definition carryterm (w fw:Z) (t:Z * Z) := if (Z.eqb (fst t) w) then dlet_nd t2 := snd t in - dlet_nd d2 := div t2 fw in - dlet_nd m2 := modulo t2 fw in + dlet_nd d2 := Z.div t2 fw in + dlet_nd m2 := Z.modulo t2 fw in [(w * fw, d2);(w,m2)] else [t]. @@ -109,7 +113,7 @@ Module Associational. eval (carryterm w fw t) = eval [t]. Proof using Type*. cbv [carryterm Let_In]; break_match; push; [|trivial]. - specialize (div_mod (snd t) fw fw_nonzero). + pose proof (Z.div_mod (snd t) fw fw_nonzero). rewrite Z.eqb_eq in *. nsatz. Qed. Hint Rewrite eval_carryterm using auto : push_eval. @@ -139,8 +143,9 @@ Module Positional. Section Positional. (* SKIP over this: zeros, add_to_nth *) Local Ltac push := autorewrite with push_eval push_map distr_length push_flat_map push_fold_right push_nth_default cancel_pair natsimplify. - Definition zeros n : list Z - := List.repeat 0 n. + Definition zeros n : list Z := List.repeat 0 n. + Lemma length_zeros n : length (zeros n) = n. Proof. cbv [zeros]; distr_length. Qed. + Hint Rewrite length_zeros : distr_length. Lemma eval_zeros n : eval n (zeros n) = 0. Proof. cbv [eval Associational.eval to_associational zeros]. @@ -202,6 +207,9 @@ Module Positional. Section Positional. Hint Rewrite @eval_from_associational : push_eval. Section mulmod. + (** TODO(for jadep, from jgross): Add a comment about why we take + in [m] rather than just using [s - Associational.eval c], or + just remove [m] *) Context (m:Z) (m_nz:m <> 0) (s:Z) (s_nz:s <> 0) (c:list (Z*Z)) (Hm:m = s - Associational.eval c). Definition mulmod (n:nat) (a b:list Z) : list Z @@ -219,14 +227,47 @@ Module Positional. Section Positional. induction c as [|?? IHc]; simpl; trivial. Qed. End mulmod. - Section Carries. - Context {modulo div: Z -> Z -> Z}. - Context {div_mod : forall a b:Z, b <> 0 -> - a = b * (div a b) + modulo a b}. + Section add. + Context (m:Z) (m_nz:m <> 0) (s:Z) (s_nz:s <> 0) + (c:list (Z*Z)) (Hm:m = s - Associational.eval c). + Definition add (n:nat) (a b:list Z) : list Z + := let a_a := to_associational n a in + let b_a := to_associational n b in + from_associational n (a_a ++ b_a). + Lemma eval_add n (f g:list Z) + (Hf : length f = n) (Hg : length g = n) : + eval n (add n f g) mod m = (eval n f + eval n g) mod m. + Proof. cbv [add]; rewrite Hm in *; push; trivial. + destruct n; auto. Qed. + End add. + + Section sub. + (** TODO(jadep): Fill me in *) + Axiom sub : forall (n:nat) (a b:list Z), list Z. (* should be balanced *) + Axiom eval_sub + : forall (s:Z) + (c:list (Z*Z)) + n (f g:list Z) + (Hf : length f = n) (Hg : length g = n), + eval n (sub n f g) mod (s - Associational.eval c) + = (eval n f - eval n g) mod (s - Associational.eval c). + Hint Rewrite eval_sub : push_eval. + Definition opp (n:nat) (a:list Z) : list Z + := sub n (zeros n) a. + Lemma eval_opp + (s:Z) + (c:list (Z*Z)) + n (f:list Z) + (Hf : length f = n) + : eval n (opp n f) mod (s - Associational.eval c) + = (- eval n f) mod (s - Associational.eval c). + Proof. cbv [opp]; push; distr_length. Qed. + End sub. + Section Carries. Definition carry {n m} (index:nat) (p:list Z) : list Z := from_associational - m (@Associational.carry modulo div (weight index) + m (@Associational.carry (weight index) (weight (S index) / weight index) (to_associational n p)). @@ -276,6 +317,18 @@ Module Positional. Section Positional. destruct n; intros; push; auto. Qed. Hint Rewrite @eval_chained_carries : push_eval. + (* Reverse of [eval]; translate from Z to basesystem by putting + everything in first digit and then carrying. *) + Definition encode {n} s c (x : Z) : list Z := + chained_carries (n:=n) s c (from_associational n [(1,x)]) (seq 0 n). + Lemma eval_encode {n} s c x : + (s <> 0) -> (s - Associational.eval c <> 0) -> (n <> 0%nat) -> + (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) -> + eval n (@encode n s c x) mod (s - Associational.eval c) + = x mod (s - Associational.eval c). + Proof using Type*. cbv [encode]; intros; push; auto; f_equal; omega. Qed. + Hint Rewrite @eval_encode : push_eval. + End Carries. @@ -305,8 +358,8 @@ Module Positional. Section Positional. erewrite <-eval_mulmod with (s:=s) (c:=c) by (subst; try assumption; try reflexivity). etransitivity; - [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs) (modulo:=fun x y => Z.modulo x y) (div:=fun x y => Z.div x y) - by (subst; try assumption; auto using Z.div_mod); reflexivity ]. + [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs) + by (subst; auto); reflexivity ]. eapply f_equal2; [|trivial]. eapply f_equal. expand_lists (). subst carry_mulmod. |