aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-25 17:45:05 +0200
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-30 06:14:51 -0400
commit93503d634054fd0813a41f9484ff08f6056e1fa6 (patch)
tree781be416de63a303d4040ec9cef11d4b87fb8b6f /src/Experiments/SimplyTypedArithmetic.v
parent58a8ea57b9dde5ff3c4823852d7ff7239dc48aaa (diff)
fixed too-many-additions problem by fixing number of limbs in from_associational
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v188
1 files changed, 99 insertions, 89 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index c9f9a239b..fcba99daf 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -673,6 +673,49 @@ Module Saturated.
end.
Qed.
Hint Rewrite eval_sat_mul : push_eval.
+
+ Definition sat_multerm_const s (t t' : (Z * Z)) : list (Z * Z) :=
+ if snd t =? 1
+ then [(fst t * fst t', snd t')]
+ else if snd t =? -1
+ then [(fst t * fst t', - snd t')]
+ else if snd t =? 0
+ then nil
+ else dlet_nd xy := Z.mul_split s (snd t) (snd t') in
+ [(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)].
+
+ Definition sat_mul_const s (p q : list (Z * Z)) : list (Z * Z) :=
+ flat_map (fun t => flat_map (fun t' => sat_multerm_const s t t') q) p.
+
+ Lemma eval_map_sat_multerm_const s a q (s_nonzero:s<>0):
+ Associational.eval (flat_map (sat_multerm_const s a) q) = fst a * snd a * Associational.eval q.
+ Proof.
+ cbv [sat_multerm_const Let_In]; induction q;
+ repeat match goal with
+ | _ => progress autorewrite with cancel_pair push_eval to_div_mod in *
+ | _ => progress simpl flat_map
+ | H : _ = 1 |- _ => rewrite H
+ | H : _ = -1 |- _ => rewrite H
+ | H : _ = 0 |- _ => rewrite H
+ | _ => progress break_match; Z.ltb_to_lt
+ | _ => rewrite IHq
+ | _ => rewrite Z.mod_eq by assumption
+ | _ => ring_simplify; omega
+ end.
+ Qed.
+ Hint Rewrite eval_map_sat_multerm_const using (omega || assumption) : push_eval.
+
+ Lemma eval_sat_mul_const s p q (s_nonzero:s<>0):
+ Associational.eval (sat_mul_const s p q) = Associational.eval p * Associational.eval q.
+ Proof.
+ cbv [sat_mul_const]; induction p; [reflexivity|].
+ repeat match goal with
+ | _ => progress (autorewrite with push_flat_map push_eval in * )
+ | _ => rewrite IHp
+ | _ => ring_simplify; omega
+ end.
+ Qed.
+ Hint Rewrite eval_sat_mul_const : push_eval.
End Associational.
End Associational.
@@ -1574,19 +1617,23 @@ Module Rows.
let pq_a := Associational.sat_mul base p_a q_a in
flatten m (from_associational m pq_a).
- Definition mulmod base s c n m (p q : list Z) :=
+ Definition sat_reduce base s c (p : list (Z * Z)) :=
+ let lo_hi := Associational.split s p in
+ fst lo_hi ++ (Associational.sat_mul_const base c (snd lo_hi)).
+
+ (* TODO : hardcoded to 2 reductions; fix *)
+ Definition mulmod base s c n (p q : list Z) :=
let p_a := Positional.to_associational weight n p in
let q_a := Positional.to_associational weight n q in
let pq_a := Associational.sat_mul base p_a q_a in
- let lo_hi := Associational.split s pq_a in
- let r_a := fst lo_hi ++ (Associational.sat_mul base c (snd lo_hi)) in (* reduce, but using sat_mul *)
- flatten m (from_associational m r_a).
+ let r_a := sat_reduce base s c (sat_reduce base s c pq_a) in
+ flatten n (from_associational n r_a).
- Hint Rewrite Associational.eval_sat_mul Associational.eval_split using solve [auto] : push_eval.
+ Hint Rewrite Associational.eval_sat_mul_const Associational.eval_sat_mul Associational.eval_split using solve [auto] : push_eval.
Hint Rewrite eval_from_associational using solve [auto] : push_eval.
Hint Rewrite eval_partition using solve [auto] : push_eval.
Ltac solver :=
- intros; cbv [sub add mul mulmod];
+ intros; cbv [sub add mul mulmod sat_reduce];
rewrite ?flatten_partitions' by (intros; In_cases; subst; distr_length; eauto using length_from_associational);
rewrite ?flatten_div by (intros; In_cases; subst; distr_length; eauto using length_from_associational);
autorewrite with push_eval; ring_simplify_subterms;
@@ -1630,16 +1677,18 @@ Module Rows.
fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q).
Proof. solver. Qed.
- Lemma eval_mulmod base s c n m p q :
+ Lemma eval_mulmod base s c n p q :
base <> 0 -> s <> 0 -> s - Associational.eval c <> 0 ->
- n <> 0%nat -> m <> 0%nat -> length p = n -> length q = n ->
- (Positional.eval weight m (fst (mulmod base s c n m p q))
- + weight m * (snd (mulmod base s c n m p q))) mod (s - Associational.eval c)
+ n <> 0%nat -> length p = n -> length q = n ->
+ (Positional.eval weight n (fst (mulmod base s c n p q))
+ + weight n * (snd (mulmod base s c n p q))) mod (s - Associational.eval c)
= (Positional.eval weight n p * Positional.eval weight n q) mod (s - Associational.eval c).
Proof.
solver.
rewrite <-Z.div_mod'' by auto.
rewrite <-Associational.reduction_rule by omega.
+ autorewrite with push_eval.
+ rewrite <-Associational.reduction_rule by omega.
autorewrite with push_eval; auto.
Qed.
End Ops.
@@ -8429,7 +8478,7 @@ Module SaturatedSolinas.
Context (s : Z) (c : list (Z * Z))
(s_nz : s <> 0) (modulus_nz : s - Associational.eval c <> 0).
Context (log2base : Z) (log2base_pos : 0 < log2base)
- (n m : nat) (n_nz : n <> 0%nat) (m_nz : m <> 0%nat).
+ (n : nat) (n_nz : n <> 0%nat).
Let weight := weight log2base 1.
Let props : @weight_properties weight := wprops log2base 1 ltac:(omega).
@@ -8439,45 +8488,44 @@ Module SaturatedSolinas.
SuchThat (forall (f g : list Z)
(Hf : length f = n)
(Hg : length g = n),
- (eval weight m (fst (mulmod f g)) + weight m * (snd (mulmod f g))) mod (s - Associational.eval c)
+ (eval weight n (fst (mulmod f g)) + weight n * (snd (mulmod f g))) mod (s - Associational.eval c)
= (eval weight n f * eval weight n g) mod (s - Associational.eval c))
As eval_mulmod.
Proof.
intros.
- rewrite <-Rows.eval_mulmod with (base:=2^log2base) (s:=s) (c:=c) (m:=m) by auto using base_nz.
+ rewrite <-Rows.eval_mulmod with (base:=2^log2base) (s:=s) (c:=c) by auto using base_nz.
eapply f_equal2; [|trivial].
(* expand_lists (). *) (* uncommenting this line removes some unused multiplications but also inlines a bunch of carry stuff at the end *)
subst mulmod. reflexivity.
Qed.
+ Definition mulmod' := fun x y => fst (mulmod x y).
End MulMod.
Derive mulmod_gen
SuchThat (forall (log2base s : Z) (c : list (Z * Z)) (n m : nat)
(f g : list Z),
- Interp (t:=type.reify_type_of mulmod)
- mulmod_gen s c log2base n m f g
- = mulmod s c log2base n m f g)
+ Interp (t:=type.reify_type_of mulmod')
+ mulmod_gen s c log2base n f g
+ = mulmod' s c log2base n f g)
As mulmod_gen_correct.
Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed.
Module Export ReifyHints.
- Global Hint Extern 1 (_ = mulmod _ _ _ _ _ _ _) => simple apply mulmod_gen_correct : reify_gen_cache.
+ Global Hint Extern 1 (_ = mulmod' _ _ _ _ _ _) => simple apply mulmod_gen_correct : reify_gen_cache.
End ReifyHints.
Section rmulmod.
- Context (n m : nat)
- (s : Z)
+ Context (s : Z)
(c : list (Z * Z))
(machine_wordsize : Z).
Definition relax_zrange_of_machine_wordsize
:= relax_zrange_gen [1; machine_wordsize]%Z.
+ Let n : nat := Z.to_nat (Z.log2 s / machine_wordsize).
Let relax_zrange := relax_zrange_of_machine_wordsize.
Let bound := Some r[0 ~> (2^machine_wordsize - 1)]%zrange.
Let boundsn : list (ZRange.type.option.interp type.Z)
:= List.repeat bound n.
- Let boundsm : list (ZRange.type.option.interp type.Z)
- := List.repeat bound m.
Definition check_args {T} (res : Pipeline.ErrorT T)
: Pipeline.ErrorT T
@@ -8514,8 +8562,8 @@ Module SaturatedSolinas.
Definition rmulmod_correct
:= BoundsPipeline_correct
(Some boundsn, Some boundsn)
- (Some boundsm, bound)
- (mulmod s c machine_wordsize n m).
+ (Some boundsn)
+ (mulmod' s c machine_wordsize n).
Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _).
Definition rmulmod_correctT rv : Prop
@@ -8527,14 +8575,12 @@ Ltac solve_rmulmod := solve_rop SaturatedSolinas.rmulmod_correct.
Ltac solve_rmulmod_nocache := solve_rop_nocache SaturatedSolinas.rmulmod_correct.
Module P192_64.
- Definition n := 3%nat.
- Definition m := 6%nat.
Definition s := 2^192.
Definition c := [(2^64, 1); (1,1)].
Definition machine_wordsize := 64.
Derive mulmod
- SuchThat (SaturatedSolinas.rmulmod_correctT n m s c machine_wordsize mulmod)
+ SuchThat (SaturatedSolinas.rmulmod_correctT s c machine_wordsize mulmod)
As mulmod_correct.
Proof. Time solve_rmulmod machine_wordsize. Time Qed.
@@ -8567,69 +8613,33 @@ mulmod = fun var : type -> Type => λ x : var (type.list (type.type_primitive ty
expr_let x9 := add64 (x6₂, x8₁) in
expr_let x10 := adc64 (x9₂, x7₁, x8₂) in
expr_let x11 := adc64 (x10₂, x6₁, x7₂) in
- expr_let x12 := adx64 (x11₂, 0, 0) in
- expr_let x13 := add64 (x4₂, x9₁) in
- expr_let x14 := adc64 (x13₂, x5₁, x10₁) in
- expr_let x15 := adc64 (x14₂, x5₂, x11₁) in
- expr_let x16 := adc64 (x15₂, 0, x12) in
- expr_let x17 := adx64 (x16₂, 0, 0) in
- expr_let x18 := add64 (x3₁, x13₁) in
- expr_let x19 := adc64 (x18₂, x6₂, x14₁) in
- expr_let x20 := adc64 (x19₂, x4₁, x15₁) in
- expr_let x21 := adc64 (x20₂, x0₂, x16₁) in
- expr_let x22 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x21₂, 0, x17)) in
- expr_let x23 := adx64 (x22₂, 0, 0) in
- expr_let x24 := add64 (x2₂, x18₁) in
- expr_let x25 := adc64 (x24₂, x4₂, x19₁) in
- expr_let x26 := adc64 (x25₂, x2₁, x20₁) in
- expr_let x27 := adc64 (x26₂, 0, x21₁) in
- expr_let x28 := adc64 (x27₂, 0, x22₁) in
- expr_let x29 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x28₂, 0, x23)) in
- expr_let x30 := add64 (x1₁, x24₁) in
- expr_let x31 := adc64 (x30₂, x3₁, x25₁) in
- expr_let x32 := adc64 (x31₂, 0, x26₁) in
- expr_let x33 := adc64 (x32₂, 0, x27₁) in
- expr_let x34 := adc64 (x33₂, 0, x28₁) in
- expr_let x35 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x34₂, 0, x29₁)) in
- expr_let x36 := add64 (x2₂, x31₁) in
- expr_let x37 := adc64 (x36₂, 0, x32₁) in
- expr_let x38 := adc64 (x37₂, 0, x33₁) in
- expr_let x39 := adc64 (x38₂, 0, x34₁) in
- expr_let x40 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x39₂, 0, x35₁)) in
- expr_let x41 := add64 (x1₁, x36₁) in
- expr_let x42 := adc64 (x41₂, 0, x37₁) in
- expr_let x43 := adc64 (x42₂, 0, x38₁) in
- expr_let x44 := adc64 (x43₂, 0, x39₁) in
- expr_let x45 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x44₂, 0, x40₁)) in
- expr_let x46 := add64 (x3₂, x42₁) in
- expr_let x47 := adc64 (x46₂, 0, x43₁) in
- expr_let x48 := adc64 (x47₂, 0, x44₁) in
- expr_let x49 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x48₂, 0, x45₁)) in
- expr_let x50 := add64 (x3₂, x41₁) in
- expr_let x51 := adc64 (x50₂, x1₂, x46₁) in
- expr_let x52 := adc64 (x51₂, 0, x47₁) in
- expr_let x53 := adc64 (x52₂, 0, x48₁) in
- expr_let x54 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x53₂, 0, x49₁)) in
- expr_let x55 := add64 (x0₁, x51₁) in
- expr_let x56 := adc64 (x55₂, 0, x52₁) in
- expr_let x57 := adc64 (x56₂, 0, x53₁) in
- expr_let x58 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x57₂, 0, x54₁)) in
- expr_let x59 := add64 (x1₂, x50₁) in
- expr_let x60 := adc64 (x59₂, 0, x55₁) in
- expr_let x61 := adc64 (x60₂, 0, x56₁) in
- expr_let x62 := adc64 (x61₂, 0, x57₁) in
- expr_let x63 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x62₂, 0, x58₁)) in
- expr_let x64 := add64 (x0₁, x59₁) in
- expr_let x65 := adc64 (x64₂, 0, x60₁) in
- expr_let x66 := adc64 (x65₂, 0, x61₁) in
- expr_let x67 := adc64 (x66₂, 0, x62₁) in
- expr_let x68 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x67₂, 0, x63₁)) in
- expr_let x69 := add64 (x0₂, x65₁) in
- expr_let x70 := adc64 (x69₂, 0, x66₁) in
- expr_let x71 := adc64 (x70₂, 0, x67₁) in
- expr_let x72 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x71₂, 0, x68₁)) in
- (x30₁ :: x64₁ :: x69₁ :: x70₁ :: x71₁ :: x72₁ :: [], Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (x29₂ + x35₂) + x40₂) + x45₂) + x49₂) + x54₂) + x58₂) + x63₂) + x68₂) + x72₂))
- : Expr (type.uncurry (type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z) * type.type_primitive type.Z))
+ expr_let x12 := add64 (x4₂, x9₁) in
+ expr_let x13 := adc64 (x12₂, x5₁, x10₁) in
+ expr_let x14 := adc64 (x13₂, x5₂, x11₁) in
+ expr_let x15 := add64 (x3₁, x12₁) in
+ expr_let x16 := adc64 (x15₂, x6₂, x13₁) in
+ expr_let x17 := adc64 (x16₂, x4₁, x14₁) in
+ expr_let x18 := add64 (x2₂, x15₁) in
+ expr_let x19 := adc64 (x18₂, x4₂, x16₁) in
+ expr_let x20 := adc64 (x19₂, x2₁, x17₁) in
+ expr_let x21 := add64 (x1₁, x18₁) in
+ expr_let x22 := adc64 (x21₂, x3₁, x19₁) in
+ expr_let x23 := adc64 (x22₂, x3₂, x20₁) in
+ expr_let x24 := add64 (x0₂, x21₁) in
+ expr_let x25 := adc64 (x24₂, x2₂, x22₁) in
+ expr_let x26 := adc64 (x25₂, x1₂, x23₁) in
+ expr_let x27 := add64 (x1₁, x25₁) in
+ expr_let x28 := adc64 (x27₂, x0₁, x26₁) in
+ expr_let x29 := add64 (x3₂, x27₁) in
+ expr_let x30 := adc64 (x29₂, x0₂, x28₁) in
+ expr_let x31 := add64 (x1₂, x29₁) in
+ expr_let x32 := adc64 (x31₂, 0, x30₁) in
+ expr_let x33 := add64 (x0₁, x31₁) in
+ expr_let x34 := adc64 (x33₂, 0, x32₁) in
+ expr_let x35 := add64 (x0₂, x33₁) in
+ expr_let x36 := adc64 (x35₂, 0, x34₁) in
+ x24₁ :: x35₁ :: x36₁ :: []
+ : Expr (type.uncurry (type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z)))
*)
End P192_64.