diff options
author | jadep <jade.philipoom@gmail.com> | 2017-04-19 09:16:34 -0400 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2017-05-01 14:34:48 -0400 |
commit | 232702b35096cd00b4843c9b283b36dccab18961 (patch) | |
tree | 72b659e64bf62f90fd13932cbf74abd72c45fa81 /src/Arithmetic | |
parent | a81bce39bf121c41f559a90710892b4e43930f5e (diff) |
prove compact_digit obeys div/mod rule
Diffstat (limited to 'src/Arithmetic')
-rw-r--r-- | src/Arithmetic/Saturated.v | 194 |
1 files changed, 112 insertions, 82 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index 87c0e5ec9..884a59ef7 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -105,16 +105,28 @@ Module Columns. Context (weight : nat->Z) {weight_0 : weight 0%nat = 1} {weight_nonzero : forall i, weight i <> 0} + {weight_positive : forall i, weight i > 0} {weight_multiples : forall i, weight (S i) mod weight i = 0} - (* add_get_carry takes in a number at which to split output *) + {weight_divides : forall i : nat, weight (S i) / weight i > 0} + (* add_get_carry takes in a number at which to split output *) {add_get_carry: Z ->Z -> Z -> (Z * Z)} - {add_get_carry_correct : forall s x y, - fst (add_get_carry s x y) = x + y - s * snd (add_get_carry s x y)} + {add_get_carry_mod : forall s x y, + fst (add_get_carry s x y) = (x + y) mod s} + {add_get_carry_div : forall s x y, + snd (add_get_carry s x y) = (x + y) / s} + {div modulo : Z -> Z -> Z} + {div_correct : forall a b, div a b = a / b} + {modulo_correct : forall a b, modulo a b = a mod b} . + Hint Rewrite div_correct modulo_correct add_get_carry_mod add_get_carry_div : div_mod. Definition eval {n} (x : (list Z)^n) : Z := B.Positional.eval weight (Tuple.map sum x). + Lemma eval_unit (x:unit) : eval (n:=0) x = 0. + Proof. reflexivity. Qed. + Hint Rewrite eval_unit : push_basesystem_eval. + Definition eval_from {n} (offset:nat) (x : (list Z)^n) : Z := B.Positional.eval (fun i => weight (i+offset)) (Tuple.map sum x). @@ -138,12 +150,13 @@ Module Columns. Fixpoint compact_digit_cps n (digit : list Z) {T} (f:Z * Z->T) := match digit with | nil => f (0, 0) - | x :: nil => f (0, x) + | x :: nil => f (div x (weight (S n) / weight n), modulo x (weight (S n) / weight n)) | x :: tl => - compact_digit_cps n tl (fun rec => - dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in - dlet carry' := (fst rec + snd sum_carry)%RT in - f (carry', fst sum_carry)) + compact_digit_cps n tl + (fun rec => + dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in + dlet carry' := (fst rec + snd sum_carry)%RT in + f (carry', fst sum_carry)) end. Definition compact_digit n digit := compact_digit_cps n digit id. @@ -175,32 +188,52 @@ Module Columns. Lemma compact_id {n} xs {T} f : @compact_cps n xs T f = f (compact xs). Proof using Type. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed. - Lemma compact_digit_correct i (xs : list Z) : - snd (compact_digit i xs) = sum xs - (weight (S i) / weight i) * (fst (compact_digit i xs)). - Proof using add_get_carry_correct weight_0. + Lemma compact_digit_mod i (xs : list Z) : + snd (compact_digit i xs) = sum xs mod (weight (S i) / weight i). + Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct. + induction xs; cbv [compact_digit]; simpl compact_digit_cps; + cbv [Let_In]; + repeat match goal with + | _ => progress autorewrite with div_mod + | _ => rewrite IHxs, <-Z.add_mod_r + | _ => progress (rewrite ?sum_cons, ?sum_nil in * ) + | _ => progress (autorewrite with uncps push_id cancel_pair in * ) + | _ => progress break_match; try discriminate + | _ => reflexivity + | _ => f_equal; ring + end. + Qed. Hint Rewrite compact_digit_mod : div_mod. + + Lemma compact_digit_div i (xs : list Z) : + fst (compact_digit i xs) = sum xs / (weight (S i) / weight i). + Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides. induction xs; cbv [compact_digit]; simpl compact_digit_cps; cbv [Let_In]; repeat match goal with - | _ => rewrite add_get_carry_correct + | _ => progress autorewrite with div_mod + | _ => rewrite IHxs | _ => progress (rewrite ?sum_cons, ?sum_nil in * ) - | _ => progress (autorewrite with uncps push_id in * ) - | _ => progress (autorewrite with cancel_pair in * ) + | _ => progress (autorewrite with uncps push_id cancel_pair in * ) | _ => progress break_match; try discriminate - | _ => progress ring_simplify | _ => reflexivity - | _ => nsatz + | _ => f_equal; ring end. + assert (weight (S i) / weight i <> 0) by auto using Z.positive_is_nonzero. + match goal with |- _ = (?a + ?X) / ?D => + transitivity ((a + X mod D + D * (X / D)) / D); + [| rewrite (Z.div_mod'' X D) at 3; f_equal; auto; ring] + end. + rewrite Z.div_add' by auto; nsatz. Qed. Definition compact_invariant n i (starter rem:Z) (inp : tuple (list Z) n) (out : tuple Z n) := - B.Positional.eval_from weight i out + weight (i + n) * (rem) - = eval_from i inp + weight i*starter. + B.Positional.eval_from weight i out + weight (i + n) * rem = eval_from i inp + weight i*starter. Lemma compact_invariant_holds n i starter rem inp out : compact_invariant n (S i) (fst (compact_step_cps i starter (hd inp) id)) rem (tl inp) out -> compact_invariant (S n) i starter rem inp (append (snd (compact_step_cps i starter (hd inp) id)) out). Proof using Type*. - cbv [compact_invariant B.Positional.eval_from]; intros. + cbv [compact_invariant B.Positional.eval_from]; intros. repeat match goal with | _ => rewrite B.Positional.eval_step | _ => rewrite eval_from_S @@ -212,14 +245,26 @@ Module Columns. | _ => progress ring_simplify | _ => rewrite ZUtil.Z.mul_div_eq_full by apply weight_nonzero | _ => cbv [compact_step_cps] in *; - autorewrite with uncps push_id; - rewrite compact_digit_correct + autorewrite with uncps push_id in *; + rewrite !compact_digit_mod, !compact_digit_div in * | _ => progress (autorewrite with natsimplify in * ) - end. + end; rewrite B.Positional.eval_wt_equiv with (wtb := fun i0 => weight (i0 + S i)) by (intros; f_equal; try omega). + { + rewrite Z.mod_eq by auto using Z.positive_is_nonzero. + rewrite sum_cons in H. + ring_simplify. + match type of H with + context [?y * (?a / (?y / ?x))] => + replace (y * (a / (y / x))) with (x * (y / x) * (a / (y / x))) in H + by (rewrite Z.mul_div_eq_full by auto using Z.positive_is_nonzero; + rewrite weight_multiples; ring) + end. nsatz. + } Qed. + Lemma compact_invariant_base i rem : compact_invariant 0 i rem rem tt tt. Proof using Type. cbv [compact_invariant]. simpl. repeat (f_equal; try omega). Qed. @@ -366,14 +411,21 @@ Hint Rewrite using (assumption || omega): push_basesystem_eval. Section Freeze. - Context (weight : nat->Z) - {weight_0 : weight 0%nat = 1} - {weight_nonzero : forall i, weight i <> 0} - {weight_multiples : forall i, weight (S i) mod weight i = 0} - (* add_get_carry takes in a number at which to split output *) - {add_get_carry: Z ->Z -> Z -> (Z * Z)} - {add_get_carry_correct : forall s x y, - fst (add_get_carry s x y) = x + y - s * snd (add_get_carry s x y)} + Context (weight : nat->Z) + {weight_0 : weight 0%nat = 1} + {weight_nonzero : forall i, weight i <> 0} + {weight_positive : forall i, weight i > 0} + {weight_multiples : forall i, weight (S i) mod weight i = 0} + {weight_divides : forall i : nat, weight (S i) / weight i > 0} + (* add_get_carry takes in a number at which to split output *) + {add_get_carry: Z ->Z -> Z -> (Z * Z)} + {add_get_carry_mod : forall s x y, + fst (add_get_carry s x y) = (x + y) mod s} + {add_get_carry_div : forall s x y, + snd (add_get_carry s x y) = (x + y) / s} + {div modulo : Z -> Z -> Z} + {div_correct : forall a b, div a b = a / b} + {modulo_correct : forall a b, modulo a b = a mod b} . (* adds p and q if cond is 0, else adds 0 to p*) @@ -397,7 +449,7 @@ Section Freeze. Definition conditional_add_cps {n} mask cond (p q : Z^n) {T} (f:_->T) := conditional_mask_cps mask cond q (fun qq => Columns.add_cps weight p qq - (fun R => Columns.compact_cps (add_get_carry:=add_get_carry) weight R f)). + (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f)). Definition conditional_add {n} mask cond p q := @conditional_add_cps n mask cond p q _ id. Lemma conditional_add_id {n} mask cond p q T f: @@ -430,13 +482,11 @@ Section Freeze. B.Positional.carry_reduce_cps (div:=div) (modulo:=modulo) weight s c p (fun P => Columns.sub_cps weight P m - (fun Q => Columns.compact_cps (add_get_carry:=add_get_carry) weight Q + (fun Q => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight Q (fun carry_q => conditional_add_cps mask (fst carry_q) (snd carry_q) m (fun carry_r => f (snd carry_r))))) . - SearchAbout (((_ mod _) mod _)). - Definition freeze {n} mask s c m p := @freeze_cps n mask s c m p _ id. Lemma freeze_id {n} mask s c m p T f: @@ -483,27 +533,27 @@ Section Freeze. Hint Rewrite @eval_conditional_add using (omega || assumption) : push_basesystem_eval. - Lemma freezeZ m s c y0 z z0 c0 a : + Lemma freezeZ m s c y y0 z z0 c0 a : m = s - c -> 0 < c < s -> s <> 0 -> - -m <= y0 < m -> + 0 <= y < 2*m -> + y0 = y - m -> z = y0 mod s -> c0 = y0 / s -> c0 <> 0 -> - z0 = z + m -> + z0 = z + (if (dec (c0 = 0)) then 0 else m) -> a = z0 mod s -> a mod m = y0 mod m. Proof. - clear. intros. subst. - rewrite Z.add_mod by assumption. - rewrite Z.mod_mod by assumption. - rewrite <-Z.add_mod by assumption. - assert (~ (0 <= y0 < s)) by (pose proof (Z.div_small y0 s); tauto). - assert (-(s-c) <= y0 < 0) by omega. - rewrite Z.mod_small with (b := s) by omega. - rewrite Z.add_mod, Z.mod_same, Z.add_0_r, Z.mod_mod by omega. - reflexivity. + clear. intros. subst. break_match. + { rewrite Z.add_0_r, Z.mod_mod, !Z.mod_small by omega. + reflexivity. } + { rewrite <-Z.add_mod_l, Z.sub_mod_full. + rewrite Z.mod_same, Z.sub_0_r, Z.mod_mod by omega. + rewrite Z.mod_small with (b := s) + by (pose proof (Z.div_small (y - (s-c)) s); omega). + f_equal. ring. } Qed. Lemma eval_freeze {n} mask s c m p @@ -517,43 +567,23 @@ Section Freeze. (B.Positional.eval weight (@freeze n mask s c m p)) (B.Positional.eval weight p). Proof. - cbv [freeze_cps freeze mod_eq]; repeat progress autounfold; - autorewrite with uncps push_id. - - assert (Z.pos modulus <> 0) by (pose proof Pos2Z.is_pos modulus; omega). - pose proof div_mod. - break_match; subst. - - rewrite Z.mul_0_r, Z.add_0_r, Z.sub_0_r. - (* TODO : how to prove second carry is 0? *) - rewrite Z.add_mod, B.Associational.eval_reduce by assumption. - autorewrite with uncps push_id push_basesystem_eval. - rewrite Hm. autorewrite with zsimplify. - reflexivity. - + cbv [freeze_cps freeze mod_eq conditional_add_cps]. + repeat progress autounfold. + autorewrite with uncps push_id. + + match goal with |- context [B.Associational.reduce ?s ?c ?p] => + remember (B.Associational.reduce s c p) as y end. + match goal with |- context [fst ?x] => + remember x as carry_q end. + match goal with |- context [conditional_mask ?mask ?cond ?p] => + remember (conditional_mask mask cond p) as m0 end. + match goal with |- context [Columns.compact ?w ?p] => + remember (Columns.compact w p) as carry_r end. + destruct carry_q as [c0 z]. + destruct carry_r as [c1 a]. - ring_simplify. - let p := fresh "P" in - let carry := fresh "carry" in - let result := fresh "result" in - match goal with H:fst ?x <> 0 |- _ => - remember x as p; destruct p as [carry result]; - autorewrite with cancel_pair in * - end. - rewrite Columns.eval_from_associational by assumption. - autorewrite with uncps push_id push_basesystem_eval. - rewrite Hm. - Check Z.add_mod_full. - match goal with |- context [(?a + ?b - ?c + ?d) mod ?m] => - replace (a + b - c + d) with (b + (a-c) + d) by ring; - rewrite (Z.add_mod_full _ d), (Z.add_mod_full _ (a-c)) - end. - rewrite !Z.mod_same by assumption. - rewrite !Z.add_0_r, !Z.add_0_l. - rewrite !Z.mod_mod by assumption. - rewrite Z.sub_mod_full. - rewrite B.Associational.eval_reduce by assumption. - autorewrite with uncps push_id push_basesystem_eval. + Admitted. +End Freeze. |