aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystemOpt.v
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-08-16 18:29:48 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-08-16 18:29:48 -0400
commitb3f72699177d4448201daf857d07bef9ede5a2d3 (patch)
tree95c91bcf93c543533de6e9b8b359c65a16e4c994 /src/ModularArithmetic/ModularBaseSystemOpt.v
parent82defeaac51f1b76217fcb548a8115449946e432 (diff)
Added optimized versions of [pack] and [unpack] to ModularBaseSystemOpt. Further optimization, including the unrolling of the entire loop, can be done in Specific/ once limb widths of both ModularBaseSystem format and wire format are known.
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystemOpt.v')
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v100
1 files changed, 99 insertions, 1 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index 0be74a3c0..4349ee9d9 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -27,13 +27,19 @@ Class SubtractionCoefficient (m : Z) (prm : PseudoMersenneBaseParams m) := {
(* Computed versions of some functions. *)
+Definition plus_opt := Eval compute in plus.
+
Definition Z_add_opt := Eval compute in Z.add.
Definition Z_sub_opt := Eval compute in Z.sub.
Definition Z_mul_opt := Eval compute in Z.mul.
Definition Z_div_opt := Eval compute in Z.div.
Definition Z_pow_opt := Eval compute in Z.pow.
Definition Z_opp_opt := Eval compute in Z.opp.
+Definition Z_min_opt := Eval compute in Z.min.
Definition Z_ones_opt := Eval compute in Z.ones.
+Definition Z_of_nat_opt := Eval compute in Z.of_nat.
+Definition Z_le_dec_opt := Eval compute in Z_le_dec.
+Definition Z_lt_dec_opt := Eval compute in Z_lt_dec.
Definition Z_shiftl_opt := Eval compute in Z.shiftl.
Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by.
@@ -47,6 +53,10 @@ Definition base_from_limb_widths_opt := Eval compute in @Pow2Base.base_from_limb
Definition minus_opt := Eval compute in minus.
Definition max_ones_opt := Eval compute in @max_ones.
Definition from_list_default_opt {A} := Eval compute in (@from_list_default A).
+Definition sum_firstn_opt {A} := Eval compute in (@sum_firstn A).
+Definition zeros_opt := Eval compute in (@zeros).
+Definition bit_index_opt := Eval compute in bit_index.
+Definition digit_index_opt := Eval compute in digit_index.
Definition Let_In {A P} (x : A) (f : forall y : A, P y)
:= let y := x in f y.
@@ -407,7 +417,6 @@ Section Addition.
Definition add_opt_sig (us vs : digits) : { b : digits | b = add us vs }.
Proof.
eexists.
- cbv [BaseSystem.add].
reflexivity.
Defined.
@@ -622,6 +631,95 @@ Section Multiplication.
End Multiplication.
+Section Conversion.
+
+ Definition convert'_opt_sig {lwA lwB}
+ (nonnegA : forall x, In x lwA -> 0 <= x)
+ (nonnegB : forall x, In x lwB -> 0 <= x)
+ bits_fit inp i out :
+ { y | y = convert' nonnegA nonnegB bits_fit inp i out}.
+ Proof.
+ eexists.
+ rewrite convert'_equation.
+ change sum_firstn with @sum_firstn_opt.
+ change length with length_opt.
+ change Z_le_dec with Z_le_dec_opt.
+ change Z.of_nat with Z_of_nat_opt.
+ change digit_index with digit_index_opt.
+ change bit_index with bit_index_opt.
+ change Z.min with Z_min_opt.
+ change (nth_default 0 lwA) with (nth_default_opt 0 lwA).
+ change (nth_default 0 lwB) with (nth_default_opt 0 lwB).
+ cbv [update_by_concat_bits concat_bits Z.pow2_mod].
+ change Z.ones with Z_ones_opt.
+ change @update_nth with @update_nth_opt.
+ change plus with plus_opt.
+ change Z.sub with Z_sub_opt.
+ reflexivity.
+ Defined.
+
+ Definition convert'_opt {lwA lwB}
+ (nonnegA : forall x, In x lwA -> 0 <= x)
+ (nonnegB : forall x, In x lwB -> 0 <= x)
+ bits_fit inp i out :=
+ Eval cbv [proj1_sig convert'_opt_sig] in
+ proj1_sig (convert'_opt_sig nonnegA nonnegB bits_fit inp i out).
+
+ Definition convert'_opt_correct {lwA lwB}
+ (nonnegA : forall x, In x lwA -> 0 <= x)
+ (nonnegB : forall x, In x lwB -> 0 <= x)
+ bits_fit inp i out :
+ convert'_opt nonnegA nonnegB bits_fit inp i out = convert' nonnegA nonnegB bits_fit inp i out :=
+ Eval cbv [proj2_sig convert'_opt_sig] in
+ proj2_sig (convert'_opt_sig nonnegA nonnegB bits_fit inp i out).
+
+ Context {modulus} (prm : PseudoMersenneBaseParams modulus)
+ {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) (bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn target_widths (length target_widths)).
+ Local Notation digits := (tuple Z (length limb_widths)).
+ Local Notation target_digits := (tuple Z (length target_widths)).
+
+ Definition pack_opt_sig (x : digits) : { y | y = pack target_widths_nonneg bits_eq x}.
+ Proof.
+ eexists.
+ cbv [pack].
+ rewrite <- from_list_default_eq with (d := 0%Z).
+ change @from_list_default with @from_list_default_opt.
+ cbv [ModularBaseSystemList.pack convert].
+ change length with length_opt.
+ change sum_firstn with @sum_firstn_opt.
+ change zeros with zeros_opt.
+ reflexivity.
+ Defined.
+
+ Definition pack_opt (x : digits) : target_digits :=
+ Eval cbv [proj1_sig pack_opt_sig] in proj1_sig (pack_opt_sig x).
+
+ Definition pack_correct (x : digits) :
+ pack_opt x = pack target_widths_nonneg bits_eq x
+ := Eval cbv [proj2_sig pack_opt_sig] in proj2_sig (pack_opt_sig x).
+
+ Definition unpack_opt_sig (x : target_digits) : { y | y = unpack target_widths_nonneg bits_eq x}.
+ Proof.
+ eexists.
+ cbv [unpack].
+ rewrite <- from_list_default_eq with (d := 0%Z).
+ change @from_list_default with @from_list_default_opt.
+ cbv [ModularBaseSystemList.unpack convert].
+ change length with length_opt.
+ change sum_firstn with @sum_firstn_opt.
+ change zeros with zeros_opt.
+ reflexivity.
+ Defined.
+
+ Definition unpack_opt (x : target_digits) : digits :=
+ Eval cbv [proj1_sig unpack_opt_sig] in proj1_sig (unpack_opt_sig x).
+
+ Definition unpack_correct (x : target_digits) :
+ unpack_opt x = unpack target_widths_nonneg bits_eq x
+ := Eval cbv [proj2_sig unpack_opt_sig] in proj2_sig (unpack_opt_sig x).
+
+End Conversion.
+
Section with_base.
Context {modulus} (prm : PseudoMersenneBaseParams modulus).
Local Notation base := (Pow2Base.base_from_limb_widths limb_widths).