aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Saturated.v
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-06-08 17:52:32 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2017-06-08 17:52:32 -0400
commit727c2902f1ec078c2359c1690125ae5a5d0e40e4 (patch)
treec05b2bce6d07b6e33738eb0be81e9676e0aeeebd /src/Arithmetic/Saturated.v
parentb9720744fd268072000daa3c1ee5a61e6cc7c954 (diff)
start saturated-arithmetic API for use in Montgomery (see discussion in #157)
Diffstat (limited to 'src/Arithmetic/Saturated.v')
-rw-r--r--src/Arithmetic/Saturated.v137
1 files changed, 123 insertions, 14 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v
index 880adc8f1..f4f17dcee 100644
--- a/src/Arithmetic/Saturated.v
+++ b/src/Arithmetic/Saturated.v
@@ -404,26 +404,25 @@ Module Columns.
autorewrite with uncps push_id push_basesystem_eval in *.
rewrite eval_cons_to_nth by omega. nsatz.
Qed.
-
+
End Columns.
Section Wrappers.
- Context (weight : nat->Z)
- {add_get_carry: Z ->Z -> Z -> (Z * Z)}
- {div modulo : Z -> Z -> Z}.
+ Context (weight : nat->Z).
- Definition add_cps {n} (p q : Z^n) {T} (f : (Z*Z^n)->T) :=
+ Definition add_cps {n1 n2 n3} (p : Z^n1) (q : Z^n2)
+ {T} (f : (Z*Z^n3)->T) :=
B.Positional.to_associational_cps weight p
(fun P => B.Positional.to_associational_cps weight q
- (fun Q => from_associational_cps weight n (P++Q)
- (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f))).
+ (fun Q => from_associational_cps weight n3 (P++Q)
+ (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f))).
Definition sub_cps {n} (p q : Z^n) {T} (f : (Z*Z^n)->T) :=
B.Positional.to_associational_cps weight p
(fun P => B.Positional.negate_snd_cps weight q
(fun nq => B.Positional.to_associational_cps weight nq
(fun Q => from_associational_cps weight n (P++Q)
- (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=add_get_carry) weight R f)))).
+ (fun R => compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))).
End Wrappers.
End Columns.
@@ -452,15 +451,12 @@ Section Freeze.
{weight_positive : forall i, weight i > 0}
{weight_multiples : forall i, weight (S i) mod weight i = 0}
{weight_divides : forall i : nat, weight (S i) / weight i > 0}
- {div modulo : Z -> Z -> Z}
- {div_correct : forall a b, div a b = a / b}
- {modulo_correct : forall a b, modulo a b = a mod b}
.
Definition conditional_add_cps {n} mask cond (p q : Z^n)
{T} (f:_->T) :=
B.Positional.select_cps mask cond q
- (fun qq => Columns.add_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight p qq f).
+ (fun qq => Columns.add_cps (n3:=n) weight p qq f).
Definition conditional_add {n} mask cond p q :=
@conditional_add_cps n mask cond p q _ id.
Lemma conditional_add_id {n} mask cond p q T f:
@@ -482,6 +478,7 @@ Section Freeze.
repeat progress autounfold in *.
pose proof Z.add_get_carry_full_mod.
pose proof Z.add_get_carry_full_div.
+ pose proof div_correct. pose proof modulo_correct.
autorewrite with uncps push_id push_basesystem_eval.
break_match;
match goal with
@@ -509,8 +506,7 @@ Section Freeze.
the carry in step 3 was -1, so they cancel out.
*)
Definition freeze_cps {n} mask (m:Z^n) (p:Z^n) {T} (f : Z^n->T) :=
- Columns.sub_cps (div:=div) (modulo:=modulo)
- (add_get_carry:=Z.add_get_carry_full) weight p m
+ Columns.sub_cps weight p m
(fun carry_p => conditional_add_cps mask (fst carry_p) (snd carry_p) m
(fun carry_r => f (snd carry_r)))
.
@@ -569,6 +565,7 @@ Section Freeze.
repeat progress autounfold.
pose proof Z.add_get_carry_full_mod.
pose proof Z.add_get_carry_full_div.
+ pose proof div_correct. pose proof modulo_correct.
autorewrite with uncps push_id push_basesystem_eval.
pose proof (weight_nonzero n).
@@ -589,6 +586,118 @@ Section Freeze.
Qed.
End Freeze.
+Section API.
+ Context {weight : nat -> Z}
+ {weight_0 : weight 0%nat = 1}
+ {weight_nonzero : forall i, weight i <> 0}
+ {weight_positive : forall i, weight i > 0}
+ {weight_multiples : forall i, weight (S i) mod weight i = 0}
+ {weight_divides : forall i : nat, weight (S i) / weight i > 0}
+ .
+
+ Definition T : Type := list Z.
+
+ Definition length : T -> nat := @length Z.
+
+ Definition divmod (p : T) : T * Z :=
+ (List.tl p, List.hd 0 p).
+
+ (* TODO : scmul (Z -> T -> T), using double-output multiply *)
+
+ Definition add (p q : T) : T * Z :=
+ let P := Tuple.from_list (Datatypes.length p) p (eq_refl _) in
+ let Q := Tuple.from_list (Datatypes.length q) q (eq_refl _) in
+ let n := max (Datatypes.length p) (Datatypes.length q) in
+ let carry_result := Columns.add_cps weight P Q id in
+ (to_list n (snd carry_result), fst carry_result).
+
+ Definition join (to_join : T * Z) : T :=
+ let p := fst to_join in
+ let P := Tuple.from_list (Datatypes.length p) p (eq_refl _) in
+ to_list _ (left_append (snd to_join) P).
+
+ Section Proofs.
+ Definition eval (p : T) : Z :=
+ B.Positional.eval weight (Tuple.from_list (Datatypes.length p) p (eq_refl _)).
+
+ Lemma eval_add_nz p q :
+ max (Datatypes.length p) (Datatypes.length q) <> 0%nat ->
+ eval (fst (add p q)) = eval p + eval q - snd (add p q) * weight (Datatypes.length (fst (add p q))).
+ Proof.
+ intros.
+ pose proof Z.add_get_carry_full_div.
+ pose proof Z.add_get_carry_full_mod.
+ pose proof div_correct. pose proof modulo_correct.
+ repeat match goal with
+ | _ => progress (cbv [add eval]; repeat autounfold)
+ | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval
+ | _ => progress
+ (rewrite <-!from_list_default_eq with (d:=0);
+ erewrite !length_to_list, !from_list_default_eq,
+ from_list_to_list)
+ | _ => rewrite Z.mul_div_eq' by auto using weight_nonzero;
+ omega
+ end.
+ Unshelve. distr_length.
+ Qed.
+
+ Lemma eval_add_z p q :
+ max (Datatypes.length p) (Datatypes.length q) = 0%nat ->
+ eval (fst (add p q)) = eval p + eval q - snd (add p q) * weight (Datatypes.length (fst (add p q))).
+ Proof. destruct p, q; distr_length; reflexivity. Qed.
+
+ Lemma eval_add p q :
+ eval (fst (add p q)) = eval p + eval q - snd (add p q) * weight (Datatypes.length (fst (add p q))).
+ Proof.
+ destruct (Nat.eq_dec (max (Datatypes.length p) (Datatypes.length q)) 0%nat); auto using eval_add_z, eval_add_nz.
+ Qed.
+ Hint Rewrite eval_add : push_basesystem_eval.
+
+ Lemma eval_join to_join :
+ eval (join to_join) = eval (fst to_join) + snd to_join * weight (Datatypes.length (fst to_join)).
+ Proof.
+ repeat match goal with
+ | _ => progress (cbv [join eval]; repeat autounfold)
+ | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval
+ | _ => progress
+ (rewrite <-!from_list_default_eq with (d:=0);
+ erewrite !length_to_list, !from_list_default_eq,
+ from_list_to_list)
+ | _ => rewrite B.Positional.eval_left_append; ring_simplify; omega
+ end.
+ Unshelve. distr_length.
+ Qed.
+ Hint Rewrite eval_join : push_basesystem_eval.
+
+ Lemma eval_join_add p q :
+ eval (join (add p q)) = eval p + eval q.
+ Proof. autorewrite with push_basesystem_eval; omega. Qed.
+
+ (* TODO : move to tuple *)
+ Lemma from_list_tl {A n} (ls : list A) H H':
+ from_list n (List.tl ls) H = tl (from_list (S n) ls H').
+ Proof.
+ induction ls; distr_length. simpl List.tl.
+ rewrite from_list_cons, tl_append, <-!(from_list_default_eq a ls).
+ reflexivity.
+ Qed.
+
+ Lemma divmod_div p : eval (fst (divmod p)) = eval p / weight 1.
+ Proof.
+ repeat match goal with
+ | _ => progress (cbv [divmod eval]; repeat autounfold)
+ | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval
+ end.
+ erewrite from_list_tl.
+ Admitted.
+
+ Lemma divmod_mod p : eval (fst (divmod p)) = eval p / weight 1.
+ Proof.
+ Admitted.
+
+ End Proofs.
+End API.
+
(*
(* Just some pretty-printing *)