diff options
author | jadep <jade.philipoom@gmail.com> | 2017-06-08 17:52:32 -0400 |
---|---|---|
committer | jadep <jade.philipoom@gmail.com> | 2017-06-08 17:52:32 -0400 |
commit | 727c2902f1ec078c2359c1690125ae5a5d0e40e4 (patch) | |
tree | c05b2bce6d07b6e33738eb0be81e9676e0aeeebd /src/Arithmetic/Saturated.v | |
parent | b9720744fd268072000daa3c1ee5a61e6cc7c954 (diff) |
start saturated-arithmetic API for use in Montgomery (see discussion in #157)
Diffstat (limited to 'src/Arithmetic/Saturated.v')
-rw-r--r-- | src/Arithmetic/Saturated.v | 137 |
1 files changed, 123 insertions, 14 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v index 880adc8f1..f4f17dcee 100644 --- a/src/Arithmetic/Saturated.v +++ b/src/Arithmetic/Saturated.v @@ -404,26 +404,25 @@ Module Columns. autorewrite with uncps push_id push_basesystem_eval in *. rewrite eval_cons_to_nth by omega. nsatz. Qed. - + End Columns. Section Wrappers. - Context (weight : nat->Z) - {add_get_carry: Z ->Z -> Z -> (Z * Z)} - {div modulo : Z -> Z -> Z}. + Context (weight : nat->Z). - Definition add_cps {n} (p q : Z^n) {T} (f : (Z*Z^n)->T) := + Definition add_cps {n1 n2 n3} (p : Z^n1) (q : Z^n2) + {T} (f : (Z*Z^n3)->T) := B.Positional.to_associational_cps weight p (fun P => B.Positional.to_associational_cps weight q - (fun Q => from_associational_cps weight n (P++Q) - (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f))). + (fun Q => from_associational_cps weight n3 (P++Q) + (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f))). Definition sub_cps {n} (p q : Z^n) {T} (f : (Z*Z^n)->T) := B.Positional.to_associational_cps weight p (fun P => B.Positional.negate_snd_cps weight q (fun nq => B.Positional.to_associational_cps weight nq (fun Q => from_associational_cps weight n (P++Q) - (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f)))). + (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). End Wrappers. End Columns. @@ -452,15 +451,12 @@ Section Freeze. {weight_positive : forall i, weight i > 0} {weight_multiples : forall i, weight (S i) mod weight i = 0} {weight_divides : forall i : nat, weight (S i) / weight i > 0} - {div modulo : Z -> Z -> Z} - {div_correct : forall a b, div a b = a / b} - {modulo_correct : forall a b, modulo a b = a mod b} . Definition conditional_add_cps {n} mask cond (p q : Z^n) {T} (f:_->T) := B.Positional.select_cps mask cond q - (fun qq => Columns.add_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight p qq f). + (fun qq => Columns.add_cps (n3:=n) weight p qq f). Definition conditional_add {n} mask cond p q := @conditional_add_cps n mask cond p q _ id. Lemma conditional_add_id {n} mask cond p q T f: @@ -482,6 +478,7 @@ Section Freeze. repeat progress autounfold in *. pose proof Z.add_get_carry_full_mod. pose proof Z.add_get_carry_full_div. + pose proof div_correct. pose proof modulo_correct. autorewrite with uncps push_id push_basesystem_eval. break_match; match goal with @@ -509,8 +506,7 @@ Section Freeze. the carry in step 3 was -1, so they cancel out. *) Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) := - Columns.sub_cps (div:=div) (modulo:=modulo) - (add_get_carry:=Z.add_get_carry_full) weight p m + Columns.sub_cps weight p m (fun carry_p => conditional_add_cps mask (fst carry_p) (snd carry_p) m (fun carry_r => f (snd carry_r))) . @@ -569,6 +565,7 @@ Section Freeze. repeat progress autounfold. pose proof Z.add_get_carry_full_mod. pose proof Z.add_get_carry_full_div. + pose proof div_correct. pose proof modulo_correct. autorewrite with uncps push_id push_basesystem_eval. pose proof (weight_nonzero n). @@ -589,6 +586,118 @@ Section Freeze. Qed. End Freeze. +Section API. + Context {weight : nat -> Z} + {weight_0 : weight 0%nat = 1} + {weight_nonzero : forall i, weight i <> 0} + {weight_positive : forall i, weight i > 0} + {weight_multiples : forall i, weight (S i) mod weight i = 0} + {weight_divides : forall i : nat, weight (S i) / weight i > 0} + . + + Definition T : Type := list Z. + + Definition length : T -> nat := @length Z. + + Definition divmod (p : T) : T * Z := + (List.tl p, List.hd 0 p). + + (* TODO : scmul (Z -> T -> T), using double-output multiply *) + + Definition add (p q : T) : T * Z := + let P := Tuple.from_list (Datatypes.length p) p (eq_refl _) in + let Q := Tuple.from_list (Datatypes.length q) q (eq_refl _) in + let n := max (Datatypes.length p) (Datatypes.length q) in + let carry_result := Columns.add_cps weight P Q id in + (to_list n (snd carry_result), fst carry_result). + + Definition join (to_join : T * Z) : T := + let p := fst to_join in + let P := Tuple.from_list (Datatypes.length p) p (eq_refl _) in + to_list _ (left_append (snd to_join) P). + + Section Proofs. + Definition eval (p : T) : Z := + B.Positional.eval weight (Tuple.from_list (Datatypes.length p) p (eq_refl _)). + + Lemma eval_add_nz p q : + max (Datatypes.length p) (Datatypes.length q) <> 0%nat -> + eval (fst (add p q)) = eval p + eval q - snd (add p q) * weight (Datatypes.length (fst (add p q))). + Proof. + intros. + pose proof Z.add_get_carry_full_div. + pose proof Z.add_get_carry_full_mod. + pose proof div_correct. pose proof modulo_correct. + repeat match goal with + | _ => progress (cbv [add eval]; repeat autounfold) + | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval + | _ => progress + (rewrite <-!from_list_default_eq with (d:=0); + erewrite !length_to_list, !from_list_default_eq, + from_list_to_list) + | _ => rewrite Z.mul_div_eq' by auto using weight_nonzero; + omega + end. + Unshelve. distr_length. + Qed. + + Lemma eval_add_z p q : + max (Datatypes.length p) (Datatypes.length q) = 0%nat -> + eval (fst (add p q)) = eval p + eval q - snd (add p q) * weight (Datatypes.length (fst (add p q))). + Proof. destruct p, q; distr_length; reflexivity. Qed. + + Lemma eval_add p q : + eval (fst (add p q)) = eval p + eval q - snd (add p q) * weight (Datatypes.length (fst (add p q))). + Proof. + destruct (Nat.eq_dec (max (Datatypes.length p) (Datatypes.length q)) 0%nat); auto using eval_add_z, eval_add_nz. + Qed. + Hint Rewrite eval_add : push_basesystem_eval. + + Lemma eval_join to_join : + eval (join to_join) = eval (fst to_join) + snd to_join * weight (Datatypes.length (fst to_join)). + Proof. + repeat match goal with + | _ => progress (cbv [join eval]; repeat autounfold) + | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval + | _ => progress + (rewrite <-!from_list_default_eq with (d:=0); + erewrite !length_to_list, !from_list_default_eq, + from_list_to_list) + | _ => rewrite B.Positional.eval_left_append; ring_simplify; omega + end. + Unshelve. distr_length. + Qed. + Hint Rewrite eval_join : push_basesystem_eval. + + Lemma eval_join_add p q : + eval (join (add p q)) = eval p + eval q. + Proof. autorewrite with push_basesystem_eval; omega. Qed. + + (* TODO : move to tuple *) + Lemma from_list_tl {A n} (ls : list A) H H': + from_list n (List.tl ls) H = tl (from_list (S n) ls H'). + Proof. + induction ls; distr_length. simpl List.tl. + rewrite from_list_cons, tl_append, <-!(from_list_default_eq a ls). + reflexivity. + Qed. + + Lemma divmod_div p : eval (fst (divmod p)) = eval p / weight 1. + Proof. + repeat match goal with + | _ => progress (cbv [divmod eval]; repeat autounfold) + | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval + end. + erewrite from_list_tl. + Admitted. + + Lemma divmod_mod p : eval (fst (divmod p)) = eval p / weight 1. + Proof. + Admitted. + + End Proofs. +End API. + (* (* Just some pretty-printing *) |