aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2017-06-02 14:46:06 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2017-06-02 17:20:58 -0400
commit7534db3ca225bf961059de97113b6dfab059299e (patch)
treea0814b0b3f4a595712932970770aa69be663aa9a /src
parent1714206f23d47c65cf423c17883f44bbe8937276 (diff)
pulled in a CPS version of Karatsuba from another branch
Diffstat (limited to 'src')
-rw-r--r--src/Specific/Karatsuba.v402
1 files changed, 402 insertions, 0 deletions
diff --git a/src/Specific/Karatsuba.v b/src/Specific/Karatsuba.v
new file mode 100644
index 000000000..0f205f253
--- /dev/null
+++ b/src/Specific/Karatsuba.v
@@ -0,0 +1,402 @@
+Require Import Coq.ZArith.ZArith Coq.ZArith.BinIntDef.
+Require Import Coq.Lists.List. Import ListNotations.
+Require Import Crypto.Arithmetic.Core. Import B.
+Require Import Crypto.Arithmetic.PrimeFieldTheorems.
+Require Import (*Crypto.Util.Tactics*) Crypto.Util.Decidable.
+Require Import Crypto.Util.LetIn Crypto.Util.ZUtil Crypto.Util.Tactics.
+Require Import Crypto.Arithmetic.Karatsuba.
+Require Crypto.Util.Tuple.
+Local Notation tuple := Tuple.tuple.
+Local Open Scope list_scope.
+Local Open Scope Z_scope.
+Local Coercion Z.of_nat : nat >-> Z.
+
+(***
+Modulus : 2^448-2^224-1
+Base: 56
+***)
+Section Ops51.
+ Local Infix "^" := tuple : type_scope.
+
+ (* These definitions will need to be passed as Ltac arguments (or
+ cleverly inferred) when things are eventually automated *)
+ Definition sz := 8%nat.
+ Definition bitwidth := 64.
+ Definition s : Z := 2^448.
+ Definition c : list B.limb := [(1, 1); (2^224, 1)].
+ Definition coef_div_modulus : nat := 2. (* add 2*modulus before subtracting *)
+ Definition carry_chain1 := Eval vm_compute in (seq 0 (pred sz)).
+ Definition carry_chain2 := ([0;1])%nat.
+ Definition a24 := 121665%Z.
+
+ (* These definitions are inferred from those above *)
+ Definition m := Eval vm_compute in Z.to_pos (s - Associational.eval c). (* modulus *)
+ Definition wt := fun i : nat =>
+ let si := Z.log2 s * i in
+ 2 ^ ((si/sz) + (if dec ((si/sz)*sz=si) then 0 else 1)).
+ Definition sz2 := Eval vm_compute in ((sz * 2) - 1)%nat.
+ Definition m_enc :=
+ Eval vm_compute in (Positional.encode (modulo:=modulo) (div:=div) (n:=sz) wt (s-Associational.eval c)).
+ Definition coef := (* subtraction coefficient *)
+ Eval vm_compute in
+ ((fix addm (acc: Z^sz) (ctr : nat) : Z^sz :=
+ match ctr with
+ | O => acc
+ | S n => addm (Positional.add_cps wt acc m_enc id) n
+ end) (Positional.zeros sz) coef_div_modulus).
+ Definition coef_mod : mod_eq m (Positional.eval (n:=sz) wt coef) 0 := eq_refl.
+
+ Lemma sz_nonzero : sz <> 0%nat. Proof. vm_decide. Qed.
+ Lemma wt_nonzero i : wt i <> 0.
+ Proof.
+ apply Z.pow_nonzero; zero_bounds; try break_match; vm_decide.
+ Qed.
+
+ Lemma wt_divides_chain1 i (H:In i carry_chain1) : wt (S i) / wt i <> 0.
+ Proof.
+ cbv [In carry_chain1] in H.
+ repeat match goal with H : _ \/ _ |- _ => destruct H end;
+ try (exfalso; assumption); subst; try vm_decide.
+ Qed.
+ Lemma wt_divides_chain2 i (H:In i carry_chain2) : wt (S i) / wt i <> 0.
+ Proof.
+ cbv [In carry_chain2] in H.
+ repeat match goal with H : _ \/ _ |- _ => destruct H end;
+ try (exfalso; assumption); subst; try vm_decide.
+ Qed.
+ Lemma wt_divides_full i : wt (S i) / wt i <> 0.
+ Proof.
+ cbv [wt].
+ match goal with |- _ ^ ?x / _ ^ ?y <> _ => assert (0 <= y <= x) end.
+ { rewrite Nat2Z.inj_succ.
+ split; try break_match; ring_simplify;
+ repeat match goal with
+ | _ => apply Z.div_le_mono; try vm_decide; [ ]
+ | _ => apply Z.mul_le_mono_nonneg_l; try vm_decide; [ ]
+ | _ => apply Z.add_le_mono; try vm_decide; [ ]
+ | |- ?x <= ?y + 1 => assert (x <= y); [|omega]
+ | |- ?x + 1 <= ?y => rewrite <- Z.div_add by vm_decide
+ | _ => progress zero_bounds
+ | _ => progress ring_simplify
+ | _ => vm_decide
+ end. }
+ break_match; rewrite <-Z.pow_sub_r by omega;
+ apply Z.pow_nonzero; omega.
+ Qed.
+
+ Local Ltac solve_constant_sig :=
+ lazymatch goal with
+ | [ |- { c : Z^?sz | Positional.Fdecode (m:=?M) ?wt c = ?v } ]
+ => let t := (eval vm_compute in
+ (Positional.encode (n:=sz) (modulo:=modulo) (div:=div) wt (F.to_Z (m:=M) v))) in
+ (exists t; vm_decide)
+ end.
+
+ Definition zero_sig :
+ { zero : Z^sz | Positional.Fdecode (m:=m) wt zero = 0%F}.
+ Proof.
+ solve_constant_sig.
+ Defined.
+
+ Definition one_sig :
+ { one : Z^sz | Positional.Fdecode (m:=m) wt one = 1%F}.
+ Proof.
+ solve_constant_sig.
+ Defined.
+
+ Definition a24_sig :
+ { a24t : Z^sz | Positional.Fdecode (m:=m) wt a24t = F.of_Z m a24 }.
+ Proof.
+ solve_constant_sig.
+ Defined.
+
+ Definition add_sig :
+ { add : (Z^sz -> Z^sz -> Z^sz)%type |
+ forall a b : Z^sz,
+ let eval := Positional.Fdecode (m:=m) wt in
+ eval (add a b) = (eval a + eval b)%F }.
+ Proof.
+ eexists; cbv beta zeta; intros.
+ pose proof wt_nonzero.
+ let x := constr:(
+ Positional.add_cps (n := sz) wt a b id) in
+ solve_op_F wt x. reflexivity.
+ Defined.
+
+ Definition sub_sig :
+ {sub : (Z^sz -> Z^sz -> Z^sz)%type |
+ forall a b : Z^sz,
+ let eval := Positional.Fdecode (m:=m) wt in
+ eval (sub a b) = (eval a - eval b)%F}.
+ Proof.
+ eexists; cbv beta zeta; intros.
+ pose proof wt_nonzero.
+ let x := constr:(
+ Positional.sub_cps (n:=sz) (coef := coef) wt a b id) in
+ solve_op_F wt x. reflexivity.
+ Defined.
+
+ Definition opp_sig :
+ {opp : (Z^sz -> Z^sz)%type |
+ forall a : Z^sz,
+ let eval := Positional.Fdecode (m := m) wt in
+ eval (opp a) = F.opp (eval a)}.
+ Proof.
+ eexists; cbv beta zeta; intros.
+ pose proof wt_nonzero.
+ let x := constr:(
+ Positional.opp_cps (n:=sz) (coef := coef) wt a id) in
+ solve_op_F wt x. reflexivity.
+ Defined.
+
+ Check goldilocks_mul_cps.
+ Definition half_sz : nat := Eval compute in (sz / 2).
+
+ (* TODO: move *)
+ Definition Positional_split_cps {n m1 m2} (s:Z) (p : tuple Z n)
+ {T} (f:(tuple Z m1 * tuple Z m2) -> T) :=
+ Positional.to_associational_cps wt p
+ (fun P => Associational.split_cps s P
+ (fun split_P =>
+ f (Positional.from_associational wt m1 (fst split_P),
+ (Positional.from_associational wt m2 (snd split_P))))).
+ Definition Positional_scmul_cps {n} (x : Z) (p: tuple Z n)
+ {T} (f:tuple Z n->T) :=
+ Positional.to_associational_cps wt p
+ (fun P => Associational.mul_cps P [(1, x)]
+ (fun R => Positional.from_associational_cps wt n R f)).
+ Definition Positional_sub_cps {n} (p q: tuple Z n)
+ {T} (f:tuple Z n->T) :=
+ Positional.to_associational_cps wt p
+ (fun P => Positional.to_associational_cps wt q
+ (fun Q => Associational.negate_snd_cps Q
+ (fun negQ => Positional.from_associational_cps wt n (P ++ negQ) f))).
+ Definition goldilocks448_cps :=
+ (goldilocks_mul_cps
+ (T := tuple Z half_sz) (T2 := tuple Z sz)
+ (mul_cps := Positional.mul_cps (n:=half_sz) wt)
+ (add_cps := Positional.add_cps (n:=half_sz) wt)
+ (add2_cps := Positional.add_cps (n:=sz) wt)
+ (split_cps := Positional_split_cps (n:=sz) (m1:=half_sz) (m2 := half_sz))
+ (scmul2_cps := Positional_scmul_cps (n:=sz))
+ (sub2_cps := Positional_sub_cps (n:=sz))
+ ).
+ Hint Unfold goldilocks448_cps.
+ Check goldilocks_mul_id.
+ Definition goldilocks448_id
+ mul_id add_id add2_id split_id scmul2_id sub2_id
+ :=
+ (goldilocks_mul_id
+ (T := tuple Z half_sz) (T2 := tuple Z sz)
+ (mul_cps := Positional.mul_cps (n:=half_sz) wt)
+ (add_cps := Positional.add_cps (n:=half_sz) wt)
+ (add2_cps := Positional.add_cps (n:=sz) wt)
+ (split_cps := Positional_split_cps (n:=sz) (m1:=half_sz) (m2 := half_sz))
+ (scmul2_cps := Positional_scmul_cps (n:=sz))
+ (sub2_cps := Positional_sub_cps (n:=sz))
+ (mul := fun a b => Positional.mul_cps (n:= half_sz) wt a b id)
+ (add := fun a b => Positional.add_cps (n:=half_sz) wt a b id)
+ (add2 := fun a b => Positional.add_cps (n:=sz) wt a b id)
+ (split := fun s a => Positional_split_cps (n:=sz) (m1:=half_sz) (m2 := half_sz) s a id)
+ (scmul2 := fun x a => Positional_scmul_cps (n:=sz) x a id)
+ (sub2 := fun a b => Positional_sub_cps (n:=sz) a b id)
+ (mul_id := mul_id)
+ (add_id := add_id)
+ (add2_id := add2_id)
+ (split_id := split_id)
+ (scmul2_id := scmul2_id)
+ (sub2_id := sub2_id)
+ ).
+ Definition goldilocks448_correct'
+ mul_id add_id add2_id split_id scmul2_id sub2_id
+ eval_mul eval_add eval_add2 eval_split eval_scmul2 eval_sub2
+ :=
+ (goldilocks_mul_correct
+ (T := tuple Z half_sz) (T2 := tuple Z sz)
+ (Positional.eval (n:=half_sz) wt)
+ (Positional.eval (n:=sz) wt)
+ (mul_cps := Positional.mul_cps (n:=half_sz) wt)
+ (add_cps := Positional.add_cps (n:=half_sz) wt)
+ (add2_cps := Positional.add_cps (n:=sz) wt)
+ (split_cps := Positional_split_cps (n:=sz) (m1:=half_sz) (m2 := half_sz))
+ (scmul2_cps := Positional_scmul_cps (n:=sz))
+ (sub2_cps := Positional_sub_cps (n:=sz))
+ (mul := fun a b => Positional.mul_cps (n:= half_sz) wt a b id)
+ (add := fun a b => Positional.add_cps (n:=half_sz) wt a b id)
+ (add2 := fun a b => Positional.add_cps (n:=sz) wt a b id)
+ (split := fun s a => Positional_split_cps (n:=sz) (m1:=half_sz) (m2 := half_sz) s a id)
+ (scmul2 := fun x a => Positional_scmul_cps (n:=sz) x a id)
+ (sub2 := fun a b => Positional_sub_cps (n:=sz) a b id)
+ (mul_id := mul_id)
+ (add_id := add_id)
+ (add2_id := add2_id)
+ (split_id := split_id)
+ (scmul2_id := scmul2_id)
+ (sub2_id := sub2_id)
+ (eval_mul := eval_mul)
+ (eval_add := eval_add)
+ (eval_add2 := eval_add2)
+ (eval_split := eval_split)
+ (eval_scmul2 := eval_scmul2)
+ (eval_sub2 := eval_sub2)
+ ).
+ Check goldilocks448_correct'.
+ Hint Unfold Positional_split_cps Positional_scmul_cps Positional_sub_cps.
+ Lemma goldilocks448_correct :
+ forall p : positive,
+ forall s : Z,
+ s <> 0 ->
+ s ^ 2 mod p = (s + 1) mod p ->
+ forall xs ys : Z ^ sz,
+ mod_eq (Z.to_pos p)
+ (Positional.eval wt (goldilocks448_cps s xs ys _ id))
+ (Positional.eval wt xs * Positional.eval wt ys).
+ Proof.
+ pose proof wt_nonzero.
+ intros; autounfold. cbv [mod_eq].
+ rewrite goldilocks448_id by (intros; autounfold; autorewrite with uncps push_id; reflexivity). autorewrite with push_id.
+ apply goldilocks448_correct'; try assumption; intros; autounfold;
+ autorewrite with uncps push_id cancel_pair push_basesystem_eval;
+ try reflexivity.
+ { setoid_rewrite Associational.eval_nil. ring. }
+ { rewrite Pos2Z.id; congruence. }
+ Qed.
+
+ Definition mul_sig :
+ {mul : (Z^sz -> Z^sz -> Z^sz)%type |
+ forall a b : Z^sz,
+ let eval := Positional.Fdecode (m := m) wt in
+ eval (mul a b) = (eval a * eval b)%F}.
+ Proof.
+ eexists; cbv beta zeta; intros.
+ pose proof wt_nonzero.
+ let x := constr:(
+ goldilocks448_cps (2^224) a b _ id) in
+ F_mod_eq;
+ transitivity (Positional.eval wt x); repeat autounfold;
+
+ [
+ | autorewrite with uncps push_id push_basesystem_eval;
+ apply goldilocks448_correct; cbv; congruence ].
+ cbv[mod_eq]; apply f_equal2;
+ [ | reflexivity ]; apply f_equal.
+ basesystem_partial_evaluation_RHS.
+ do_replace_match_with_destructuring_match_in_goal.
+ reflexivity.
+ Defined.
+
+ Definition square_sig :
+ {square : (Z^sz -> Z^sz)%type |
+ forall a : Z^sz,
+ let eval := Positional.Fdecode (m := m) wt in
+ eval (square a) = (eval a * eval a)%F}.
+ Proof.
+ eexists; cbv beta zeta; intros.
+ rewrite <-(proj2_sig mul_sig).
+ apply f_equal.
+ cbv [proj1_sig mul_sig].
+ reflexivity.
+ Defined.
+
+ (* Performs a full carry loop (as specified by carry_chain) *)
+ Definition carry_sig :
+ {carry : (Z^sz -> Z^sz)%type |
+ forall a : Z^sz,
+ let eval := Positional.Fdecode (m := m) wt in
+ eval (carry a) = eval a}.
+ Proof.
+ eexists; cbv beta zeta; intros.
+ pose proof wt_nonzero. pose proof wt_divides_chain1.
+ pose proof div_mod. pose proof wt_divides_chain2.
+ let x := constr:(
+ Positional.chained_carries_cps (n:=sz) (div:=div)(modulo:=modulo) wt a carry_chain1
+ (fun r => Positional.carry_reduce_cps (n:=sz) (div:=div) (modulo:=modulo) wt s c r
+ (fun rrr => Positional.chained_carries_cps (n:=sz) (div:=div) (modulo:=modulo) wt rrr carry_chain2 id
+ ))) in
+ solve_op_F wt x. reflexivity.
+ Defined.
+
+ Require Import Crypto.Arithmetic.Saturated.
+
+ Section PreFreeze.
+ Lemma wt_pos i : wt i > 0.
+ Proof.
+ apply Z.lt_gt.
+ apply Z.pow_pos_nonneg; zero_bounds; try break_match; vm_decide.
+ Qed.
+
+ Lemma wt_multiples i : wt (S i) mod (wt i) = 0.
+ Admitted.
+
+ Lemma wt_divides_full_pos i : wt (S i) / wt i > 0.
+ Proof.
+ pose proof (wt_divides_full i).
+ apply Z.div_positive_gt_0; auto using wt_pos.
+ apply wt_multiples.
+ Qed.
+ End PreFreeze.
+
+ Hint Opaque freeze : uncps.
+ Hint Rewrite freeze_id : uncps.
+
+ Definition freeze_sig :
+ {freeze : (Z^sz -> Z^sz)%type |
+ forall a : Z^sz,
+ (0 <= Positional.eval wt a < 2 * Z.pos m)->
+ let eval := Positional.Fdecode (m := m) wt in
+ eval (freeze a) = eval a}.
+ Proof.
+ eexists; cbv beta zeta; intros.
+ pose proof wt_nonzero. pose proof wt_pos.
+ pose proof div_mod. pose proof wt_divides_full_pos.
+ pose proof wt_multiples.
+ pose proof div_correct. pose proof modulo_correct.
+ let x := constr:(freeze (n:=sz) (div:=div) (modulo:=modulo) wt (Z.ones bitwidth) m_enc a) in
+ F_mod_eq;
+ transitivity (Positional.eval wt x); repeat autounfold;
+ [ | autorewrite with uncps push_id push_basesystem_eval;
+ rewrite eval_freeze with (c:=c);
+ try eassumption; try omega; try reflexivity;
+ try solve [auto using B.Positional.select_id,
+ B.Positional.eval_select, zselect_correct];
+ vm_decide].
+ cbv[mod_eq]; apply f_equal2;
+ [ | reflexivity ]; apply f_equal.
+ cbv - [runtime_opp runtime_add runtime_mul runtime_shr runtime_and Let_In Z.add_get_carry Z.zselect].
+ reflexivity.
+ Defined.
+
+ Definition ring_56 :=
+ (Ring.ring_by_isomorphism
+ (F := F m)
+ (H := Z^sz)
+ (phi := Positional.Fencode wt)
+ (phi' := Positional.Fdecode wt)
+ (zero := proj1_sig zero_sig)
+ (one := proj1_sig one_sig)
+ (opp := proj1_sig opp_sig)
+ (add := proj1_sig add_sig)
+ (sub := proj1_sig sub_sig)
+ (mul := proj1_sig mul_sig)
+ (phi'_zero := proj2_sig zero_sig)
+ (phi'_one := proj2_sig one_sig)
+ (phi'_opp := proj2_sig opp_sig)
+ (Positional.Fdecode_Fencode_id
+ (sz_nonzero := sz_nonzero)
+ (div_mod := div_mod)
+ wt eq_refl wt_nonzero wt_divides_full)
+ (Positional.eq_Feq_iff wt)
+ (proj2_sig add_sig)
+ (proj2_sig sub_sig)
+ (proj2_sig mul_sig)
+ ).
+
+(*
+Eval cbv [proj1_sig add_sig] in (proj1_sig add_sig).
+Eval cbv [proj1_sig sub_sig] in (proj1_sig sub_sig).
+Eval cbv [proj1_sig opp_sig] in (proj1_sig opp_sig).
+Eval cbv [proj1_sig mul_sig] in (proj1_sig mul_sig).
+Eval cbv [proj1_sig carry_sig] in (proj1_sig carry_sig).
+*)
+
+End Ops51.