aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2016-10-17 18:20:33 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2016-10-19 12:37:32 -0400
commit077a20a0018c9823c2568eb624122f48ab35c1d5 (patch)
tree1352f75bb2c65c135e7e6dbb082b5a1f3830f0df /src/ModularArithmetic
parentc1f4f952a034a4a9b8677a5b4823f97a7fab2252 (diff)
Add opt versions of add, sub, opp
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v21
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v107
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v2
3 files changed, 118 insertions, 12 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index 1769f86c4..ff2c23ab8 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -60,9 +60,22 @@ Section ModularBaseSystem.
(* Placeholder *)
Definition div (x y : digits) : digits := encode (F.div (decode x) (decode y)).
+ Definition carry_ (carry_chain : list nat) (us : digits) : digits :=
+ from_list (carry_sequence carry_chain [[us]]) (length_carry_sequence length_to_list).
+
+ Definition carry_add (carry_chain : list nat) (us vs : digits) : digits :=
+ carry_ carry_chain (add us vs).
Definition carry_mul (carry_chain : list nat) (us vs : digits) : digits :=
- from_list (carry_sequence carry_chain [[mul us vs]]) (length_carry_sequence length_to_list).
-
+ carry_ carry_chain (mul us vs).
+ Definition carry_sub (carry_chain : list nat) (modulus_multiple: digits)
+ (modulus_multiple_correct : decode modulus_multiple = 0%F)
+ (us vs : digits) : digits :=
+ carry_ carry_chain (sub modulus_multiple modulus_multiple_correct us vs).
+ Definition carry_opp (carry_chain : list nat) (modulus_multiple : digits)
+ (modulus_multiple_correct : decode modulus_multiple = 0%F)
+ (x : digits) : digits :=
+ carry_ carry_chain (opp modulus_multiple modulus_multiple_correct x).
+
Definition rep (us : digits) (x : F modulus) := decode us = x.
Local Notation "u ~= x" := (rep u x).
Local Hint Unfold rep.
@@ -102,8 +115,8 @@ Section ModularBaseSystem.
Definition pack (x : digits) : target_digits :=
from_list (pack target_widths_nonneg bits_eq [[x]]) length_pack.
-
+
Definition unpack (x : target_digits) : digits :=
from_list (unpack target_widths_nonneg bits_eq [[x]]) length_unpack.
-End ModularBaseSystem. \ No newline at end of file
+End ModularBaseSystem.
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index 3b50662a6..f1b8c601b 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -410,8 +410,32 @@ Section Carries.
End Carries.
+Section CarryChain.
+ Context `{prm : PseudoMersenneBaseParams} {cc : CarryChain limb_widths}.
+ Local Notation digits := (tuple Z (length limb_widths)).
+
+ Definition carry__opt_sig {T} (f : digits -> T) (us : digits)
+ : { x | x = f (carry_ carry_chain us) }.
+ Proof.
+ eexists.
+ cbv [carry_].
+ rewrite <- from_list_default_eq with (d := 0%Z).
+ change @from_list_default with @from_list_default_opt.
+ erewrite <-carry_sequence_opt_cps_correct by eauto using carry_chain_valid, length_to_list.
+ cbv [carry_sequence_opt_cps].
+ reflexivity.
+ Defined.
+
+ Definition carry__opt_cps {T} (f:digits -> T) (us : digits) : T
+ := Eval cbv [proj1_sig carry__opt_sig] in proj1_sig (carry__opt_sig f us).
+
+ Definition carry__opt_cps_correct {T} (f:digits -> T) (us : digits)
+ : carry__opt_cps f us = f (carry_ carry_chain us)
+ := proj2_sig (carry__opt_sig f us).
+End CarryChain.
+
Section Addition.
- Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient}.
+ Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} {cc : CarryChain limb_widths}.
Local Notation digits := (tuple Z (length limb_widths)).
Definition add_opt_sig (us vs : digits) : { b : digits | b = add us vs }.
@@ -426,10 +450,34 @@ Section Addition.
Definition add_opt_correct us vs
: add_opt us vs = add us vs
:= proj2_sig (add_opt_sig us vs).
+
+ Definition carry_add_opt_sig {T} (f:digits -> T)
+ (us vs : digits) : { x | x = f (carry_add carry_chain us vs) }.
+ Proof.
+ eexists.
+ cbv [carry_add].
+ rewrite <-carry__opt_cps_correct, <-add_opt_correct.
+ cbv [carry_sequence_opt_cps carry__opt_cps add_opt add].
+ rewrite to_list_from_list.
+ reflexivity.
+ Defined.
+
+ Definition carry_add_opt_cps {T} (f:digits -> T) (us vs : digits) : T
+ := Eval cbv [proj1_sig carry_add_opt_sig] in proj1_sig (carry_add_opt_sig f us vs).
+
+ Definition carry_add_opt_cps_correct {T} (f:digits -> T) (us vs : digits)
+ : carry_add_opt_cps f us vs = f (carry_add carry_chain us vs)
+ := proj2_sig (carry_add_opt_sig f us vs).
+
+ Definition carry_add_opt := carry_add_opt_cps id.
+
+ Definition carry_add_opt_correct (us vs : digits)
+ : carry_add_opt us vs = carry_add carry_chain us vs :=
+ carry_add_opt_cps_correct id us vs.
End Addition.
Section Subtraction.
- Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient}.
+ Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} {cc : CarryChain limb_widths}.
Local Notation digits := (tuple Z (length limb_widths)).
Definition sub_opt_sig (us vs : digits) : { b : digits | b = sub coeff coeff_mod us vs }.
@@ -446,6 +494,30 @@ Section Subtraction.
: sub_opt us vs = sub coeff coeff_mod us vs
:= proj2_sig (sub_opt_sig us vs).
+ Definition carry_sub_opt_sig {T} (f:digits -> T)
+ (us vs : digits) : { x | x = f (carry_sub carry_chain coeff coeff_mod us vs) }.
+ Proof.
+ eexists.
+ cbv [carry_sub].
+ rewrite <-carry__opt_cps_correct, <-sub_opt_correct.
+ cbv [carry_sequence_opt_cps carry__opt_cps sub_opt].
+ rewrite to_list_from_list.
+ reflexivity.
+ Defined.
+
+ Definition carry_sub_opt_cps {T} (f:digits -> T) (us vs : digits) : T
+ := Eval cbv [proj1_sig carry_sub_opt_sig] in proj1_sig (carry_sub_opt_sig f us vs).
+
+ Definition carry_sub_opt_cps_correct {T} (f:digits -> T) (us vs : digits)
+ : carry_sub_opt_cps f us vs = f (carry_sub carry_chain coeff coeff_mod us vs)
+ := proj2_sig (carry_sub_opt_sig f us vs).
+
+ Definition carry_sub_opt := carry_sub_opt_cps id.
+
+ Definition carry_sub_opt_correct (us vs : digits)
+ : carry_sub_opt us vs = carry_sub carry_chain coeff coeff_mod us vs :=
+ carry_sub_opt_cps_correct id us vs.
+
Definition opp_opt_sig (us : digits) : { b : digits | b = opp coeff coeff_mod us }.
Proof.
eexists.
@@ -461,6 +533,30 @@ Section Subtraction.
: opp_opt us = opp coeff coeff_mod us
:= proj2_sig (opp_opt_sig us).
+ Definition carry_opp_opt_sig {T} (f:digits -> T)
+ (us : digits) : { x | x = f (carry_opp carry_chain coeff coeff_mod us) }.
+ Proof.
+ eexists.
+ cbv [carry_opp].
+ rewrite <-carry__opt_cps_correct, <-opp_opt_correct.
+ cbv [carry_sequence_opt_cps carry__opt_cps opp_opt opp sub_opt].
+ rewrite to_list_from_list.
+ reflexivity.
+ Defined.
+
+ Definition carry_opp_opt_cps {T} (f:digits -> T) (us : digits) : T
+ := Eval cbv [proj1_sig carry_opp_opt_sig] in proj1_sig (carry_opp_opt_sig f us).
+
+ Definition carry_opp_opt_cps_correct {T} (f:digits -> T) (us : digits)
+ : carry_opp_opt_cps f us = f (carry_opp carry_chain coeff coeff_mod us)
+ := proj2_sig (carry_opp_opt_sig f us).
+
+ Definition carry_opp_opt := carry_opp_opt_cps id.
+
+ Definition carry_opp_opt_correct (us : digits)
+ : carry_opp_opt us = carry_opp carry_chain coeff coeff_mod us :=
+ carry_opp_opt_cps_correct id us.
+
End Subtraction.
Section Multiplication.
@@ -616,11 +712,8 @@ Section Multiplication.
Proof.
eexists.
cbv [carry_mul].
- rewrite <- from_list_default_eq with (d := 0%Z).
- change @from_list_default with @from_list_default_opt.
- erewrite <-carry_sequence_opt_cps_correct by eauto using carry_chain_valid, length_to_list.
- erewrite <-mul_opt_correct.
- cbv [carry_sequence_opt_cps mul_opt].
+ rewrite <-carry__opt_cps_correct, <-mul_opt_correct.
+ cbv [carry_sequence_opt_cps carry__opt_cps mul_opt].
erewrite from_list_default_eq.
rewrite to_list_from_list.
reflexivity.
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
index c160eca7f..9cef5710d 100644
--- a/src/ModularArithmetic/ModularBaseSystemProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -360,7 +360,7 @@ Section FieldOperationProofs.
+ eapply _zero_neq_one.
+ trivial.
Qed.
- End FieldProofs.
+End FieldProofs.
End FieldOperationProofs.
Opaque encode add mul sub inv pow.