diff options
author | jadep <jade.philipoom@gmail.com> | 2017-06-12 22:37:30 -0400 |
---|---|---|
committer | jadep <jade.philipoom@gmail.com> | 2017-06-12 22:37:37 -0400 |
commit | 3b1d3856b138d84cbb0429a6bcdaa9080233fa9b (patch) | |
tree | a2c4d4241fa09bf8199cfc717d2e8d4edaa1954c /src/Arithmetic | |
parent | 5085effb7df589bd346b43685889f077cbbb78f1 (diff) |
finish computational portions of operations needed for Montgomery, and sketch out some of the proofs as discussed in #157
Diffstat (limited to 'src/Arithmetic')
-rw-r--r-- | src/Arithmetic/Saturated.v | 278 |
1 files changed, 219 insertions, 59 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index ea44d6c3d..b72bde4b8 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -103,6 +103,70 @@ check confirms our result. ***) +Module Associational. + Section Associational. + Context {mul_split : Z -> Z -> Z -> Z * Z} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *) + {mul_split_mod : forall s x y, + fst (mul_split s x y) = (x * y) mod s} + {mul_split_div : forall s x y, + snd (mul_split s x y) = (x * y) / s} + . + + Definition multerm_cps s (t t' : B.limb) {T} (f:list B.limb ->T) := + dlet xy := mul_split s (snd t) (snd t') in + f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil). + + Definition multerm s t t' := multerm_cps s t t' id. + Lemma multerm_id s t t' T f : + @multerm_cps s t t' T f = f (multerm s t t'). + Proof. reflexivity. Qed. + Hint Opaque multerm : uncps. + Hint Rewrite multerm_id : uncps. + + Definition mul_cps s (p q : list B.limb) {T} (f : list B.limb -> T) := + flat_map_cps (fun t => @flat_map_cps _ _ (multerm_cps s t) q) p f. + + Definition mul s p q := mul_cps s p q id. + Lemma mul_id s p q T f : @mul_cps s p q T f = f (mul s p q). + Proof. cbv [mul mul_cps]. autorewrite with uncps. reflexivity. Qed. + Hint Opaque mul : uncps. + Hint Rewrite mul_id : uncps. + + Lemma eval_map_multerm s a q (s_nonzero:s<>0): + B.Associational.eval (flat_map (multerm s a) q) = fst a * snd a * B.Associational.eval q. + Proof. + cbv [multerm multerm_cps Let_In]; induction q; + repeat match goal with + | _ => progress (autorewrite with uncps push_id cancel_pair push_basesystem_eval in * ) + | _ => progress simpl flat_map + | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div + | _ => rewrite Z.mod_eq by assumption + | _ => ring_simplify; omega + end. + Qed. + Hint Rewrite eval_map_multerm using (omega || assumption) + : push_basesystem_eval. + + Lemma eval_mul s p q (s_nonzero:s<>0): + B.Associational.eval (mul s p q) = B.Associational.eval p * B.Associational.eval q. + Proof. + cbv [mul mul_cps]; induction p; [reflexivity|]. + repeat match goal with + | _ => progress (autorewrite with uncps push_id push_basesystem_eval in * ) + | _ => progress simpl flat_map + | _ => rewrite IHp + | _ => progress change (fun x => multerm_cps s a x id) with (multerm s a) + | _ => ring_simplify; omega + end. + Qed. + Hint Rewrite eval_mul : push_basesystem_eval. + + End Associational. +End Associational. +Hint Opaque Associational.mul Associational.multerm : uncps. +Hint Rewrite @Associational.mul_id @Associational.multerm_id : uncps. +Hint Rewrite @Associational.eval_mul @Associational.eval_map_multerm : push_basesystem_eval. + Module Columns. Section Columns. Context (weight : nat->Z) @@ -316,6 +380,36 @@ Module Columns. Proof. apply (proj2 (compact_div_mod inp)). Qed. Hint Rewrite @compact_div : push_basesystem_eval. + (* TODO : move to tuple *) + Lemma hd_to_list {A n} a (t : A^(S n)) : List.hd a (to_list (S n) t) = hd t. + Proof. + rewrite (subst_append t), to_list_append, hd_append. reflexivity. + Qed. + + Lemma small_compact {n} x : + hd (n:=n) (snd (Columns.compact x)) < weight 1 / weight 0. + Proof. + match goal with + |- ?G => assert (G /\ fst (compact x) = fst (compact x)); [|tauto] + end. (* assert a dummy second statement so that fst (compact x) is in context *) + cbv [compact compact_cps compact_step compact_step_cps]; + autorewrite with uncps push_id. + change (fun i s a => compact_digit_cps i (s :: a) id) + with (fun i s a => compact_digit i (s :: a)). + remember (fun i s a => compact_digit i (s :: a)) as f. + + rewrite <-hd_to_list with (a:=0). + + apply @mapi_with'_linvariant with (n:=S n) (start:=0) (f:=f) (inp:=x); intros; try tauto; (split; [|reflexivity]). + { let P := fresh "H" in + match goal with H : _ /\ _ |- _ => destruct H end. + destruct n0; simpl. subst f. + { rewrite compact_digit_mod. + apply Z.mod_pos_bound, Z.gt_lt, weight_divides. } + { rewrite hd_to_list in H1. assumption. } } + { simpl. apply Z.gt_lt, weight_divides. } + Qed. + Definition cons_to_nth_cps {n} i (x:Z) (t:(list Z)^n) {T} (f:(list Z)^n->T) := @on_tuple_cps _ _ nil (update_nth_cps i (cons x)) n n t _ f. @@ -424,11 +518,20 @@ Module Columns. (fun Q => from_associational_cps weight n (P++Q) (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). + Definition mul_cps {n1 n2 n3 mul_split} s (p : Z^n1) (q : Z^n2) + {T} (f : (Z*Z^n3)->T) := + B.Positional.to_associational_cps weight p + (fun P => B.Positional.to_associational_cps weight q + (fun Q => Associational.mul_cps (mul_split := mul_split) s P Q + (fun PQ => from_associational_cps weight n3 PQ + (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). + End Wrappers. End Columns. Hint Unfold Columns.add_cps - Columns.sub_cps. + Columns.sub_cps + Columns.mul_cps. Hint Rewrite @Columns.compact_digit_id @Columns.compact_step_id @@ -586,41 +689,65 @@ Section Freeze. Qed. End Freeze. -Section API. - Context {weight : nat -> Z} - {weight_0 : weight 0%nat = 1} - {weight_nonzero : forall i, weight i <> 0} - {weight_positive : forall i, weight i > 0} - {weight_multiples : forall i, weight (S i) mod weight i = 0} - {weight_divides : forall i : nat, weight (S i) / weight i > 0} - . +Section UniformWeight. + Context (bound : Z) {bound_pos : bound > 0}. + + Definition uweight : nat -> Z := fun i => bound ^ Z.of_nat i. + Lemma uweight_0 : uweight 0%nat = 1. Proof. reflexivity. Qed. + Lemma uweight_positive i : uweight i > 0. + Proof. apply Z.lt_gt, Z.pow_pos_nonneg; omega. Qed. + Lemma uweight_nonzero i : uweight i <> 0. + Proof. auto using Z.positive_is_nonzero, uweight_positive. Qed. + Lemma uweight_multiples i : uweight (S i) mod uweight i = 0. + Proof. apply Z.mod_same_pow; rewrite Nat2Z.inj_succ; omega. Qed. + Lemma uweight_divides i : uweight (S i) / uweight i > 0. + Proof. + cbv [uweight]. rewrite <-Z.pow_sub_r by (rewrite ?Nat2Z.inj_succ; omega). + apply Z.lt_gt, Z.pow_pos_nonneg; rewrite ?Nat2Z.inj_succ; omega. + Qed. +End UniformWeight. + +Section API. + Context (bound : Z) {bound_pos : bound > 0}. + Context {mul_split : Z -> Z -> Z -> Z * Z} + {mul_split_mod : forall s x y, + fst (mul_split s x y) = (x * y) mod s} + {mul_split_div : forall s x y, + snd (mul_split s x y) = (x * y) / s}. Definition T : Type := list Z. - Definition length : T -> nat := @length Z. + Definition numlimbs : T -> nat := @length Z. - Definition zero (n:nat) : T := to_list _ (B.Positional.zeros n). + (* lowest limb is less than its bound; this is required for [divmod] + to simply separate the lowest limb from the rest and be equivalent + to normal div/mod with [bound]. *) + Definition small (p : T) : Prop := List.hd 0 p < bound. - Definition divmod (p : T) : T * Z := - (List.tl p, List.hd 0 p). - - (* TODO : scmul (Z -> T -> T), using double-output multiply *) - - Definition add (p q : T) : T * Z := - let P := Tuple.from_list (Datatypes.length p) p (eq_refl _) in - let Q := Tuple.from_list (Datatypes.length q) q (eq_refl _) in - let n := max (Datatypes.length p) (Datatypes.length q) in - let carry_result := Columns.add_cps weight P Q id in - (to_list n (snd carry_result), fst carry_result). - - Definition join (to_join : T * Z) : T := - let p := fst to_join in - let P := Tuple.from_list (Datatypes.length p) p (eq_refl _) in - to_list _ (left_append (snd to_join) P). + Definition zero (n:nat) : T := to_list _ (B.Positional.zeros n). + Definition divmod (p : T) : T * Z := (List.tl p, List.hd 0 p). + + Definition drop_high (n : nat) (p : T) : T := firstn n p. + + Definition scmul (c : Z) (p : T) : T := + let P := Tuple.from_list (length p) p (eq_refl _) in + Columns.mul_cps (mul_split := mul_split) (n1:=1) (n3:=length p) (uweight bound) bound c P + (fun carry_result => + to_list _ (left_append (fst carry_result) (snd carry_result))). + + Definition add (p q : T) : T := + let P := Tuple.from_list (length p) p (eq_refl _) in + let Q := Tuple.from_list (length q) q (eq_refl _) in + dlet n := max (length p) (length q) in + Columns.add_cps (uweight bound) P Q + (fun carry_result => + to_list (S n) (left_append (fst carry_result) (snd carry_result))). + Section Proofs. + Definition eval (p : T) : Z := - B.Positional.eval weight (Tuple.from_list (Datatypes.length p) p (eq_refl _)). + B.Positional.eval (uweight bound) (Tuple.from_list (length p) p (eq_refl _)). Lemma eval_zero n : eval (zero n) = 0. Proof. @@ -631,61 +758,80 @@ Section API. Unshelve. distr_length. Qed. - Lemma length_zero n : length (zero n) = n. + Lemma numlimbs_zero n : numlimbs (zero n) = n. Proof. cbv [eval zero]. apply length_to_list. Qed. + Local Ltac pose_uweight bound := + match goal with H : bound > 0 |- _ => + pose proof (uweight_0 bound); + pose proof (@uweight_positive bound H); + pose proof (@uweight_nonzero bound H); + pose proof (@uweight_multiples bound); + pose proof (@uweight_divides bound H) + end. + Lemma eval_add_nz p q : - max (Datatypes.length p) (Datatypes.length q) <> 0%nat -> - eval (fst (add p q)) = eval p + eval q - snd (add p q) * weight (Datatypes.length (fst (add p q))). + max (length p) (length q) <> 0%nat -> + eval (add p q) = eval p + eval q. Proof. - intros. + intros. pose_uweight bound. pose proof Z.add_get_carry_full_div. pose proof Z.add_get_carry_full_mod. pose proof div_correct. pose proof modulo_correct. repeat match goal with - | _ => progress (cbv [add eval]; repeat autounfold) + | _ => progress (cbv [add eval Let_In]; repeat autounfold) | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval + | _ => rewrite B.Positional.eval_left_append + | _ => progress (rewrite <-!from_list_default_eq with (d:=0); erewrite !length_to_list, !from_list_default_eq, from_list_to_list) - | _ => rewrite Z.mul_div_eq' by auto using weight_nonzero; + | _ => rewrite Z.mul_div_eq by auto; omega end. Unshelve. distr_length. Qed. Lemma eval_add_z p q : - max (Datatypes.length p) (Datatypes.length q) = 0%nat -> - eval (fst (add p q)) = eval p + eval q - snd (add p q) * weight (Datatypes.length (fst (add p q))). + max (length p) (length q) = 0%nat -> + eval (add p q) = eval p + eval q. Proof. destruct p, q; distr_length; reflexivity. Qed. - Lemma eval_add p q : - eval (fst (add p q)) = eval p + eval q - snd (add p q) * weight (Datatypes.length (fst (add p q))). + Lemma eval_add p q : eval (add p q) = eval p + eval q. Proof. - destruct (Nat.eq_dec (max (Datatypes.length p) (Datatypes.length q)) 0%nat); auto using eval_add_z, eval_add_nz. + destruct (Nat.eq_dec (max (length p) (length q)) 0%nat); auto using eval_add_z, eval_add_nz. Qed. Hint Rewrite eval_add : push_basesystem_eval. - - Lemma eval_join to_join : - eval (join to_join) = eval (fst to_join) + snd to_join * weight (Datatypes.length (fst to_join)). + + Lemma small_add a b : small (add a b). Proof. - repeat match goal with - | _ => progress (cbv [join eval]; repeat autounfold) - | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval - | _ => progress - (rewrite <-!from_list_default_eq with (d:=0); - erewrite !length_to_list, !from_list_default_eq, - from_list_to_list) - | _ => rewrite B.Positional.eval_left_append; ring_simplify; omega - end. - Unshelve. distr_length. + intros. pose_uweight bound. + pose proof Z.add_get_carry_full_div. + pose proof Z.add_get_carry_full_mod. + pose proof div_correct. pose proof modulo_correct. + cbv [small add Let_In]. repeat autounfold. + autorewrite with uncps push_id. + destruct (max (length a) (length b)); [simpl; omega |]. + rewrite Columns.hd_to_list, hd_left_append. + eapply Z.lt_le_trans. + { apply Columns.small_compact; auto. } + { cbv [uweight]. simpl Z.of_nat. + autorewrite with zsimplify. + rewrite Z.pow_1_r. reflexivity. } Qed. - Hint Rewrite eval_join : push_basesystem_eval. + + Lemma numlimbs_add a b: numlimbs (add a b) = S (max (numlimbs a) (numlimbs b)). + Proof. + Admitted. - Lemma eval_join_add p q : - eval (join (add p q)) = eval p + eval q. - Proof. autorewrite with push_basesystem_eval; omega. Qed. + Lemma eval_scmul a v: eval (scmul a v) = a * eval v. + Proof. + Admitted. + + Lemma numlimbs_scmul a v: 0 <= a < bound -> + numlimbs (scmul a v) = S (numlimbs v). + Admitted. (* TODO : move to tuple *) Lemma from_list_tl {A n} (ls : list A) H H': @@ -696,18 +842,32 @@ Section API. reflexivity. Qed. - Lemma divmod_div p : eval (fst (divmod p)) = eval p / weight 1. + Lemma eval_div p : small p -> eval (fst (divmod p)) = eval p / bound. Proof. repeat match goal with - | _ => progress (cbv [divmod eval]; repeat autounfold) + | _ => progress (intros; cbv [divmod eval]; repeat autounfold) | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval end. erewrite from_list_tl. Admitted. - Lemma divmod_mod p : eval (fst (divmod p)) = eval p / weight 1. + Lemma eval_mod p : small p -> snd (divmod p) = eval p mod bound. Proof. Admitted. + + Lemma small_div v : small v -> small (fst (divmod v)). + Admitted. + + Lemma numlimbs_div v : numlimbs (fst (divmod v)) = pred (numlimbs v). + Admitted. + + Lemma eval_drop_high n v : + small v -> eval (drop_high n v) = eval v mod (uweight bound n). + Admitted. + + Lemma numlimbs_drop_high n v : + numlimbs (drop_high n v) = min (numlimbs v) n. + Admitted. End Proofs. End API. |