aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jagro@google.com>2016-09-15 16:15:42 -0700
committerGravatar Jason Gross <jgross@mit.edu>2016-09-22 14:58:53 -0400
commit36cffa1f6b04497d8935b466c8362afd5f2ae5c8 (patch)
treef2649356c5ee61635e3e17bcca9c66316e38f52f /src
parent95cd2c60969c8d14e92689336c1d0a93cc105b19 (diff)
Don't inline everything in Montgomery and Barrett
We still use CSE in fancy machine, because we want to lift the ldi's above the rest of the code. However, on a quick inspection, the algorithm no longer needs CSE to be duplicate-free.
Diffstat (limited to 'src')
-rw-r--r--src/BoundedArithmetic/ArchitectureToZLike.v3
-rw-r--r--src/BoundedArithmetic/ArchitectureToZLikeProofs.v2
-rw-r--r--src/BoundedArithmetic/DoubleBounded.v26
-rw-r--r--src/BoundedArithmetic/DoubleBoundedProofs.v31
-rw-r--r--src/Specific/FancyMachine256/Barrett.v61
-rw-r--r--src/Specific/FancyMachine256/Core.v1
-rw-r--r--src/Specific/FancyMachine256/Montgomery.v58
7 files changed, 113 insertions, 69 deletions
diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v
index 3388ece78..80f0d9803 100644
--- a/src/BoundedArithmetic/ArchitectureToZLike.v
+++ b/src/BoundedArithmetic/ArchitectureToZLike.v
@@ -4,6 +4,7 @@ Require Import Crypto.BoundedArithmetic.Interface.
Require Import Crypto.BoundedArithmetic.DoubleBounded.
Require Import Crypto.ModularArithmetic.ZBounded.
Require Import Crypto.Util.Tuple.
+Require Import Crypto.Util.LockedLet.
Local Open Scope Z_scope.
@@ -23,7 +24,7 @@ Section fancy_machine_p256_montgomery_foundation.
DivBy_SmallBound v := snd v;
DivBy_SmallerBound v := if smaller_bound_exp =? n
then snd v
- else shrd (snd v) (fst v) smaller_bound_exp;
+ else llet v := v in shrd (snd v) (fst v) smaller_bound_exp;
Mul x y := muldw x y;
CarryAdd x y := adc x y false;
CarrySubSmall x y := subc x y false;
diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v
index 804296374..243ecb064 100644
--- a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v
+++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v
@@ -8,6 +8,7 @@ Require Import Crypto.BoundedArithmetic.ArchitectureToZLike.
Require Import Crypto.ModularArithmetic.ZBounded.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.ZUtil Crypto.Util.Tactics.
+Require Import Crypto.Util.LockedLet.
Local Open Scope nat_scope.
Local Open Scope Z_scope.
@@ -53,6 +54,7 @@ Section fancy_machine_p256_montgomery_foundation.
Local Ltac post_t_step :=
match goal with
| _ => reflexivity
+ | _ => rewrite !unlock_let
| _ => progress autorewrite with zsimplify_const
| [ |- fst ?x = (?a <=? ?b) :> bool ]
=> cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z);
diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v
index b624c5082..b6aa858ff 100644
--- a/src/BoundedArithmetic/DoubleBounded.v
+++ b/src/BoundedArithmetic/DoubleBounded.v
@@ -6,6 +6,7 @@ Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.ListUtil.
Require Import Crypto.Util.Notations.
+Require Import Crypto.Util.LockedLet.
Local Open Scope nat_scope.
Local Open Scope Z_scope.
@@ -27,10 +28,14 @@ Section ripple_carry_definitions.
: forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k
:= match k return forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k with
| O => f
- | S k' => fun xss yss carry => let '(xs, x) := eta xss in
+ | S k' => fun xss yss carry => llet xss := xss in
+ llet yss := yss in
+ let '(xs, x) := eta xss in
let '(ys, y) := eta yss in
- let '(carry, zs) := eta (@ripple_carry_tuple' _ f k' xs ys carry) in
- let '(carry, z) := eta (f x y carry) in
+ llet addv := (@ripple_carry_tuple' _ f k' xs ys carry) in
+ let '(carry, zs) := eta addv in
+ llet fxy := (f x y carry) in
+ let '(carry, z) := eta fxy in
(carry, (zs, z))
end.
@@ -75,11 +80,16 @@ Section tuple2.
{ldi : load_immediate W}.
Definition mul_double (a b : W) : tuple W 2
- := let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in
- let tmp := mulhwhl a b in
- let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in
- let tmp := mulhwhl b a in
- let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in
+ := llet a := a in
+ llet b := b in
+ let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in
+ llet out := out in
+ llet tmp := mulhwhl a b in
+ llet addv := (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in
+ let '(_, out) := eta addv in
+ llet tmp := mulhwhl b a in
+ llet addv := (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in
+ let '(_, out) := eta addv in
out.
(** Require a dummy [decoder] for these instances to allow
diff --git a/src/BoundedArithmetic/DoubleBoundedProofs.v b/src/BoundedArithmetic/DoubleBoundedProofs.v
index 53ac59d00..95ba35579 100644
--- a/src/BoundedArithmetic/DoubleBoundedProofs.v
+++ b/src/BoundedArithmetic/DoubleBoundedProofs.v
@@ -12,6 +12,7 @@ Require Import Crypto.Util.ZUtil.
Require Import Crypto.Util.ListUtil.
Require Import Crypto.Util.Tactics.
Require Import Crypto.Util.Notations.
+Require Import Crypto.Util.LockedLet.
Import ListNotations.
Local Open Scope list_scope.
@@ -235,6 +236,28 @@ Global Instance decode_mul_double
: forall x y, tuple_decoder (muldw x y) <~=~> (decode x * decode y)%Z
:= proj1 decode_mul_double_iff _.
+
+Lemma ripple_carry_tuple_SS' {T} f k xss yss carry
+ : @ripple_carry_tuple T f (S (S k)) xss yss carry
+ = llet xss := xss in
+ llet yss := yss in
+ let '(xs, x) := eta xss in
+ let '(ys, y) := eta yss in
+ llet addv := (@ripple_carry_tuple _ f (S k) xs ys carry) in
+ let '(carry, zs) := eta addv in
+ llet fxy := (f x y carry) in
+ let '(carry, z) := eta fxy in
+ (carry, (zs, z)).
+Proof. reflexivity. Qed.
+
+Local Ltac eta_expand :=
+ repeat match goal with
+ | _ => rewrite !unlock_let
+ | [ |- context[let '(x, y) := ?e in _] ]
+ => rewrite (surjective_pairing e)
+ | _ => rewrite <- !surjective_pairing
+ end.
+
Lemma ripple_carry_tuple_SS {T} f k xss yss carry
: @ripple_carry_tuple T f (S (S k)) xss yss carry
= let '(xs, x) := eta xss in
@@ -242,7 +265,11 @@ Lemma ripple_carry_tuple_SS {T} f k xss yss carry
let '(carry, zs) := eta (@ripple_carry_tuple _ f (S k) xs ys carry) in
let '(carry, z) := eta (f x y carry) in
(carry, (zs, z)).
-Proof. reflexivity. Qed.
+Proof.
+ rewrite ripple_carry_tuple_SS'.
+ eta_expand.
+ reflexivity.
+Qed.
Lemma carry_is_good (n z0 z1 k : Z)
: 0 <= n ->
@@ -414,7 +441,7 @@ Section tuple2.
Proof.
assert (0 <= 2 * half_n) by eauto using decode_exponent_nonnegative.
assert (0 <= half_n) by omega.
- unfold mul_double.
+ unfold mul_double; eta_expand.
push_decode; autorewrite with simpl_tuple_decoder; simplify_projections.
autorewrite with zsimplify Zshift_to_pow push_Zpow.
rewrite !spread_left_from_shift_half_correct.
diff --git a/src/Specific/FancyMachine256/Barrett.v b/src/Specific/FancyMachine256/Barrett.v
index 7b757cc83..fd57f1fa3 100644
--- a/src/Specific/FancyMachine256/Barrett.v
+++ b/src/Specific/FancyMachine256/Barrett.v
@@ -43,12 +43,14 @@ Section expression.
Local Arguments μ' / .
Local Arguments ldi' / .
+ Local Arguments DoubleBounded.mul_double / .
Definition expression'
:= Eval simpl in
(fun v => proj1_sig (pre_f v)).
+ Local Transparent locked_let.
Definition expression
- := Eval cbv beta iota delta [expression' fst snd] in
+ := Eval cbv beta iota delta [expression' fst snd locked_let] in
fun v => let RegMod := fancy_machine.ldi m in
let RegMu := fancy_machine.ldi μ in
let RegZero := fancy_machine.ldi 0 in
@@ -72,12 +74,9 @@ Section reflected.
(*Compute DefaultRegisters rexpression_simple.*)
Definition registers
- := [RegMod; RegMuLow; x; xHigh; RegMod; RegMuLow; RegZero; tmp;
- qHigh; scratch+3; q; SpecialCarryBit; q;
- SpecialCarryBit; qHigh; scratch+3; SpecialCarryBit;
- q; SpecialCarryBit; qHigh; tmp; scratch+3;
- SpecialCarryBit; tmp; scratch+3; SpecialCarryBit;
- tmp; SpecialCarryBit; tmp; q; out].
+ := [RegMod; RegMuLow; x; xHigh; RegMod; RegMuLow; RegZero; tmp; q; qHigh; scratch+3;
+ SpecialCarryBit; q; SpecialCarryBit; qHigh; scratch+3; SpecialCarryBit; q; SpecialCarryBit; qHigh; tmp;
+ scratch+3; SpecialCarryBit; tmp; scratch+3; SpecialCarryBit; tmp; SpecialCarryBit; tmp; q; out].
Definition compiled_syntax
:= Eval lazy in AssembleSyntax rexpression_simple registers.
@@ -125,30 +124,30 @@ End reflected.
Print compiled_syntax.
(* compiled_syntax =
-fun ops : fancy_machine.instructions 256 =>
-(λn RegMod RegMuLow x xHigh,
- slet RegMod := RegMod in
- slet RegMuLow := RegMuLow in
- slet RegZero := ldi 0 in
- c.Rshi(tmp, xHigh, x, 250),
- c.Mul128(qHigh, c.UpperHalf(tmp), c.UpperHalf(RegMuLow)),
- c.Mul128(scratch+3, c.UpperHalf(tmp), c.LowerHalf(RegMuLow)),
- c.Mul128(q, c.LowerHalf(tmp), c.LowerHalf(RegMuLow)),
- c.Add(q, q, c.LeftShifted{scratch+3, 128}),
- c.Addc(qHigh, qHigh, c.RightShifted{scratch+3, 128}),
- c.Mul128(scratch+3, c.UpperHalf(RegMuLow), c.LowerHalf(tmp)),
- c.Add(q, q, c.LeftShifted{scratch+3, 128}),
- c.Addc(qHigh, qHigh, c.RightShifted{scratch+3, 128}),
- c.Mul128(tmp, c.LowerHalf(qHigh), c.LowerHalf(RegMod)),
- c.Mul128(scratch+3, c.UpperHalf(qHigh), c.LowerHalf(RegMod)),
- c.Add(tmp, tmp, c.LeftShifted{scratch+3, 128}),
- c.Mul128(scratch+3, c.UpperHalf(RegMod), c.LowerHalf(qHigh)),
- c.Add(tmp, tmp, c.LeftShifted{scratch+3, 128}),
- c.Sub(tmp, x, tmp),
- c.Addm(q, tmp, RegZero),
- c.Addm(out, q, RegZero),
- Return out)%nexpr
- : forall ops : fancy_machine.instructions 256,
+fun ops : fancy_machine.instructions (2 * 128) =>
+λn RegMod RegMuLow x xHigh,
+slet RegMod := RegMod in
+slet RegMuLow := RegMuLow in
+slet RegZero := ldi 0 in
+c.Rshi(tmp, xHigh, x, 250),
+c.Mul128(q, c.LowerHalf(tmp), c.LowerHalf(RegMuLow)),
+c.Mul128(qHigh, c.UpperHalf(tmp), c.UpperHalf(RegMuLow)),
+c.Mul128(scratch+3, c.UpperHalf(tmp), c.LowerHalf(RegMuLow)),
+c.Add(q, q, c.LeftShifted{scratch+3, 128}),
+c.Addc(qHigh, qHigh, c.RightShifted{scratch+3, 128}),
+c.Mul128(scratch+3, c.UpperHalf(RegMuLow), c.LowerHalf(tmp)),
+c.Add(q, q, c.LeftShifted{scratch+3, 128}),
+c.Addc(qHigh, qHigh, c.RightShifted{scratch+3, 128}),
+c.Mul128(tmp, c.LowerHalf(qHigh), c.LowerHalf(RegMod)),
+c.Mul128(scratch+3, c.UpperHalf(qHigh), c.LowerHalf(RegMod)),
+c.Add(tmp, tmp, c.LeftShifted{scratch+3, 128}),
+c.Mul128(scratch+3, c.UpperHalf(RegMod), c.LowerHalf(qHigh)),
+c.Add(tmp, tmp, c.LeftShifted{scratch+3, 128}),
+c.Sub(tmp, x, tmp),
+c.Addm(q, tmp, RegZero),
+c.Addm(out, q, RegZero),
+Return out
+ : forall ops : fancy_machine.instructions (2 * 128),
expr base_type
(fun v : base_type =>
match v with
diff --git a/src/Specific/FancyMachine256/Core.v b/src/Specific/FancyMachine256/Core.v
index 440f1aa4f..419f1a24c 100644
--- a/src/Specific/FancyMachine256/Core.v
+++ b/src/Specific/FancyMachine256/Core.v
@@ -20,6 +20,7 @@ Require Export Crypto.Reflection.Reify.
Require Export Crypto.Util.ZUtil.
Require Export Crypto.Util.Notations.
Require Import Crypto.Util.ListUtil.
+Require Export Crypto.Util.LockedLet.
Export ListNotations.
Open Scope Z_scope.
diff --git a/src/Specific/FancyMachine256/Montgomery.v b/src/Specific/FancyMachine256/Montgomery.v
index b899257ca..56879eb57 100644
--- a/src/Specific/FancyMachine256/Montgomery.v
+++ b/src/Specific/FancyMachine256/Montgomery.v
@@ -27,11 +27,13 @@ Section expression.
Local Arguments pre_f / .
Local Arguments ldi' / .
Local Arguments reduce_via_partial / .
+ Local Arguments DoubleBounded.mul_double / .
Definition expression'
:= Eval simpl in f.
+ Local Transparent locked_let.
Definition expression
- := Eval cbv beta delta [expression' fst snd] in
+ := Eval cbv beta delta [expression' fst snd locked_let] in
fun v => let RegMod := fancy_machine.ldi modulus in
let RegPInv := fancy_machine.ldi m' in
let RegZero := fancy_machine.ldi 0 in
@@ -43,7 +45,7 @@ Section expression.
v
Hv
: fancy_machine.decode (expression v) = _
- := @ZBounded.reduce_via_partial_correct (2^256) modulus _ props' (ldi' m') I Hm R' HR0 HR1 v I Hv.
+ := @ZBounded.reduce_via_partial_correct (2^256) modulus _ props' (ldi' m') I Hm R' HR0 HR1 (fst v, snd v) I Hv.
End expression.
Section reflected.
@@ -61,7 +63,7 @@ Section reflected.
Definition registers
:= [RegMod; RegPInv; lo; hi; RegMod; RegPInv; RegZero; y; t1; SpecialCarryBit; y;
- t1; SpecialCarryBit; y; t2; scratch+3; t1; SpecialCarryBit; t1; SpecialCarryBit; t2;
+ t1; SpecialCarryBit; y; t1; t2; scratch+3; SpecialCarryBit; t1; SpecialCarryBit; t2;
scratch+3; SpecialCarryBit; t1; SpecialCarryBit; t2; SpecialCarryBit; lo; SpecialCarryBit; hi; y;
SpecialCarryBit; lo; lo].
@@ -72,6 +74,7 @@ Section reflected.
(props : fancy_machine.arithmetic ops).
Let result (v : tuple fancy_machine.W 2) := Syntax.Interp (interp_op _) rexpression_simple modulus m' (fst v) (snd v).
+
Let assembled_result (v : tuple fancy_machine.W 2) : fancy_machine.W := Core.Interp compiled_syntax modulus m' (fst v) (snd v).
Theorem sanity : result = expression ops modulus m'.
@@ -118,29 +121,29 @@ End reflected.
Print compiled_syntax.
(* compiled_syntax =
fun ops : fancy_machine.instructions (2 * 128) =>
-(λn RegMod RegPInv lo hi,
- slet RegMod := RegMod in
- slet RegPInv := RegPInv in
- slet RegZero := ldi 0 in
- c.Mul128(y, c.LowerHalf(lo), c.LowerHalf(RegPInv)),
- c.Mul128(t1, c.UpperHalf(lo), c.LowerHalf(RegPInv)),
- c.Add(y, y, c.LeftShifted{t1, 128}),
- c.Mul128(t1, c.UpperHalf(RegPInv), c.LowerHalf(lo)),
- c.Add(y, y, c.LeftShifted{t1, 128}),
- c.Mul128(t2, c.UpperHalf(y), c.UpperHalf(RegMod)),
- c.Mul128(scratch+3, c.UpperHalf(y), c.LowerHalf(RegMod)),
- c.Mul128(t1, c.LowerHalf(y), c.LowerHalf(RegMod)),
- c.Add(t1, t1, c.LeftShifted{scratch+3, 128}),
- c.Addc(t2, t2, c.RightShifted{scratch+3, 128}),
- c.Mul128(scratch+3, c.UpperHalf(RegMod), c.LowerHalf(y)),
- c.Add(t1, t1, c.LeftShifted{scratch+3, 128}),
- c.Addc(t2, t2, c.RightShifted{scratch+3, 128}),
- c.Add(lo, lo, t1),
- c.Addc(hi, hi, t2),
- c.Selc(y, RegMod, RegZero),
- c.Sub(lo, hi, y),
- c.Addm(lo, lo, RegZero),
- Return lo)%nexpr
+λn RegMod RegPInv lo hi,
+slet RegMod := RegMod in
+slet RegPInv := RegPInv in
+slet RegZero := ldi 0 in
+c.Mul128(y, c.LowerHalf(lo), c.LowerHalf(RegPInv)),
+c.Mul128(t1, c.UpperHalf(lo), c.LowerHalf(RegPInv)),
+c.Add(y, y, c.LeftShifted{t1, 128}),
+c.Mul128(t1, c.UpperHalf(RegPInv), c.LowerHalf(lo)),
+c.Add(y, y, c.LeftShifted{t1, 128}),
+c.Mul128(t1, c.LowerHalf(y), c.LowerHalf(RegMod)),
+c.Mul128(t2, c.UpperHalf(y), c.UpperHalf(RegMod)),
+c.Mul128(scratch+3, c.UpperHalf(y), c.LowerHalf(RegMod)),
+c.Add(t1, t1, c.LeftShifted{scratch+3, 128}),
+c.Addc(t2, t2, c.RightShifted{scratch+3, 128}),
+c.Mul128(scratch+3, c.UpperHalf(RegMod), c.LowerHalf(y)),
+c.Add(t1, t1, c.LeftShifted{scratch+3, 128}),
+c.Addc(t2, t2, c.RightShifted{scratch+3, 128}),
+c.Add(lo, lo, t1),
+c.Addc(hi, hi, t2),
+c.Selc(y, RegMod, RegZero),
+c.Sub(lo, hi, y),
+c.Addm(lo, lo, RegZero),
+Return lo
: forall ops : fancy_machine.instructions (2 * 128),
expr base_type
(fun v : base_type =>
@@ -148,4 +151,5 @@ fun ops : fancy_machine.instructions (2 * 128) =>
| TZ => Z
| Tbool => bool
| TW => let (W, _, _, _, _, _, _, _, _, _, _, _, _, _) := ops in W
- end) op Register (TZ -> TZ -> TW -> TW -> Tbase TW)%ctype *)
+ end) op Register (TZ -> TZ -> TW -> TW -> Tbase TW)%ctype
+*)