aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-04-19 09:16:34 -0400
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2017-05-01 14:34:48 -0400
commit232702b35096cd00b4843c9b283b36dccab18961 (patch)
tree72b659e64bf62f90fd13932cbf74abd72c45fa81 /src/Arithmetic
parenta81bce39bf121c41f559a90710892b4e43930f5e (diff)
prove compact_digit obeys div/mod rule
Diffstat (limited to 'src/Arithmetic')
-rw-r--r--src/Arithmetic/Saturated.v194
1 files changed, 112 insertions, 82 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v
index 87c0e5ec9..884a59ef7 100644
--- a/src/Arithmetic/Saturated.v
+++ b/src/Arithmetic/Saturated.v
@@ -105,16 +105,28 @@ Module Columns.
Context (weight : nat->Z)
{weight_0 : weight 0%nat = 1}
{weight_nonzero : forall i, weight i <> 0}
+ {weight_positive : forall i, weight i > 0}
{weight_multiples : forall i, weight (S i) mod weight i = 0}
- (* add_get_carry takes in a number at which to split output *)
+ {weight_divides : forall i : nat, weight (S i) / weight i > 0}
+ (* add_get_carry takes in a number at which to split output *)
{add_get_carry: Z ->Z -> Z -> (Z * Z)}
- {add_get_carry_correct : forall s x y,
- fst (add_get_carry s x y) = x + y - s * snd (add_get_carry s x y)}
+ {add_get_carry_mod : forall s x y,
+ fst (add_get_carry s x y) = (x + y) mod s}
+ {add_get_carry_div : forall s x y,
+ snd (add_get_carry s x y) = (x + y) / s}
+ {div modulo : Z -> Z -> Z}
+ {div_correct : forall a b, div a b = a / b}
+ {modulo_correct : forall a b, modulo a b = a mod b}
.
+ Hint Rewrite div_correct modulo_correct add_get_carry_mod add_get_carry_div : div_mod.
Definition eval {n} (x : (list Z)^n) : Z :=
B.Positional.eval weight (Tuple.map sum x).
+ Lemma eval_unit (x:unit) : eval (n:=0) x = 0.
+ Proof. reflexivity. Qed.
+ Hint Rewrite eval_unit : push_basesystem_eval.
+
Definition eval_from {n} (offset:nat) (x : (list Z)^n) : Z :=
B.Positional.eval (fun i => weight (i+offset)) (Tuple.map sum x).
@@ -138,12 +150,13 @@ Module Columns.
Fixpoint compact_digit_cps n (digit : list Z) {T} (f:Z * Z->T) :=
match digit with
| nil => f (0, 0)
- | x :: nil => f (0, x)
+ | x :: nil => f (div x (weight (S n) / weight n), modulo x (weight (S n) / weight n))
| x :: tl =>
- compact_digit_cps n tl (fun rec =>
- dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in
- dlet carry' := (fst rec + snd sum_carry)%RT in
- f (carry', fst sum_carry))
+ compact_digit_cps n tl
+ (fun rec =>
+ dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in
+ dlet carry' := (fst rec + snd sum_carry)%RT in
+ f (carry', fst sum_carry))
end.
Definition compact_digit n digit := compact_digit_cps n digit id.
@@ -175,32 +188,52 @@ Module Columns.
Lemma compact_id {n} xs {T} f : @compact_cps n xs T f = f (compact xs).
Proof using Type. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed.
- Lemma compact_digit_correct i (xs : list Z) :
- snd (compact_digit i xs) = sum xs - (weight (S i) / weight i) * (fst (compact_digit i xs)).
- Proof using add_get_carry_correct weight_0.
+ Lemma compact_digit_mod i (xs : list Z) :
+ snd (compact_digit i xs) = sum xs mod (weight (S i) / weight i).
+ Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct.
+ induction xs; cbv [compact_digit]; simpl compact_digit_cps;
+ cbv [Let_In];
+ repeat match goal with
+ | _ => progress autorewrite with div_mod
+ | _ => rewrite IHxs, <-Z.add_mod_r
+ | _ => progress (rewrite ?sum_cons, ?sum_nil in * )
+ | _ => progress (autorewrite with uncps push_id cancel_pair in * )
+ | _ => progress break_match; try discriminate
+ | _ => reflexivity
+ | _ => f_equal; ring
+ end.
+ Qed. Hint Rewrite compact_digit_mod : div_mod.
+
+ Lemma compact_digit_div i (xs : list Z) :
+ fst (compact_digit i xs) = sum xs / (weight (S i) / weight i).
+ Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides.
induction xs; cbv [compact_digit]; simpl compact_digit_cps;
cbv [Let_In];
repeat match goal with
- | _ => rewrite add_get_carry_correct
+ | _ => progress autorewrite with div_mod
+ | _ => rewrite IHxs
| _ => progress (rewrite ?sum_cons, ?sum_nil in * )
- | _ => progress (autorewrite with uncps push_id in * )
- | _ => progress (autorewrite with cancel_pair in * )
+ | _ => progress (autorewrite with uncps push_id cancel_pair in * )
| _ => progress break_match; try discriminate
- | _ => progress ring_simplify
| _ => reflexivity
- | _ => nsatz
+ | _ => f_equal; ring
end.
+ assert (weight (S i) / weight i <> 0) by auto using Z.positive_is_nonzero.
+ match goal with |- _ = (?a + ?X) / ?D =>
+ transitivity ((a + X mod D + D * (X / D)) / D);
+ [| rewrite (Z.div_mod'' X D) at 3; f_equal; auto; ring]
+ end.
+ rewrite Z.div_add' by auto; nsatz.
Qed.
Definition compact_invariant n i (starter rem:Z) (inp : tuple (list Z) n) (out : tuple Z n) :=
- B.Positional.eval_from weight i out + weight (i + n) * (rem)
- = eval_from i inp + weight i*starter.
+ B.Positional.eval_from weight i out + weight (i + n) * rem = eval_from i inp + weight i*starter.
Lemma compact_invariant_holds n i starter rem inp out :
compact_invariant n (S i) (fst (compact_step_cps i starter (hd inp) id)) rem (tl inp) out ->
compact_invariant (S n) i starter rem inp (append (snd (compact_step_cps i starter (hd inp) id)) out).
Proof using Type*.
- cbv [compact_invariant B.Positional.eval_from]; intros.
+ cbv [compact_invariant B.Positional.eval_from]; intros.
repeat match goal with
| _ => rewrite B.Positional.eval_step
| _ => rewrite eval_from_S
@@ -212,14 +245,26 @@ Module Columns.
| _ => progress ring_simplify
| _ => rewrite ZUtil.Z.mul_div_eq_full by apply weight_nonzero
| _ => cbv [compact_step_cps] in *;
- autorewrite with uncps push_id;
- rewrite compact_digit_correct
+ autorewrite with uncps push_id in *;
+ rewrite !compact_digit_mod, !compact_digit_div in *
| _ => progress (autorewrite with natsimplify in * )
- end.
+ end;
rewrite B.Positional.eval_wt_equiv with (wtb := fun i0 => weight (i0 + S i)) by (intros; f_equal; try omega).
+ {
+ rewrite Z.mod_eq by auto using Z.positive_is_nonzero.
+ rewrite sum_cons in H.
+ ring_simplify.
+ match type of H with
+ context [?y * (?a / (?y / ?x))] =>
+ replace (y * (a / (y / x))) with (x * (y / x) * (a / (y / x))) in H
+ by (rewrite Z.mul_div_eq_full by auto using Z.positive_is_nonzero;
+ rewrite weight_multiples; ring)
+ end.
nsatz.
+ }
Qed.
+
Lemma compact_invariant_base i rem : compact_invariant 0 i rem rem tt tt.
Proof using Type. cbv [compact_invariant]. simpl. repeat (f_equal; try omega). Qed.
@@ -366,14 +411,21 @@ Hint Rewrite
using (assumption || omega): push_basesystem_eval.
Section Freeze.
- Context (weight : nat->Z)
- {weight_0 : weight 0%nat = 1}
- {weight_nonzero : forall i, weight i <> 0}
- {weight_multiples : forall i, weight (S i) mod weight i = 0}
- (* add_get_carry takes in a number at which to split output *)
- {add_get_carry: Z ->Z -> Z -> (Z * Z)}
- {add_get_carry_correct : forall s x y,
- fst (add_get_carry s x y) = x + y - s * snd (add_get_carry s x y)}
+ Context (weight : nat->Z)
+ {weight_0 : weight 0%nat = 1}
+ {weight_nonzero : forall i, weight i <> 0}
+ {weight_positive : forall i, weight i > 0}
+ {weight_multiples : forall i, weight (S i) mod weight i = 0}
+ {weight_divides : forall i : nat, weight (S i) / weight i > 0}
+ (* add_get_carry takes in a number at which to split output *)
+ {add_get_carry: Z ->Z -> Z -> (Z * Z)}
+ {add_get_carry_mod : forall s x y,
+ fst (add_get_carry s x y) = (x + y) mod s}
+ {add_get_carry_div : forall s x y,
+ snd (add_get_carry s x y) = (x + y) / s}
+ {div modulo : Z -> Z -> Z}
+ {div_correct : forall a b, div a b = a / b}
+ {modulo_correct : forall a b, modulo a b = a mod b}
.
(* adds p and q if cond is 0, else adds 0 to p*)
@@ -397,7 +449,7 @@ Section Freeze.
Definition conditional_add_cps {n} mask cond (p q : Z^n) {T} (f:_->T) :=
conditional_mask_cps mask cond q
(fun qq => Columns.add_cps weight p qq
- (fun R => Columns.compact_cps (add_get_carry:=add_get_carry) weight R f)).
+ (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f)).
Definition conditional_add {n} mask cond p q :=
@conditional_add_cps n mask cond p q _ id.
Lemma conditional_add_id {n} mask cond p q T f:
@@ -430,13 +482,11 @@ Section Freeze.
B.Positional.carry_reduce_cps
(div:=div) (modulo:=modulo) weight s c p
(fun P => Columns.sub_cps weight P m
- (fun Q => Columns.compact_cps (add_get_carry:=add_get_carry) weight Q
+ (fun Q => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight Q
(fun carry_q => conditional_add_cps mask (fst carry_q) (snd carry_q) m
(fun carry_r => f (snd carry_r)))))
.
- SearchAbout (((_ mod _) mod _)).
-
Definition freeze {n} mask s c m p :=
@freeze_cps n mask s c m p _ id.
Lemma freeze_id {n} mask s c m p T f:
@@ -483,27 +533,27 @@ Section Freeze.
Hint Rewrite @eval_conditional_add using (omega || assumption)
: push_basesystem_eval.
- Lemma freezeZ m s c y0 z z0 c0 a :
+ Lemma freezeZ m s c y y0 z z0 c0 a :
m = s - c ->
0 < c < s ->
s <> 0 ->
- -m <= y0 < m ->
+ 0 <= y < 2*m ->
+ y0 = y - m ->
z = y0 mod s ->
c0 = y0 / s ->
c0 <> 0 ->
- z0 = z + m ->
+ z0 = z + (if (dec (c0 = 0)) then 0 else m) ->
a = z0 mod s ->
a mod m = y0 mod m.
Proof.
- clear. intros. subst.
- rewrite Z.add_mod by assumption.
- rewrite Z.mod_mod by assumption.
- rewrite <-Z.add_mod by assumption.
- assert (~ (0 <= y0 < s)) by (pose proof (Z.div_small y0 s); tauto).
- assert (-(s-c) <= y0 < 0) by omega.
- rewrite Z.mod_small with (b := s) by omega.
- rewrite Z.add_mod, Z.mod_same, Z.add_0_r, Z.mod_mod by omega.
- reflexivity.
+ clear. intros. subst. break_match.
+ { rewrite Z.add_0_r, Z.mod_mod, !Z.mod_small by omega.
+ reflexivity. }
+ { rewrite <-Z.add_mod_l, Z.sub_mod_full.
+ rewrite Z.mod_same, Z.sub_0_r, Z.mod_mod by omega.
+ rewrite Z.mod_small with (b := s)
+ by (pose proof (Z.div_small (y - (s-c)) s); omega).
+ f_equal. ring. }
Qed.
Lemma eval_freeze {n} mask s c m p
@@ -517,43 +567,23 @@ Section Freeze.
(B.Positional.eval weight (@freeze n mask s c m p))
(B.Positional.eval weight p).
Proof.
- cbv [freeze_cps freeze mod_eq]; repeat progress autounfold;
- autorewrite with uncps push_id.
-
- assert (Z.pos modulus <> 0) by (pose proof Pos2Z.is_pos modulus; omega).
- pose proof div_mod.
- break_match; subst.
-
- rewrite Z.mul_0_r, Z.add_0_r, Z.sub_0_r.
- (* TODO : how to prove second carry is 0? *)
- rewrite Z.add_mod, B.Associational.eval_reduce by assumption.
- autorewrite with uncps push_id push_basesystem_eval.
- rewrite Hm. autorewrite with zsimplify.
- reflexivity.
-
+ cbv [freeze_cps freeze mod_eq conditional_add_cps].
+ repeat progress autounfold.
+ autorewrite with uncps push_id.
+
+ match goal with |- context [B.Associational.reduce ?s ?c ?p] =>
+ remember (B.Associational.reduce s c p) as y end.
+ match goal with |- context [fst ?x] =>
+ remember x as carry_q end.
+ match goal with |- context [conditional_mask ?mask ?cond ?p] =>
+ remember (conditional_mask mask cond p) as m0 end.
+ match goal with |- context [Columns.compact ?w ?p] =>
+ remember (Columns.compact w p) as carry_r end.
+ destruct carry_q as [c0 z].
+ destruct carry_r as [c1 a].
- ring_simplify.
- let p := fresh "P" in
- let carry := fresh "carry" in
- let result := fresh "result" in
- match goal with H:fst ?x <> 0 |- _ =>
- remember x as p; destruct p as [carry result];
- autorewrite with cancel_pair in *
- end.
- rewrite Columns.eval_from_associational by assumption.
- autorewrite with uncps push_id push_basesystem_eval.
- rewrite Hm.
- Check Z.add_mod_full.
- match goal with |- context [(?a + ?b - ?c + ?d) mod ?m] =>
- replace (a + b - c + d) with (b + (a-c) + d) by ring;
- rewrite (Z.add_mod_full _ d), (Z.add_mod_full _ (a-c))
- end.
- rewrite !Z.mod_same by assumption.
- rewrite !Z.add_0_r, !Z.add_0_l.
- rewrite !Z.mod_mod by assumption.
- rewrite Z.sub_mod_full.
- rewrite B.Associational.eval_reduce by assumption.
- autorewrite with uncps push_id push_basesystem_eval.
+ Admitted.
+End Freeze.