From 93503d634054fd0813a41f9484ff08f6056e1fa6 Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 25 Apr 2018 17:45:05 +0200 Subject: fixed too-many-additions problem by fixing number of limbs in from_associational --- src/Experiments/SimplyTypedArithmetic.v | 188 +++++++++++++++++--------------- 1 file changed, 99 insertions(+), 89 deletions(-) (limited to 'src/Experiments/SimplyTypedArithmetic.v') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index c9f9a239b..fcba99daf 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -673,6 +673,49 @@ Module Saturated. end. Qed. Hint Rewrite eval_sat_mul : push_eval. + + Definition sat_multerm_const s (t t' : (Z * Z)) : list (Z * Z) := + if snd t =? 1 + then [(fst t * fst t', snd t')] + else if snd t =? -1 + then [(fst t * fst t', - snd t')] + else if snd t =? 0 + then nil + else dlet_nd xy := Z.mul_split s (snd t) (snd t') in + [(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)]. + + Definition sat_mul_const s (p q : list (Z * Z)) : list (Z * Z) := + flat_map (fun t => flat_map (fun t' => sat_multerm_const s t t') q) p. + + Lemma eval_map_sat_multerm_const s a q (s_nonzero:s<>0): + Associational.eval (flat_map (sat_multerm_const s a) q) = fst a * snd a * Associational.eval q. + Proof. + cbv [sat_multerm_const Let_In]; induction q; + repeat match goal with + | _ => progress autorewrite with cancel_pair push_eval to_div_mod in * + | _ => progress simpl flat_map + | H : _ = 1 |- _ => rewrite H + | H : _ = -1 |- _ => rewrite H + | H : _ = 0 |- _ => rewrite H + | _ => progress break_match; Z.ltb_to_lt + | _ => rewrite IHq + | _ => rewrite Z.mod_eq by assumption + | _ => ring_simplify; omega + end. + Qed. + Hint Rewrite eval_map_sat_multerm_const using (omega || assumption) : push_eval. + + Lemma eval_sat_mul_const s p q (s_nonzero:s<>0): + Associational.eval (sat_mul_const s p q) = Associational.eval p * Associational.eval q. + Proof. + cbv [sat_mul_const]; induction p; [reflexivity|]. + repeat match goal with + | _ => progress (autorewrite with push_flat_map push_eval in * ) + | _ => rewrite IHp + | _ => ring_simplify; omega + end. + Qed. + Hint Rewrite eval_sat_mul_const : push_eval. End Associational. End Associational. @@ -1574,19 +1617,23 @@ Module Rows. let pq_a := Associational.sat_mul base p_a q_a in flatten m (from_associational m pq_a). - Definition mulmod base s c n m (p q : list Z) := + Definition sat_reduce base s c (p : list (Z * Z)) := + let lo_hi := Associational.split s p in + fst lo_hi ++ (Associational.sat_mul_const base c (snd lo_hi)). + + (* TODO : hardcoded to 2 reductions; fix *) + Definition mulmod base s c n (p q : list Z) := let p_a := Positional.to_associational weight n p in let q_a := Positional.to_associational weight n q in let pq_a := Associational.sat_mul base p_a q_a in - let lo_hi := Associational.split s pq_a in - let r_a := fst lo_hi ++ (Associational.sat_mul base c (snd lo_hi)) in (* reduce, but using sat_mul *) - flatten m (from_associational m r_a). + let r_a := sat_reduce base s c (sat_reduce base s c pq_a) in + flatten n (from_associational n r_a). - Hint Rewrite Associational.eval_sat_mul Associational.eval_split using solve [auto] : push_eval. + Hint Rewrite Associational.eval_sat_mul_const Associational.eval_sat_mul Associational.eval_split using solve [auto] : push_eval. Hint Rewrite eval_from_associational using solve [auto] : push_eval. Hint Rewrite eval_partition using solve [auto] : push_eval. Ltac solver := - intros; cbv [sub add mul mulmod]; + intros; cbv [sub add mul mulmod sat_reduce]; rewrite ?flatten_partitions' by (intros; In_cases; subst; distr_length; eauto using length_from_associational); rewrite ?flatten_div by (intros; In_cases; subst; distr_length; eauto using length_from_associational); autorewrite with push_eval; ring_simplify_subterms; @@ -1630,16 +1677,18 @@ Module Rows. fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q). Proof. solver. Qed. - Lemma eval_mulmod base s c n m p q : + Lemma eval_mulmod base s c n p q : base <> 0 -> s <> 0 -> s - Associational.eval c <> 0 -> - n <> 0%nat -> m <> 0%nat -> length p = n -> length q = n -> - (Positional.eval weight m (fst (mulmod base s c n m p q)) - + weight m * (snd (mulmod base s c n m p q))) mod (s - Associational.eval c) + n <> 0%nat -> length p = n -> length q = n -> + (Positional.eval weight n (fst (mulmod base s c n p q)) + + weight n * (snd (mulmod base s c n p q))) mod (s - Associational.eval c) = (Positional.eval weight n p * Positional.eval weight n q) mod (s - Associational.eval c). Proof. solver. rewrite <-Z.div_mod'' by auto. rewrite <-Associational.reduction_rule by omega. + autorewrite with push_eval. + rewrite <-Associational.reduction_rule by omega. autorewrite with push_eval; auto. Qed. End Ops. @@ -8429,7 +8478,7 @@ Module SaturatedSolinas. Context (s : Z) (c : list (Z * Z)) (s_nz : s <> 0) (modulus_nz : s - Associational.eval c <> 0). Context (log2base : Z) (log2base_pos : 0 < log2base) - (n m : nat) (n_nz : n <> 0%nat) (m_nz : m <> 0%nat). + (n : nat) (n_nz : n <> 0%nat). Let weight := weight log2base 1. Let props : @weight_properties weight := wprops log2base 1 ltac:(omega). @@ -8439,45 +8488,44 @@ Module SaturatedSolinas. SuchThat (forall (f g : list Z) (Hf : length f = n) (Hg : length g = n), - (eval weight m (fst (mulmod f g)) + weight m * (snd (mulmod f g))) mod (s - Associational.eval c) + (eval weight n (fst (mulmod f g)) + weight n * (snd (mulmod f g))) mod (s - Associational.eval c) = (eval weight n f * eval weight n g) mod (s - Associational.eval c)) As eval_mulmod. Proof. intros. - rewrite <-Rows.eval_mulmod with (base:=2^log2base) (s:=s) (c:=c) (m:=m) by auto using base_nz. + rewrite <-Rows.eval_mulmod with (base:=2^log2base) (s:=s) (c:=c) by auto using base_nz. eapply f_equal2; [|trivial]. (* expand_lists (). *) (* uncommenting this line removes some unused multiplications but also inlines a bunch of carry stuff at the end *) subst mulmod. reflexivity. Qed. + Definition mulmod' := fun x y => fst (mulmod x y). End MulMod. Derive mulmod_gen SuchThat (forall (log2base s : Z) (c : list (Z * Z)) (n m : nat) (f g : list Z), - Interp (t:=type.reify_type_of mulmod) - mulmod_gen s c log2base n m f g - = mulmod s c log2base n m f g) + Interp (t:=type.reify_type_of mulmod') + mulmod_gen s c log2base n f g + = mulmod' s c log2base n f g) As mulmod_gen_correct. Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. Module Export ReifyHints. - Global Hint Extern 1 (_ = mulmod _ _ _ _ _ _ _) => simple apply mulmod_gen_correct : reify_gen_cache. + Global Hint Extern 1 (_ = mulmod' _ _ _ _ _ _) => simple apply mulmod_gen_correct : reify_gen_cache. End ReifyHints. Section rmulmod. - Context (n m : nat) - (s : Z) + Context (s : Z) (c : list (Z * Z)) (machine_wordsize : Z). Definition relax_zrange_of_machine_wordsize := relax_zrange_gen [1; machine_wordsize]%Z. + Let n : nat := Z.to_nat (Z.log2 s / machine_wordsize). Let relax_zrange := relax_zrange_of_machine_wordsize. Let bound := Some r[0 ~> (2^machine_wordsize - 1)]%zrange. Let boundsn : list (ZRange.type.option.interp type.Z) := List.repeat bound n. - Let boundsm : list (ZRange.type.option.interp type.Z) - := List.repeat bound m. Definition check_args {T} (res : Pipeline.ErrorT T) : Pipeline.ErrorT T @@ -8514,8 +8562,8 @@ Module SaturatedSolinas. Definition rmulmod_correct := BoundsPipeline_correct (Some boundsn, Some boundsn) - (Some boundsm, bound) - (mulmod s c machine_wordsize n m). + (Some boundsn) + (mulmod' s c machine_wordsize n). Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). Definition rmulmod_correctT rv : Prop @@ -8527,14 +8575,12 @@ Ltac solve_rmulmod := solve_rop SaturatedSolinas.rmulmod_correct. Ltac solve_rmulmod_nocache := solve_rop_nocache SaturatedSolinas.rmulmod_correct. Module P192_64. - Definition n := 3%nat. - Definition m := 6%nat. Definition s := 2^192. Definition c := [(2^64, 1); (1,1)]. Definition machine_wordsize := 64. Derive mulmod - SuchThat (SaturatedSolinas.rmulmod_correctT n m s c machine_wordsize mulmod) + SuchThat (SaturatedSolinas.rmulmod_correctT s c machine_wordsize mulmod) As mulmod_correct. Proof. Time solve_rmulmod machine_wordsize. Time Qed. @@ -8567,69 +8613,33 @@ mulmod = fun var : type -> Type => λ x : var (type.list (type.type_primitive ty expr_let x9 := add64 (x6₂, x8₁) in expr_let x10 := adc64 (x9₂, x7₁, x8₂) in expr_let x11 := adc64 (x10₂, x6₁, x7₂) in - expr_let x12 := adx64 (x11₂, 0, 0) in - expr_let x13 := add64 (x4₂, x9₁) in - expr_let x14 := adc64 (x13₂, x5₁, x10₁) in - expr_let x15 := adc64 (x14₂, x5₂, x11₁) in - expr_let x16 := adc64 (x15₂, 0, x12) in - expr_let x17 := adx64 (x16₂, 0, 0) in - expr_let x18 := add64 (x3₁, x13₁) in - expr_let x19 := adc64 (x18₂, x6₂, x14₁) in - expr_let x20 := adc64 (x19₂, x4₁, x15₁) in - expr_let x21 := adc64 (x20₂, x0₂, x16₁) in - expr_let x22 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x21₂, 0, x17)) in - expr_let x23 := adx64 (x22₂, 0, 0) in - expr_let x24 := add64 (x2₂, x18₁) in - expr_let x25 := adc64 (x24₂, x4₂, x19₁) in - expr_let x26 := adc64 (x25₂, x2₁, x20₁) in - expr_let x27 := adc64 (x26₂, 0, x21₁) in - expr_let x28 := adc64 (x27₂, 0, x22₁) in - expr_let x29 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x28₂, 0, x23)) in - expr_let x30 := add64 (x1₁, x24₁) in - expr_let x31 := adc64 (x30₂, x3₁, x25₁) in - expr_let x32 := adc64 (x31₂, 0, x26₁) in - expr_let x33 := adc64 (x32₂, 0, x27₁) in - expr_let x34 := adc64 (x33₂, 0, x28₁) in - expr_let x35 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x34₂, 0, x29₁)) in - expr_let x36 := add64 (x2₂, x31₁) in - expr_let x37 := adc64 (x36₂, 0, x32₁) in - expr_let x38 := adc64 (x37₂, 0, x33₁) in - expr_let x39 := adc64 (x38₂, 0, x34₁) in - expr_let x40 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x39₂, 0, x35₁)) in - expr_let x41 := add64 (x1₁, x36₁) in - expr_let x42 := adc64 (x41₂, 0, x37₁) in - expr_let x43 := adc64 (x42₂, 0, x38₁) in - expr_let x44 := adc64 (x43₂, 0, x39₁) in - expr_let x45 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x44₂, 0, x40₁)) in - expr_let x46 := add64 (x3₂, x42₁) in - expr_let x47 := adc64 (x46₂, 0, x43₁) in - expr_let x48 := adc64 (x47₂, 0, x44₁) in - expr_let x49 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x48₂, 0, x45₁)) in - expr_let x50 := add64 (x3₂, x41₁) in - expr_let x51 := adc64 (x50₂, x1₂, x46₁) in - expr_let x52 := adc64 (x51₂, 0, x47₁) in - expr_let x53 := adc64 (x52₂, 0, x48₁) in - expr_let x54 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x53₂, 0, x49₁)) in - expr_let x55 := add64 (x0₁, x51₁) in - expr_let x56 := adc64 (x55₂, 0, x52₁) in - expr_let x57 := adc64 (x56₂, 0, x53₁) in - expr_let x58 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x57₂, 0, x54₁)) in - expr_let x59 := add64 (x1₂, x50₁) in - expr_let x60 := adc64 (x59₂, 0, x55₁) in - expr_let x61 := adc64 (x60₂, 0, x56₁) in - expr_let x62 := adc64 (x61₂, 0, x57₁) in - expr_let x63 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x62₂, 0, x58₁)) in - expr_let x64 := add64 (x0₁, x59₁) in - expr_let x65 := adc64 (x64₂, 0, x60₁) in - expr_let x66 := adc64 (x65₂, 0, x61₁) in - expr_let x67 := adc64 (x66₂, 0, x62₁) in - expr_let x68 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x67₂, 0, x63₁)) in - expr_let x69 := add64 (x0₂, x65₁) in - expr_let x70 := adc64 (x69₂, 0, x66₁) in - expr_let x71 := adc64 (x70₂, 0, x67₁) in - expr_let x72 := Z.cast2 (bool, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (x71₂, 0, x68₁)) in - (x30₁ :: x64₁ :: x69₁ :: x70₁ :: x71₁ :: x72₁ :: [], Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (Z.cast bool @@ (x29₂ + x35₂) + x40₂) + x45₂) + x49₂) + x54₂) + x58₂) + x63₂) + x68₂) + x72₂)) - : Expr (type.uncurry (type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z) * type.type_primitive type.Z)) + expr_let x12 := add64 (x4₂, x9₁) in + expr_let x13 := adc64 (x12₂, x5₁, x10₁) in + expr_let x14 := adc64 (x13₂, x5₂, x11₁) in + expr_let x15 := add64 (x3₁, x12₁) in + expr_let x16 := adc64 (x15₂, x6₂, x13₁) in + expr_let x17 := adc64 (x16₂, x4₁, x14₁) in + expr_let x18 := add64 (x2₂, x15₁) in + expr_let x19 := adc64 (x18₂, x4₂, x16₁) in + expr_let x20 := adc64 (x19₂, x2₁, x17₁) in + expr_let x21 := add64 (x1₁, x18₁) in + expr_let x22 := adc64 (x21₂, x3₁, x19₁) in + expr_let x23 := adc64 (x22₂, x3₂, x20₁) in + expr_let x24 := add64 (x0₂, x21₁) in + expr_let x25 := adc64 (x24₂, x2₂, x22₁) in + expr_let x26 := adc64 (x25₂, x1₂, x23₁) in + expr_let x27 := add64 (x1₁, x25₁) in + expr_let x28 := adc64 (x27₂, x0₁, x26₁) in + expr_let x29 := add64 (x3₂, x27₁) in + expr_let x30 := adc64 (x29₂, x0₂, x28₁) in + expr_let x31 := add64 (x1₂, x29₁) in + expr_let x32 := adc64 (x31₂, 0, x30₁) in + expr_let x33 := add64 (x0₁, x31₁) in + expr_let x34 := adc64 (x33₂, 0, x32₁) in + expr_let x35 := add64 (x0₂, x33₁) in + expr_let x36 := adc64 (x35₂, 0, x34₁) in + x24₁ :: x35₁ :: x36₁ :: [] + : Expr (type.uncurry (type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z))) *) End P192_64. -- cgit v1.2.3