diff options
Diffstat (limited to 'src/Experiments/NewPipeline/Toplevel1.v')
-rw-r--r-- | src/Experiments/NewPipeline/Toplevel1.v | 2318 |
1 files changed, 2318 insertions, 0 deletions
diff --git a/src/Experiments/NewPipeline/Toplevel1.v b/src/Experiments/NewPipeline/Toplevel1.v new file mode 100644 index 000000000..6168dfd22 --- /dev/null +++ b/src/Experiments/NewPipeline/Toplevel1.v @@ -0,0 +1,2318 @@ +Require Import Coq.ZArith.ZArith Coq.micromega.Lia. +Require Import Coq.derive.Derive. +Require Import Coq.Bool.Bool. +Require Import Coq.Strings.String. +Require Import Coq.Lists.List. +Require Crypto.Util.Strings.String. +Require Import Crypto.Util.Strings.Decimal. +Require Import Crypto.Util.Strings.HexString. +Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. +Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Algebra.SubsetoidRing. +Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ListUtil.FoldBool. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.Tactics.SplitInContext. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Tactics.DestructHead. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil Coq.Lists.List. +Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.GetGoal. +Require Import Crypto.Arithmetic.BarrettReduction.Generalized. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.ZUtil.Rshi. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.ZUtil.Zselect. +Require Import Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.CC. +Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. +Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. +Require Import Crypto.Util.ErrorT. +Require Import Crypto.Util.Strings.Show. +Require Import Crypto.Util.ZRange.Show. +Require Import Crypto.Experiments.NewPipeline.Arithmetic. +Require Crypto.Experiments.NewPipeline.Language. +Require Crypto.Experiments.NewPipeline.UnderLets. +Require Crypto.Experiments.NewPipeline.AbstractInterpretation. +Require Crypto.Experiments.NewPipeline.AbstractInterpretationProofs. +Require Crypto.Experiments.NewPipeline.Rewriter. +Require Crypto.Experiments.NewPipeline.MiscCompilerPasses. +Require Crypto.Experiments.NewPipeline.CStringification. +Require Import Crypto.Util.Notations. +Import ListNotations. Local Open Scope Z_scope. + +(** NOTE: Module Ring SHOULD NOT depend on any compilers things *) +Module Ring. + Local Notation is_bounded_by0 r v + := ((lower r <=? v) && (v <=? upper r)). + Local Notation is_bounded_by0o r + := (match r with Some r' => fun v' => is_bounded_by0 r' v' | None => fun _ => true end). + Local Notation is_bounded_by bounds ls + := (fold_andb_map (fun r v'' => is_bounded_by0o r v'') bounds ls). + Local Notation is_bounded_by1 bounds ls + := (andb (is_bounded_by bounds (@fst _ unit ls)) true). + Local Notation is_bounded_by2 bounds ls + := (andb (is_bounded_by bounds (fst ls)) (is_bounded_by1 bounds (snd ls))). + + Lemma length_is_bounded_by bounds ls + : is_bounded_by bounds ls = true -> length ls = length bounds. + Proof. + intro H. + apply fold_andb_map_length in H; congruence. + Qed. + + Section ring_goal. + Context (limbwidth_num limbwidth_den : Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (tight_bounds : list (option zrange)) + (length_tight_bounds : length tight_bounds = n) + (loose_bounds : list (option zrange)) + (length_loose_bounds : length loose_bounds = n). + Local Notation weight := (weight limbwidth_num limbwidth_den). + Local Notation eval := (Positional.eval weight n). + Let prime_bound : zrange + := r[0~>(s - Associational.eval c - 1)]%zrange. + Let m := Z.to_pos (s - Associational.eval c). + Context (m_eq : Z.pos m = s - Associational.eval c) + (sc_pos : 0 < s - Associational.eval c) + (Interp_rrelaxv : list Z -> list Z) + (HInterp_rrelaxv : forall arg, + is_bounded_by1 tight_bounds arg = true + -> is_bounded_by loose_bounds (Interp_rrelaxv (fst arg)) = true + /\ Interp_rrelaxv (fst arg) = id (fst arg)) + (carry_mulmod : list Z -> list Z -> list Z) + (Hcarry_mulmod + : forall f g, + length f = n -> length g = n -> + (eval (carry_mulmod f g)) mod (s - Associational.eval c) + = (eval f * eval g) mod (s - Associational.eval c)) + (Interp_rcarry_mulv : list Z -> list Z -> list Z) + (HInterp_rcarry_mulv : forall arg, + is_bounded_by2 loose_bounds arg = true + -> is_bounded_by tight_bounds (Interp_rcarry_mulv (fst arg) (fst (snd arg))) = true + /\ Interp_rcarry_mulv (fst arg) (fst (snd arg)) = carry_mulmod (fst arg) (fst (snd arg))) + (carrymod : list Z -> list Z) + (Hcarrymod + : forall f, + length f = n -> + (eval (carrymod f)) mod (s - Associational.eval c) + = (eval f) mod (s - Associational.eval c)) + (Interp_rcarryv : list Z -> list Z) + (HInterp_rcarryv : forall arg, + is_bounded_by1 loose_bounds arg = true + -> is_bounded_by tight_bounds (Interp_rcarryv (fst arg)) = true + /\ Interp_rcarryv (fst arg) = carrymod (fst arg)) + (addmod : list Z -> list Z -> list Z) + (Haddmod + : forall f g, + length f = n -> length g = n -> + (eval (addmod f g)) mod (s - Associational.eval c) + = (eval f + eval g) mod (s - Associational.eval c)) + (Interp_raddv : list Z -> list Z -> list Z) + (HInterp_raddv : forall arg, + is_bounded_by2 tight_bounds arg = true + -> is_bounded_by loose_bounds (Interp_raddv (fst arg) (fst (snd arg))) = true + /\ Interp_raddv (fst arg) (fst (snd arg)) = addmod (fst arg) (fst (snd arg))) + (submod : list Z -> list Z -> list Z) + (Hsubmod + : forall f g, + length f = n -> length g = n -> + (eval (submod f g)) mod (s - Associational.eval c) + = (eval f - eval g) mod (s - Associational.eval c)) + (Interp_rsubv : list Z -> list Z -> list Z) + (HInterp_rsubv : forall arg, + is_bounded_by2 tight_bounds arg = true + -> is_bounded_by loose_bounds (Interp_rsubv (fst arg) (fst (snd arg))) = true + /\ Interp_rsubv (fst arg) (fst (snd arg)) = submod (fst arg) (fst (snd arg))) + (oppmod : list Z -> list Z) + (Hoppmod + : forall f, + length f = n -> + (eval (oppmod f)) mod (s - Associational.eval c) + = (- eval f) mod (s - Associational.eval c)) + (Interp_roppv : list Z -> list Z) + (HInterp_roppv : forall arg, + is_bounded_by1 tight_bounds arg = true + -> is_bounded_by loose_bounds (Interp_roppv (fst arg)) = true + /\ Interp_roppv (fst arg) = oppmod (fst arg)) + (zeromod : list Z) + (Hzeromod + : (eval zeromod) mod (s - Associational.eval c) + = 0 mod (s - Associational.eval c)) + (Interp_rzerov : list Z) + (HInterp_rzerov : is_bounded_by tight_bounds Interp_rzerov = true + /\ Interp_rzerov = zeromod) + (onemod : list Z) + (Honemod + : (eval onemod) mod (s - Associational.eval c) + = 1 mod (s - Associational.eval c)) + (Interp_ronev : list Z) + (HInterp_ronev : is_bounded_by tight_bounds Interp_ronev = true + /\ Interp_ronev = onemod) + (encodemod : Z -> list Z) + (Hencodemod + : forall f, + (eval (encodemod f)) mod (s - Associational.eval c) + = f mod (s - Associational.eval c)) + (Interp_rencodev : Z -> list Z) + (HInterp_rencodev : forall arg, + is_bounded_by0 prime_bound (@fst _ unit arg) && true = true + -> is_bounded_by tight_bounds (Interp_rencodev (fst arg)) = true + /\ Interp_rencodev (fst arg) = encodemod (fst arg)). + + Local Notation T := (list Z) (only parsing). + Local Notation encoded_ok ls + := (is_bounded_by tight_bounds ls = true) (only parsing). + Local Notation encoded_okf := (fun ls => encoded_ok ls) (only parsing). + + Definition Fdecode (v : T) : F m + := F.of_Z m (Positional.eval weight n v). + Definition T_eq (x y : T) + := Fdecode x = Fdecode y. + + Definition encodedT := sig encoded_okf. + + Definition ring_mul (x y : T) : T + := Interp_rcarry_mulv (Interp_rrelaxv x) (Interp_rrelaxv y). + Definition ring_add (x y : T) : T := Interp_rcarryv (Interp_raddv x y). + Definition ring_sub (x y : T) : T := Interp_rcarryv (Interp_rsubv x y). + Definition ring_opp (x : T) : T := Interp_rcarryv (Interp_roppv x). + Definition ring_encode (x : F m) : T := Interp_rencodev (F.to_Z x). + + Definition GoodT : Prop + := @subsetoid_ring + (list Z) encoded_okf T_eq + Interp_rzerov Interp_ronev ring_opp ring_add ring_sub ring_mul + /\ @is_subsetoid_homomorphism + (F m) (fun _ => True) eq 1%F F.add F.mul + (list Z) encoded_okf T_eq Interp_ronev ring_add ring_mul ring_encode + /\ @is_subsetoid_homomorphism + (list Z) encoded_okf T_eq Interp_ronev ring_add ring_mul + (F m) (fun _ => True) eq 1%F F.add F.mul + Fdecode. + + Hint Rewrite ->@F.to_Z_add : push_FtoZ. + Hint Rewrite ->@F.to_Z_mul : push_FtoZ. + Hint Rewrite ->@F.to_Z_opp : push_FtoZ. + Hint Rewrite ->@F.to_Z_of_Z : push_FtoZ. + + Lemma Fm_bounded_alt (x : F m) + : (0 <=? F.to_Z x) && (F.to_Z x <=? Z.pos m - 1) = true. + Proof using m_eq. + clear -m_eq. + destruct x as [x H]; cbn [F.to_Z proj1_sig]. + pose proof (Z.mod_pos_bound x (Z.pos m)). + rewrite andb_true_iff; split; Z.ltb_to_lt; lia. + Qed. + + Lemma Good : GoodT. + Proof. + split_and. + repeat match goal with + | [ H : context[andb _ true] |- _ ] => setoid_rewrite andb_true_r in H + end. + eapply subsetoid_ring_by_ring_isomorphism; + cbv [ring_opp ring_add ring_sub ring_mul ring_encode F.sub] in *; + repeat match goal with + | [ H : forall arg : _ * unit, _ |- _ ] => specialize (fun arg => H (arg, tt)) + | [ H : forall arg : _ * (_ * unit), _ |- _ ] => specialize (fun a b => H (a, (b, tt))) + | _ => progress cbn [fst snd] in * + | _ => solve [ auto using andb_true_intro, conj with nocore ] + | _ => progress intros + | [ H : _ |- is_bounded_by _ _ = true ] => apply H + | [ |- _ <-> _ ] => reflexivity + | [ |- ?x = ?x ] => reflexivity + | [ |- _ = _ :> Z ] => first [ reflexivity | rewrite <- m_eq; reflexivity ] + | [ H : context[?x] |- Fdecode ?x = _ ] => rewrite H + | [ H : context[?x _] |- Fdecode (?x _) = _ ] => rewrite H + | [ H : context[?x _ _] |- Fdecode (?x _ _) = _ ] => rewrite H + | _ => progress cbv [Fdecode] + | [ |- _ = _ :> F _ ] => apply F.eq_to_Z_iff + | _ => progress autorewrite with push_FtoZ + | _ => rewrite m_eq + | [ H : context[?x _ _] |- context[eval (?x _ _)] ] => rewrite H + | [ H : context[?x _] |- context[eval (?x _)] ] => rewrite H + | [ H : context[?x] |- context[eval ?x] ] => rewrite H + | [ |- context[List.length ?x] ] + => erewrite (length_is_bounded_by _ x) + by eauto using andb_true_intro, conj with nocore + | [ |- _ = _ :> Z ] + => push_Zmod; reflexivity + | _ => pull_Zmod; rewrite Z.add_opp_r + | _ => rewrite expanding_id_id + | [ |- context[F.to_Z _ mod (_ - _)] ] + => rewrite <- m_eq, F.mod_to_Z + | _ => rewrite <- m_eq; apply Fm_bounded_alt + | [ |- context[andb _ true] ] => rewrite andb_true_r + end. + Qed. + End ring_goal. +End Ring. + +Import Associational Positional. + +Import + Crypto.Experiments.NewPipeline.Language + Crypto.Experiments.NewPipeline.UnderLets + Crypto.Experiments.NewPipeline.AbstractInterpretation + Crypto.Experiments.NewPipeline.AbstractInterpretationProofs + Crypto.Experiments.NewPipeline.Rewriter + Crypto.Experiments.NewPipeline.MiscCompilerPasses + Crypto.Experiments.NewPipeline.CStringification. + +Import + Language.Compilers + UnderLets.Compilers + AbstractInterpretation.Compilers + AbstractInterpretationProofs.Compilers + Rewriter.Compilers + MiscCompilerPasses.Compilers + CStringification.Compilers. + +Import Compilers.defaults. +Local Coercion Z.of_nat : nat >-> Z. +Local Coercion QArith_base.inject_Z : Z >-> Q. +Notation "x" := (expr.Var x) (only printing, at level 9) : expr_scope. + +Axiom admit_pf : False. +Notation admit := (match admit_pf with end). +Ltac cache_reify _ := + intros; + etransitivity; + [ + | repeat match goal with |- _ = ?f' ?x => is_var x; apply (f_equal (fun f => f _)) end; + Reify_rhs (); + reflexivity ]; + subst_evars; + reflexivity. + +Create HintDb reify_gen_cache. + +Derive carry_mul_gen + SuchThat (forall (limbwidth_num limbwidth_den : Z) + (f g : list Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (idxs : list nat), + Interp (t:=reify_type_of carry_mulmod) + carry_mul_gen limbwidth_num limbwidth_den s c n idxs f g + = carry_mulmod limbwidth_num limbwidth_den s c n idxs f g) + As carry_mul_gen_correct. +Proof. Time cache_reify (). Time Qed. +Hint Extern 1 (_ = carry_mulmod _ _ _ _ _ _ _ _) => simple apply carry_mul_gen_correct : reify_gen_cache. + +Derive carry_gen + SuchThat (forall (limbwidth_num limbwidth_den : Z) + (f : list Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (idxs : list nat), + Interp (t:=reify_type_of carrymod) + carry_gen limbwidth_num limbwidth_den s c n idxs f + = carrymod limbwidth_num limbwidth_den s c n idxs f) + As carry_gen_correct. +Proof. cache_reify (). Qed. +Hint Extern 1 (_ = carrymod _ _ _ _ _ _ _) => simple apply carry_gen_correct : reify_gen_cache. + +Derive encode_gen + SuchThat (forall (limbwidth_num limbwidth_den : Z) + (v : Z) + (n : nat) + (s : Z) + (c : list (Z * Z)), + Interp (t:=reify_type_of encodemod) + encode_gen limbwidth_num limbwidth_den s c n v + = encodemod limbwidth_num limbwidth_den s c n v) + As encode_gen_correct. +Proof. cache_reify (). Qed. +Hint Extern 1 (_ = encodemod _ _ _ _ _ _) => simple apply encode_gen_correct : reify_gen_cache. + +Derive add_gen + SuchThat (forall (limbwidth_num limbwidth_den : Z) + (f g : list Z) + (n : nat), + Interp (t:=reify_type_of addmod) + add_gen limbwidth_num limbwidth_den n f g + = addmod limbwidth_num limbwidth_den n f g) + As add_gen_correct. +Proof. cache_reify (). Qed. +Hint Extern 1 (_ = addmod _ _ _ _ _) => simple apply add_gen_correct : reify_gen_cache. +Derive sub_gen + SuchThat (forall (limbwidth_num limbwidth_den : Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (coef : Z) + (f g : list Z), + Interp (t:=reify_type_of submod) + sub_gen limbwidth_num limbwidth_den s c n coef f g + = submod limbwidth_num limbwidth_den s c n coef f g) + As sub_gen_correct. +Proof. cache_reify (). Qed. +Hint Extern 1 (_ = submod _ _ _ _ _ _ _ _) => simple apply sub_gen_correct : reify_gen_cache. + +Derive opp_gen + SuchThat (forall (limbwidth_num limbwidth_den : Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (coef : Z) + (f : list Z), + Interp (t:=reify_type_of oppmod) + opp_gen limbwidth_num limbwidth_den s c n coef f + = oppmod limbwidth_num limbwidth_den s c n coef f) + As opp_gen_correct. +Proof. cache_reify (). Qed. +Hint Extern 1 (_ = oppmod _ _ _ _ _ _ _) => simple apply opp_gen_correct : reify_gen_cache. + +Definition zeromod limbwidth_num limbwidth_den n s c := encodemod limbwidth_num limbwidth_den n s c 0. +Definition onemod limbwidth_num limbwidth_den n s c := encodemod limbwidth_num limbwidth_den n s c 1. + +Derive zero_gen + SuchThat (forall (limbwidth_num limbwidth_den : Z) + (n : nat) + (s : Z) + (c : list (Z * Z)), + Interp (t:=reify_type_of zeromod) + zero_gen limbwidth_num limbwidth_den s c n + = zeromod limbwidth_num limbwidth_den s c n) + As zero_gen_correct. +Proof. cache_reify (). Qed. +Hint Extern 1 (_ = zeromod _ _ _ _ _) => simple apply zero_gen_correct : reify_gen_cache. + +Derive one_gen + SuchThat (forall (limbwidth_num limbwidth_den : Z) + (n : nat) + (s : Z) + (c : list (Z * Z)), + Interp (t:=reify_type_of onemod) + one_gen limbwidth_num limbwidth_den s c n + = onemod limbwidth_num limbwidth_den s c n) + As one_gen_correct. +Proof. cache_reify (). Qed. +Hint Extern 1 (_ = onemod _ _ _ _ _) => simple apply one_gen_correct : reify_gen_cache. + +Derive id_gen + SuchThat (forall (ls : list Z), + Interp (t:=reify_type_of (@id (list Z))) + id_gen ls + = id ls) + As id_gen_correct. +Proof. cache_reify (). Qed. +Hint Extern 1 (_ = id _) => simple apply id_gen_correct : reify_gen_cache. + +Module Pipeline. + Import GeneralizeVar. + Inductive ErrorMessage := + | Computed_bounds_are_not_tight_enough + {t} (computed_bounds expected_bounds : ZRange.type.base.option.interp (type.final_codomain t)) + (syntax_tree : Expr t) (arg_bounds : type.for_each_lhs_of_arrow ZRange.type.option.interp t) + | Bounds_analysis_failed + | Type_too_complicated_for_cps (t : type) + | Value_not_leZ (descr : string) (lhs rhs : Z) + | Value_not_leQ (descr : string) (lhs rhs : Q) + | Value_not_ltZ (descr : string) (lhs rhs : Z) + | Values_not_provably_distinctZ (descr : string) (lhs rhs : Z) + | Values_not_provably_equalZ (descr : string) (lhs rhs : Z) + | Stringification_failed {t} (e : @Compilers.defaults.Expr t). + + Notation ErrorT := (ErrorT ErrorMessage). + + Section show. + Local Open Scope string_scope. + Definition show_prim_zrange_opt_interp {t:base.type.base} + : Show (ZRange.type.base.option.interp t) + := match t return Show (ZRange.type.base.option.interp t) with + | base.type.unit => _ + | base.type.Z => _ + | base.type.nat => _ + | base.type.bool => _ + end. + Global Existing Instance show_prim_zrange_opt_interp. + Fixpoint show_base_zrange_opt_interp {t} : Show (ZRange.type.base.option.interp t) + := fun parens + => match t return ZRange.type.base.option.interp t -> string with + | base.type.type_base t + => fun v : ZRange.type.base.option.interp t + => @show_prim_zrange_opt_interp t parens v + | base.type.prod A B + => fun '(a, b) + => "(" ++ @show_base_zrange_opt_interp A false a + ++ ", " ++ @show_base_zrange_opt_interp B true b + ++ ")" + | base.type.list A + => fun v : option (list (ZRange.type.option.interp A)) + => show parens v + end. + Global Existing Instance show_base_zrange_opt_interp. + Definition show_zrange_opt_interp {t} : Show (ZRange.type.option.interp t) + := fun parens + => match t return ZRange.type.option.interp t -> string with + | type.base t + => fun v : ZRange.type.base.option.interp t + => show parens v + | type.arrow s d => fun _ => "λ" + end. + Global Existing Instance show_zrange_opt_interp. + Fixpoint show_for_each_lhs_of_arrow {base_type} (f : type.type base_type -> Type) (show_f : forall t, Show (f t)) (t : type.type base_type) (p : bool) : type.for_each_lhs_of_arrow f t -> string + := match t return type.for_each_lhs_of_arrow f t -> string with + | type.base t => fun (tt : unit) => show p tt + | type.arrow s d + => fun '((x, xs) : f s * type.for_each_lhs_of_arrow f d) + => let _ : Show (f s) := show_f s in + let _ : Show (type.for_each_lhs_of_arrow f d) := @show_for_each_lhs_of_arrow base_type f show_f d in + show p (x, xs) + end. + Global Instance: forall {base_type f show_f t}, Show (type.for_each_lhs_of_arrow f t) := @show_for_each_lhs_of_arrow. + + Local Notation NewLine := (String "010" "") (only parsing). + + Fixpoint find_too_loose_base_bounds {t} + : ZRange.type.base.option.interp t -> ZRange.type.base.option.interp t-> bool * list (nat * nat) * list (zrange * zrange) + := match t return ZRange.type.base.option.interp t -> ZRange.type.option.interp t-> bool * list (nat * nat) * list (zrange * zrange) with + | base.type.unit + => fun 'tt 'tt => (false, nil, nil) + | base.type.nat + | base.type.bool + => fun _ _ => (false, nil, nil) + | base.type.Z + => fun a b + => match a, b with + | None, None => (false, nil, nil) + | Some _, None => (false, nil, nil) + | None, Some _ => (true, nil, nil) + | Some a, Some b + => if is_tighter_than_bool a b + then (false, nil, nil) + else (false, nil, ((a, b)::nil)) + end + | base.type.prod A B + => fun '(ra, rb) '(ra', rb') + => let '(b1, lens1, ls1) := @find_too_loose_base_bounds A ra ra' in + let '(b2, lens2, ls2) := @find_too_loose_base_bounds B rb rb' in + (orb b1 b2, lens1 ++ lens2, ls1 ++ ls2)%list + | base.type.list A + => fun ls1 ls2 + => match ls1, ls2 with + | None, None + | Some _, None + => (false, nil, nil) + | None, Some _ + => (true, nil, nil) + | Some ls1, Some ls2 + => List.fold_right + (fun '(b, len, err) '(bs, lens, errs) + => (orb b bs, len ++ lens, err ++ errs)%list) + (false, + (if (List.length ls1 =? List.length ls2)%nat + then nil + else ((List.length ls1, List.length ls2)::nil)), + nil) + (List.map + (fun '(a, b) => @find_too_loose_base_bounds A a b) + (List.combine ls1 ls2)) + end + end. + + Definition find_too_loose_bounds {t} + : ZRange.type.option.interp t -> ZRange.type.option.interp t-> bool * list (nat * nat) * list (zrange * zrange) + := match t with + | type.arrow s d => fun _ _ => (false, nil, nil) + | type.base t => @find_too_loose_base_bounds t + end. + Definition explain_too_loose_bounds {t} (b1 b2 : ZRange.type.option.interp t) + : string + := let '(none_some, lens, bs) := find_too_loose_bounds b1 b2 in + String.concat + NewLine + ((if none_some then "Found None where Some was expected"::nil else nil) + ++ (List.map + (A:=nat*nat) + (fun '(l1, l2) => "Found a list of length " ++ show false l1 ++ " where a list of length " ++ show false l2 ++ " was expected.") + lens) + ++ (List.map + (A:=zrange*zrange) + (fun '(b1, b2) => "The bounds " ++ show false b1 ++ " are looser than the expected bounds " ++ show false b2) + bs)). + + Global Instance show_ErrorMessage : Show ErrorMessage + := fun parens e + => maybe_wrap_parens + parens + match e with + | Computed_bounds_are_not_tight_enough t computed_bounds expected_bounds syntax_tree arg_bounds + => ("Computed bounds " ++ show true computed_bounds ++ " are not tight enough (expected bounds not looser than " ++ show true expected_bounds ++ ")." ++ NewLine) + ++ (explain_too_loose_bounds (t:=type.base _) computed_bounds expected_bounds ++ NewLine) + ++ match ToString.C.ToFunctionString + "f" syntax_tree None arg_bounds with + | Some E_str + => ("When doing bounds analysis on the syntax tree:" ++ NewLine) + ++ E_str ++ NewLine + ++ "with input bounds " ++ show true arg_bounds ++ "." ++ NewLine + | None => "(Unprintible syntax tree used in bounds analysis)" ++ NewLine + end + | Bounds_analysis_failed => "Bounds analysis failed." + | Type_too_complicated_for_cps t + => "Type too complicated for cps: " ++ show false t + | Value_not_leZ descr lhs rhs + => "Value not ≤ (" ++ descr ++ ") : expected " ++ show false lhs ++ " ≤ " ++ show false rhs + | Value_not_leQ descr lhs rhs + => "Value not ≤ (" ++ descr ++ ") : expected " ++ show false lhs ++ " ≤ " ++ show false rhs + | Value_not_ltZ descr lhs rhs + => "Value not < (" ++ descr ++ ") : expected " ++ show false lhs ++ " < " ++ show false rhs + | Values_not_provably_distinctZ descr lhs rhs + => "Values not provalby distinct (" ++ descr ++ ") : expected " ++ show true lhs ++ " ≠ " ++ show true rhs + | Values_not_provably_equalZ descr lhs rhs + => "Values not provalby equal (" ++ descr ++ ") : expected " ++ show true lhs ++ " = " ++ show true rhs + | Stringification_failed t e => "Stringification failed on the syntax tree:" ++ NewLine ++ show false e + end. + End show. + + Definition invert_result {T} (v : ErrorT T) + := match v return match v with Success _ => T | _ => ErrorMessage end with + | Success v => v + | Error msg => msg + end. + + Record to_fancy_args := { invert_low : Z (*log2wordmax*) -> Z -> option Z ; invert_high : Z (*log2wordmax*) -> Z -> option Z }. + + Definition BoundsPipeline + (with_dead_code_elimination : bool := true) + (with_subst01 : bool) + (translate_to_fancy : option to_fancy_args) + relax_zrange + {t} + (E : Expr t) + arg_bounds + out_bounds + : ErrorT (Expr t) + := (*let E := expr.Uncurry E in*) + let E := PartialEvaluateWithListInfoFromBounds E arg_bounds in + let E := PartialEvaluate E in + (* Note that DCE evaluates the expr with two different [var] + arguments, and so results in a pipeline that is 2x slower + unless we pass through a uniformly concrete [var] type + first *) + dlet_nd e := ToFlat E in + let E := FromFlat e in + let E := if with_dead_code_elimination then DeadCodeElimination.EliminateDead E else E in + dlet_nd e := ToFlat E in + let E := FromFlat e in + let E := if with_subst01 then Subst01.Subst01 E else E in + let E := UnderLets.LetBindReturn E in + let E := PartialEvaluate E in (* after inlining, see if any new rewrite redexes are available *) + let E := ReassociateSmallConstants.Reassociate (2^8) E in + let E := match translate_to_fancy with + | Some {| invert_low := invert_low ; invert_high := invert_high |} => RewriteRules.RewriteToFancy invert_low invert_high E + | None => E + end in + dlet_nd e := ToFlat E in + let E := FromFlat e in + let E := CheckedPartialEvaluateWithBounds relax_zrange E arg_bounds out_bounds in + match E with + | inl E => Success E + | inr (b, E) + => Error (Computed_bounds_are_not_tight_enough b out_bounds E arg_bounds) + end. + + Definition BoundsPipelineToStrings + (name : string) + (with_dead_code_elimination : bool := true) + (with_subst01 : bool) + (translate_to_fancy : option to_fancy_args) + relax_zrange + {t} + (E : Expr t) + arg_bounds + out_bounds + : ErrorT (list string) + := let E := BoundsPipeline + (*with_dead_code_elimination*) + with_subst01 + translate_to_fancy + relax_zrange + E arg_bounds out_bounds in + match E with + | Success E' => let E := ToString.C.ToFunctionLines + name E' None arg_bounds in + match E with + | Some E => Success E + | None => Error (Stringification_failed E') + end + | Error err => Error err + end. + + Definition BoundsPipelineToString + (name : string) + (with_dead_code_elimination : bool := true) + (with_subst01 : bool) + (translate_to_fancy : option to_fancy_args) + relax_zrange + {t} + (E : Expr t) + arg_bounds + out_bounds + : ErrorT string + := let E := BoundsPipelineToStrings + name + (*with_dead_code_elimination*) + with_subst01 + translate_to_fancy + relax_zrange + E arg_bounds out_bounds in + match E with + | Success E => Success (ToString.C.LinesToString E) + | Error err => Error err + end. + + Lemma BoundsPipeline_correct + (with_dead_code_elimination : bool := true) + (with_subst01 : bool) + (translate_to_fancy : option to_fancy_args) + relax_zrange + (Hrelax : forall r r' z : zrange, + (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) + {t} + (e : Expr t) + arg_bounds + out_bounds + rv + (Hrv : BoundsPipeline (*with_dead_code_elimination*) with_subst01 translate_to_fancy relax_zrange e arg_bounds out_bounds = Success rv) + : forall arg + (Harg : type.andb_bool_for_each_lhs_of_arrow (@ZRange.type.option.is_bounded_by) arg_bounds arg = true), + ZRange.type.base.option.is_bounded_by out_bounds (type.app_curried (Interp rv) arg) = true + /\ forall cast_outside_of_range, type.app_curried (expr.Interp (@ident.gen_interp cast_outside_of_range) rv) arg + = type.app_curried (Interp e) arg. + Proof. + cbv [BoundsPipeline Let_In] in *; + repeat match goal with + | [ H : match ?x with _ => _ end = Success _ |- _ ] + => destruct x eqn:?; cbv beta iota in H; [ | destruct_head'_prod; congruence ]; + let H' := fresh in + inversion H as [H']; clear H; rename H' into H + end. + { intros; + match goal with + | [ H : _ = _ |- _ ] + => eapply CheckedPartialEvaluateWithBounds_Correct in H; + [ destruct H as [H0 H1] | .. ] + end; + [ + | eassumption || (try reflexivity).. ]. + subst. + split; [ assumption | ]. + { intros; rewrite H1. + exact admit. (* interp correctness *) } } + Qed. + + Definition BoundsPipeline_correct_transT + {t} + arg_bounds + out_bounds + (InterpE : type.interp base.interp t) + (rv : Expr t) + := forall arg + (Harg : type.andb_bool_for_each_lhs_of_arrow (@ZRange.type.option.is_bounded_by) arg_bounds arg = true), + ZRange.type.base.option.is_bounded_by out_bounds (type.app_curried (Interp rv) arg) = true + /\ forall cast_outside_of_range, type.app_curried (expr.Interp (@ident.gen_interp cast_outside_of_range) rv) arg + = type.app_curried InterpE arg. + + Lemma BoundsPipeline_correct_trans + (with_dead_code_elimination : bool := true) + (with_subst01 : bool) + (translate_to_fancy : option to_fancy_args) + relax_zrange + (Hrelax + : forall r r' z : zrange, + (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) + {t} + (e : Expr t) + arg_bounds out_bounds + (InterpE : type.interp base.interp t) + (InterpE_correct + : forall arg + (Harg : type.andb_bool_for_each_lhs_of_arrow (@ZRange.type.option.is_bounded_by) arg_bounds arg = true), + type.app_curried (Interp e) arg = type.app_curried InterpE arg) + rv + (Hrv : BoundsPipeline (*with_dead_code_elimination*) with_subst01 translate_to_fancy relax_zrange e arg_bounds out_bounds = Success rv) + : BoundsPipeline_correct_transT arg_bounds out_bounds InterpE rv. + Proof. + intros arg Harg; rewrite <- InterpE_correct by assumption. + eapply @BoundsPipeline_correct; eassumption. + Qed. +End Pipeline. + +Definition round_up_bitwidth_gen (possible_values : list Z) (bitwidth : Z) : option Z + := List.fold_right + (fun allowed cur + => if bitwidth <=? allowed + then Some allowed + else cur) + None + possible_values. + +Lemma round_up_bitwidth_gen_le possible_values bitwidth v + : round_up_bitwidth_gen possible_values bitwidth = Some v + -> bitwidth <= v. +Proof. + cbv [round_up_bitwidth_gen]. + induction possible_values as [|x xs IHxs]; cbn; intros; inversion_option. + break_innermost_match_hyps; Z.ltb_to_lt; inversion_option; subst; trivial. + specialize_by_assumption; omega. +Qed. + +Definition relax_zrange_gen (possible_values : list Z) : zrange -> option zrange + := (fun '(r[ l ~> u ]) + => if (0 <=? l)%Z + then option_map (fun u => r[0~>2^u-1]) + (round_up_bitwidth_gen possible_values (Z.log2_up (u+1))) + else None)%zrange. + +Lemma relax_zrange_gen_good + (possible_values : list Z) + : forall r r' z : zrange, + (z <=? r)%zrange = true -> relax_zrange_gen possible_values r = Some r' -> (z <=? r')%zrange = true. +Proof. + cbv [is_tighter_than_bool relax_zrange_gen]; intros *. + pose proof (Z.log2_up_nonneg (upper r + 1)). + rewrite !Bool.andb_true_iff; destruct_head' zrange; cbn [ZRange.lower ZRange.upper] in *. + cbv [fold_right option_map]. + break_innermost_match; intros; destruct_head'_and; + try match goal with + | [ H : _ |- _ ] => apply round_up_bitwidth_gen_le in H + end; + inversion_option; inversion_zrange; + subst; + repeat apply conj; + Z.ltb_to_lt; try omega; + try (rewrite <- Z.log2_up_le_pow2_full in *; omega). +Qed. + +(** XXX TODO: Translate Jade's python script *) +Module Import UnsaturatedSolinas. + Section rcarry_mul. + Context (n : nat) + (s : Z) + (c : list (Z * Z)) + (machine_wordsize : Z). + + Let limbwidth := (Z.log2_up (s - Associational.eval c) / Z.of_nat n)%Q. + Let idxs := (seq 0 n ++ [0; 1])%list%nat. + Let coef := 2. + Let tight_upperbounds : list Z + := List.map + (fun v : Z => Qceiling (11/10 * v)) + (encode (weight (Qnum limbwidth) (Qden limbwidth)) n s c (s-1)). + Definition prime_bound : ZRange.type.option.interp (base.type.Z) + := Some r[0~>(s - Associational.eval c - 1)]%zrange. + + Definition relax_zrange_of_machine_wordsize + := relax_zrange_gen [machine_wordsize; 2 * machine_wordsize]%Z. + + Let relax_zrange := relax_zrange_of_machine_wordsize. + Definition tight_bounds : list (ZRange.type.option.interp base.type.Z) + := List.map (fun u => Some r[0~>u]%zrange) tight_upperbounds. + Definition loose_bounds : list (ZRange.type.option.interp base.type.Z) + := List.map (fun u => Some r[0 ~> 3*u]%zrange) tight_upperbounds. + + (** Note: If you change the name or type signature of this + function, you will need to update the code in CLI.v *) + Definition check_args {T} (res : Pipeline.ErrorT T) + : Pipeline.ErrorT T + := if negb (Qle_bool 1 limbwidth)%Q + then Error (Pipeline.Value_not_leQ "1 ≤ limbwidth" 1%Q limbwidth) + else if (negb (0 <? s - Associational.eval c))%Z + then Error (Pipeline.Value_not_ltZ "s - Associational.eval c ≤ 0" 0 (s - Associational.eval c)) + else if (s =? 0)%Z + then Error (Pipeline.Values_not_provably_distinctZ "s ≠ 0" s 0) + else if (n =? 0)%nat + then Error (Pipeline.Values_not_provably_distinctZ "n ≠ 0" n 0%nat) + else if (negb (0 <? machine_wordsize)) + then Error (Pipeline.Value_not_ltZ "0 < machine_wordsize" 0 machine_wordsize) + else res. + + Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). + + Notation BoundsPipeline rop in_bounds out_bounds + := (Pipeline.BoundsPipeline + (*false*) true None + relax_zrange + rop%Expr in_bounds out_bounds). + + Notation BoundsPipeline_correct in_bounds out_bounds op + := (fun rv (rop : Expr (reify_type_of op)) Hrop + => @Pipeline.BoundsPipeline_correct_trans + (*false*) true None + relax_zrange + (relax_zrange_gen_good _) + _ + rop + in_bounds + out_bounds + op + Hrop rv) + (only parsing). + + (* N.B. We only need [rcarry_mul] if we want to extract the Pipeline; otherwise we can just use [rcarry_mul_correct] *) + Definition rcarry_mul + := BoundsPipeline + (carry_mul_gen + @ GallinaReify.Reify (Qnum limbwidth) @ GallinaReify.Reify (Z.pos (Qden limbwidth)) @ GallinaReify.Reify s @ GallinaReify.Reify c @ GallinaReify.Reify n @ GallinaReify.Reify idxs) + (Some loose_bounds, (Some loose_bounds, tt)) + (Some tight_bounds). + + Definition rcarry_mul_correct + := BoundsPipeline_correct + (Some loose_bounds, (Some loose_bounds, tt)) + (Some tight_bounds) + (carry_mulmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n idxs). + + Definition rcarry + := BoundsPipeline + (carry_gen + @ GallinaReify.Reify (Qnum limbwidth) @ GallinaReify.Reify (Z.pos (Qden limbwidth)) @ GallinaReify.Reify s @ GallinaReify.Reify c @ GallinaReify.Reify n @ GallinaReify.Reify idxs) + (Some loose_bounds, tt) + (Some tight_bounds). + + Definition rcarry_correct + := BoundsPipeline_correct + (Some loose_bounds, tt) + (Some tight_bounds) + (carrymod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n idxs). + + Definition rrelax + := BoundsPipeline + id_gen + (Some tight_bounds, tt) + (Some loose_bounds). + + Definition rrelax_correct + := BoundsPipeline_correct + (Some tight_bounds, tt) + (Some loose_bounds) + (@id (list Z)). + + Definition radd + := BoundsPipeline + (add_gen + @ GallinaReify.Reify (Qnum limbwidth) @ GallinaReify.Reify (Z.pos (Qden limbwidth)) @ GallinaReify.Reify n) + (Some tight_bounds, (Some tight_bounds, tt)) + (Some loose_bounds). + + Definition radd_correct + := BoundsPipeline_correct + (Some tight_bounds, (Some tight_bounds, tt)) + (Some loose_bounds) + (addmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) n). + + Definition rsub + := BoundsPipeline + (sub_gen + @ GallinaReify.Reify (Qnum limbwidth) @ GallinaReify.Reify (Z.pos (Qden limbwidth)) @ GallinaReify.Reify s @ GallinaReify.Reify c @ GallinaReify.Reify n @ GallinaReify.Reify coef) + (Some tight_bounds, (Some tight_bounds, tt)) + (Some loose_bounds). + + Definition rsub_correct + := BoundsPipeline_correct + (Some tight_bounds, (Some tight_bounds, tt)) + (Some loose_bounds) + (submod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n coef). + + Definition ropp + := BoundsPipeline + (opp_gen + @ GallinaReify.Reify (Qnum limbwidth) @ GallinaReify.Reify (Z.pos (Qden limbwidth)) @ GallinaReify.Reify s @ GallinaReify.Reify c @ GallinaReify.Reify n @ GallinaReify.Reify coef) + (Some tight_bounds, tt) + (Some loose_bounds). + + Definition ropp_correct + := BoundsPipeline_correct + (Some tight_bounds, tt) + (Some loose_bounds) + (oppmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n coef). + + Definition rencode_correct + := BoundsPipeline_correct + (prime_bound, tt) + (Some tight_bounds) + (encodemod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n). + + Definition rzero_correct + := BoundsPipeline_correct + tt + (Some tight_bounds) + (zeromod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n). + + Definition rone_correct + := BoundsPipeline_correct + tt + (Some tight_bounds) + (onemod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n). + + (* we need to strip off [Hrv : ... = Pipeline.Success rv] and related arguments *) + Definition rcarry_mul_correctT rv : Prop + := type_of_strip_3arrow (@rcarry_mul_correct rv). + Definition rcarry_correctT rv : Prop + := type_of_strip_3arrow (@rcarry_correct rv). + Definition rrelax_correctT rv : Prop + := type_of_strip_3arrow (@rrelax_correct rv). + Definition radd_correctT rv : Prop + := type_of_strip_3arrow (@radd_correct rv). + Definition rsub_correctT rv : Prop + := type_of_strip_3arrow (@rsub_correct rv). + Definition ropp_correctT rv : Prop + := type_of_strip_3arrow (@ropp_correct rv). + Definition rencode_correctT rv : Prop + := type_of_strip_3arrow (@rencode_correct rv). + Definition rzero_correctT rv : Prop + := type_of_strip_3arrow (@rzero_correct rv). + Definition rone_correctT rv : Prop + := type_of_strip_3arrow (@rone_correct rv). + + Section make_ring. + Let m : positive := Z.to_pos (s - Associational.eval c). + Context (curve_good : check_args (Success tt) = Success tt) + {rcarry_mulv} (Hrmulv : rcarry_mul_correctT rcarry_mulv) + {rcarryv} (Hrcarryv : rcarry_correctT rcarryv) + {rrelaxv} (Hrrelaxv : rrelax_correctT rrelaxv) + {raddv} (Hraddv : radd_correctT raddv) + {rsubv} (Hrsubv : rsub_correctT rsubv) + {roppv} (Hroppv : ropp_correctT roppv) + {rzerov} (Hrzerov : rzero_correctT rzerov) + {ronev} (Hronev : rone_correctT ronev) + {rencodev} (Hrencodev : rencode_correctT rencodev). + + Local Ltac use_curve_good_t := + repeat first [ progress rewrite ?map_length, ?Z.mul_0_r, ?Pos.mul_1_r, ?Z.mul_1_r in * + | reflexivity + | lia + | rewrite interp_reify_list, ?map_map + | rewrite map_ext with (g:=id), map_id + | progress distr_length + | progress cbv [Qceiling Qfloor Qopp Qdiv Qplus inject_Z Qmult Qinv] in * + | progress cbv [Qle] in * + | progress cbn -[reify_list] in * + | progress intros + | solve [ auto ] ]. + + Lemma use_curve_good + : Z.pos m = s - Associational.eval c + /\ Z.pos m <> 0 + /\ s - Associational.eval c <> 0 + /\ s <> 0 + /\ 0 < machine_wordsize + /\ n <> 0%nat + /\ List.length tight_bounds = n + /\ List.length loose_bounds = n + /\ 0 < Qden limbwidth <= Qnum limbwidth. + Proof. + clear -curve_good. + cbv [check_args] in curve_good. + cbv [tight_bounds loose_bounds prime_bound] in *. + break_innermost_match_hyps; try discriminate. + rewrite negb_false_iff in *. + Z.ltb_to_lt. + rewrite Qle_bool_iff in *. + rewrite NPeano.Nat.eqb_neq in *. + intros. + cbv [Qnum Qden limbwidth Qceiling Qfloor Qopp Qdiv Qplus inject_Z Qmult Qinv] in *. + rewrite ?map_length, ?Z.mul_0_r, ?Pos.mul_1_r, ?Z.mul_1_r in *. + specialize_by lia. + repeat match goal with H := _ |- _ => subst H end. + repeat apply conj. + { destruct (s - Associational.eval c); cbn; lia. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + Qed. + + Definition GoodT : Prop + := @Ring.GoodT + (Qnum limbwidth) + (Z.pos (Qden limbwidth)) + n s c + tight_bounds + (Interp rrelaxv) + (Interp rcarry_mulv) + (Interp rcarryv) + (Interp raddv) + (Interp rsubv) + (Interp roppv) + (Interp rzerov) + (Interp ronev) + (Interp rencodev). + + Theorem Good : GoodT. + Proof. + pose proof use_curve_good; destruct_head'_and; destruct_head_hnf' ex. + eapply Ring.Good; + lazymatch goal with + | [ H : ?P ?rop |- context[expr.Interp _ ?rop] ] + => intros; + let H1 := fresh in + let H2 := fresh in + unshelve edestruct H as [H1 H2]; [ .. | solve [ split; [ eapply H1 | eapply H2 ] ] ]; + solve [ exact tt | eassumption | reflexivity ] + | _ => idtac + end; + repeat first [ assumption + | intros; apply eval_carry_mulmod + | intros; apply eval_carrymod + | intros; apply eval_addmod + | intros; apply eval_submod + | intros; apply eval_oppmod + | intros; apply eval_encodemod + | apply conj ]. + Qed. + End make_ring. + + Section for_stringification. + Local Open Scope string_scope. + + Let ToFunLines t name E arg_bounds + := (name, + match E with + | Success E' + => let E := @ToString.C.ToFunctionLines + name t E' None arg_bounds in + match E with + | Some E => Success E + | None => Error (Pipeline.Stringification_failed E') + end + | Error err => Error err + end). + + (** Note: If you change the name or type signature of this + function, you will need to update the code in CLI.v *) + Definition Synthesize (function_name_prefix : string) : list (string * Pipeline.ErrorT (list string)) + := let loose_bounds := Some loose_bounds in + let tight_bounds := Some tight_bounds in + let fe op := (function_name_prefix ++ op)%string in + [(ToFunLines _ (fe "carry_mul") rcarry_mul (loose_bounds, (loose_bounds, tt))); + (ToFunLines _ (fe "carry") rcarry (loose_bounds, tt)); + (ToFunLines _ (fe "add") radd (tight_bounds, (tight_bounds, tt))); + (ToFunLines _ (fe "sub") rsub (tight_bounds, (tight_bounds, tt))); + (ToFunLines _ (fe "opp") ropp (tight_bounds, tt))]. + End for_stringification. + End rcarry_mul. +End UnsaturatedSolinas. + +Ltac peel_interp_app _ := + lazymatch goal with + | [ |- ?R' (?InterpE ?arg) (?f ?arg) ] + => apply fg_equal_rel; [ | reflexivity ]; + try peel_interp_app () + | [ |- ?R' (Interp ?ev) (?f ?x) ] + => let sv := type of x in + let fx := constr:(f x) in + let dv := type of fx in + let rs := reify_type sv in + let rd := reify_type dv in + etransitivity; + [ apply @expr.Interp_APP_rel_reflexive with (s:=rs) (d:=rd) (R:=R'); + typeclasses eauto + | apply fg_equal_rel; + [ try peel_interp_app () + | try lazymatch goal with + | [ |- ?R (Interp ?ev) (Interp _) ] + => reflexivity + | [ |- ?R (Interp ?ev) ?c ] + => let rc := constr:(GallinaReify.Reify c) in + unify ev rc; reflexivity + end ] ] + end. +Ltac pre_cache_reify _ := + cbv [type.app_curried]; + let arg := fresh "arg" in + intros arg _; + peel_interp_app (); + [ lazymatch goal with + | [ |- ?R (Interp ?ev) _ ] + => (tryif is_evar ev + then let ev' := fresh "ev" in set (ev' := ev) + else idtac) + end; + cbv [pointwise_relation]; intros; clear + | .. ]. +Ltac do_inline_cache_reify do_if_not_cached := + pre_cache_reify (); + [ try solve [ + repeat match goal with H := ?e |- _ => is_evar e; subst H end; + eauto with nocore reify_gen_cache; + do_if_not_cached () + ]; + cache_reify () + | .. ]. + +(* TODO: MOVE ME *) +Ltac vm_compute_lhs_reflexivity := + lazymatch goal with + | [ |- ?LHS = ?RHS ] + => let x := (eval vm_compute in LHS) in + (* we cannot use the unify tactic, which just gives "not + unifiable" as the error message, because we want to see the + terms that were not unifable. See also + COQBUG(https://github.com/coq/coq/issues/7291) *) + let _unify := constr:(ltac:(reflexivity) : RHS = x) in + vm_cast_no_check (eq_refl x) + end. + +Ltac solve_rop' rop_correct do_if_not_cached machine_wordsizev := + eapply rop_correct with (machine_wordsize:=machine_wordsizev); + [ do_inline_cache_reify do_if_not_cached + | subst_evars; vm_compute_lhs_reflexivity (* lazy; reflexivity *) ]. +Ltac solve_rop_nocache rop_correct := + solve_rop' rop_correct ltac:(fun _ => idtac). +Ltac solve_rop rop_correct := + solve_rop' + rop_correct + ltac:(fun _ => let G := get_goal in fail 2 "Could not find a solution in reify_gen_cache for" G). +Ltac solve_rcarry_mul := solve_rop rcarry_mul_correct. +Ltac solve_rcarry_mul_nocache := solve_rop_nocache rcarry_mul_correct. +Ltac solve_rcarry := solve_rop rcarry_correct. +Ltac solve_radd := solve_rop radd_correct. +Ltac solve_rsub := solve_rop rsub_correct. +Ltac solve_ropp := solve_rop ropp_correct. +Ltac solve_rencode := solve_rop rencode_correct. +Ltac solve_rrelax := solve_rop rrelax_correct. +Ltac solve_rzero := solve_rop rzero_correct. +Ltac solve_rone := solve_rop rone_correct. + +Module PrintingNotations. + Export ident. + (*Global Set Printing Width 100000.*) + Open Scope zrange_scope. + Notation "'uint256'" + := (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : zrange_scope. + Notation "'uint128'" + := (r[0 ~> 340282366920938463463374607431768211455]%zrange) : zrange_scope. + Notation "'uint64'" + := (r[0 ~> 18446744073709551615]) : zrange_scope. + Notation "'uint32'" + := (r[0 ~> 4294967295]) : zrange_scope. + Notation "'bool'" + := (r[0 ~> 1]%zrange) : zrange_scope. + Notation "( range )( ls [[ n ]] )" + := ((#(ident.Z_cast range) @ (ls [[ n ]]))%expr) + (format "( range )( ls [[ n ]] )") : expr_scope. + (*Notation "( range )( v )" := (ident.Z_cast range @@ v)%expr : expr_scope.*) + Notation "x *₂₅₆ y" + := (#(ident.Z_cast uint256) @ (#ident.Z_mul @ x @ y))%expr (at level 40) : expr_scope. + Notation "x *₁₂₈ y" + := (#(ident.Z_cast uint128) @ (#ident.Z_mul @ x @ y))%expr (at level 40) : expr_scope. + Notation "x *₆₄ y" + := (#(ident.Z_cast uint64) @ (#ident.Z_mul @ x @ y))%expr (at level 40) : expr_scope. + Notation "x *₃₂ y" + := (#(ident.Z_cast uint32) @ (#ident.Z_mul @ x @ y))%expr (at level 40) : expr_scope. + Notation "x +₂₅₆ y" + := (#(ident.Z_cast uint256) @ (#ident.Z_add @ x @ y))%expr (at level 50) : expr_scope. + Notation "x +₁₂₈ y" + := (#(ident.Z_cast uint128) @ (#ident.Z_add @ x @ y))%expr (at level 50) : expr_scope. + Notation "x +₆₄ y" + := (#(ident.Z_cast uint64) @ (#ident.Z_add @ x @ y))%expr (at level 50) : expr_scope. + Notation "x +₃₂ y" + := (#(ident.Z_cast uint32) @ (#ident.Z_add @ x @ y))%expr (at level 50) : expr_scope. + Notation "x -₁₂₈ y" + := (#(ident.Z_cast uint128) @ (#ident.Z_sub @ x @ y))%expr (at level 50) : expr_scope. + Notation "x -₆₄ y" + := (#(ident.Z_cast uint64) @ (#ident.Z_sub @ x @ y))%expr (at level 50) : expr_scope. + Notation "x -₃₂ y" + := (#(ident.Z_cast uint32) @ (#ident.Z_sub @ x @ y))%expr (at level 50) : expr_scope. + Notation "( out_t )( v >> count )" + := ((#(ident.Z_cast out_t) @ (#(ident.Z_shiftr count) @ v))%expr) + (format "( out_t )( v >> count )") : expr_scope. + Notation "( out_t )( v << count )" + := ((#(ident.Z_cast out_t) @ (#(ident.Z_shiftl count) @ v))%expr) + (format "( out_t )( v << count )") : expr_scope. + Notation "( range )( v )" + := ((#(ident.Z_cast range) @ $v)%expr) + (format "( range )( v )") : expr_scope. + Notation "( ( out_t )( v ) & mask )" + := ((#(ident.Z_cast out_t) @ (#(ident.Z_land mask) @ v))%expr) + (format "( ( out_t )( v ) & mask )") + : expr_scope. + + Notation "x" := (#(ident.Z_cast _) @ $x)%expr (only printing, at level 9) : expr_scope. + Notation "x" := (#(ident.Z_cast2 _) @ $x)%expr (only printing, at level 9) : expr_scope. + Notation "v ₁" := (#ident.fst @ $v)%expr (at level 10, format "v ₁") : expr_scope. + Notation "v ₂" := (#ident.snd @ $v)%expr (at level 10, format "v ₂") : expr_scope. + Notation "v ₁" := (#(ident.Z_cast _) @ (#ident.fst @ $v))%expr (at level 10, format "v ₁") : expr_scope. + Notation "v ₂" := (#(ident.Z_cast _) @ (#ident.snd @ $v))%expr (at level 10, format "v ₂") : expr_scope. + Notation "v ₁" := (#(ident.Z_cast _) @ (#ident.fst @ (#(ident.Z_cast2 _) @ $v)))%expr (at level 10, format "v ₁") : expr_scope. + Notation "v ₂" := (#(ident.Z_cast _) @ (#ident.snd @ (#(ident.Z_cast2 _) @ $v)))%expr (at level 10, format "v ₂") : expr_scope. + Notation "x" := (#(ident.Literal x%Z))%expr (only printing) : expr_scope. + + (*Notation "ls [[ n ]]" := (List.nth_default_concrete _ n @@ ls)%expr : expr_scope. + Notation "( range )( v )" := (ident.Z_cast range @@ v)%expr : expr_scope. + Notation "x *₁₂₈ y" + := (ident.Z_cast uint128 @@ (ident.Z.mul (x, y)))%expr (at level 40) : expr_scope. + Notation "( out_t )( v >> count )" + := (ident.Z_cast out_t (ident.Z.shiftr count @@ v)%expr) + (format "( out_t )( v >> count )") : expr_scope. + Notation "( out_t )( v >> count )" + := (ident.Z_cast out_t (ident.Z.shiftr count @@ v)%expr) + (format "( out_t )( v >> count )") : expr_scope. + Notation "v ₁" := (ident.fst @@ v)%expr (at level 10, format "v ₁") : expr_scope. + Notation "v ₂" := (ident.snd @@ v)%expr (at level 10, format "v ₂") : expr_scope.*) + (* + Notation "'ℤ'" + := BoundsAnalysis.type.Z : zrange_scope. + Notation "ls [[ n ]]" := (List.nth n @@ ls)%nexpr : nexpr_scope. + Notation "x *₆₄₋₆₄₋₁₂₈ y" + := (mul uint64 uint64 uint128 @@ (x, y))%nexpr (at level 40) : nexpr_scope. + Notation "x *₆₄₋₆₄₋₆₄ y" + := (mul uint64 uint64 uint64 @@ (x, y))%nexpr (at level 40) : nexpr_scope. + Notation "x *₃₂₋₃₂₋₃₂ y" + := (mul uint32 uint32 uint32 @@ (x, y))%nexpr (at level 40) : nexpr_scope. + Notation "x *₃₂₋₁₂₈₋₁₂₈ y" + := (mul uint32 uint128 uint128 @@ (x, y))%nexpr (at level 40) : nexpr_scope. + Notation "x *₃₂₋₆₄₋₆₄ y" + := (mul uint32 uint64 uint64 @@ (x, y))%nexpr (at level 40) : nexpr_scope. + Notation "x *₃₂₋₃₂₋₆₄ y" + := (mul uint32 uint32 uint64 @@ (x, y))%nexpr (at level 40) : nexpr_scope. + Notation "x +₁₂₈ y" + := (add uint128 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x +₆₄₋₁₂₈₋₁₂₈ y" + := (add uint64 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x +₃₂₋₆₄₋₆₄ y" + := (add uint32 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x +₆₄ y" + := (add uint64 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x +₃₂ y" + := (add uint32 uint32 uint32 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x -₁₂₈ y" + := (sub uint128 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x -₆₄₋₁₂₈₋₁₂₈ y" + := (sub uint64 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x -₃₂₋₆₄₋₆₄ y" + := (sub uint32 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x -₆₄ y" + := (sub uint64 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x -₃₂ y" + := (sub uint32 uint32 uint32 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x" := ({| BoundsAnalysis.type.value := x |}) (only printing) : nexpr_scope. + Notation "( out_t )( v >> count )" + := ((shiftr _ out_t count @@ v)%nexpr) + (format "( out_t )( v >> count )") + : nexpr_scope. + Notation "( out_t )( v << count )" + := ((shiftl _ out_t count @@ v)%nexpr) + (format "( out_t )( v << count )") + : nexpr_scope. + Notation "( ( out_t ) v & mask )" + := ((land _ out_t mask @@ v)%nexpr) + (format "( ( out_t ) v & mask )") + : nexpr_scope. +*) + (* TODO: come up with a better notation for arithmetic with carries + that still distinguishes it from arithmetic without carries? *) + Local Notation "'TwoPow256'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 (only parsing). + Notation "'ADD_256' ( x , y )" := (#(ident.Z_cast2 (uint256, bool)%core) @ (#(ident.Z_add_get_carry_concrete TwoPow256) @ x @ y))%expr : expr_scope. + Notation "'ADD_128' ( x , y )" := (#(ident.Z_cast2 (uint128, bool)%core) @ (#(ident.Z_add_get_carry_concrete TwoPow256) @ x @ y))%expr : expr_scope. + Notation "'ADDC_256' ( x , y , z )" := (#(ident.Z_cast2 (uint256, bool)%core) @ (#(ident.Z_add_with_get_carry_concrete TwoPow256) @ x @ y @ z))%expr : expr_scope. + Notation "'ADDC_128' ( x , y , z )" := (#(ident.Z_cast2 (uint128, bool)%core) @ (#(ident.Z_add_with_get_carry_concrete TwoPow256) @ x @ y @ z))%expr : expr_scope. + Notation "'SUB_256' ( x , y )" := (#(ident.Z_cast2 (uint256, bool)%core) @ (#(ident.Z_sub_get_borrow_concrete TwoPow256) @ x @ y))%expr : expr_scope. + Notation "'SUBB_256' ( x , y , z )" := (#(ident.Z_cast2 (uint256, bool)%core) @ (#(ident.Z_sub_with_get_borrow_concrete TwoPow256) @ x @ y @ z))%expr : expr_scope. + Notation "'ADDM' ( x , y , z )" := (#(ident.Z_cast uint256) @ (#ident.Z_add_modulo @ x @ y @ z))%expr : expr_scope. + Notation "'RSHI' ( x , y , z )" := (#(ident.Z_cast _) @ (#(ident.Z_rshi_concrete _ z) @ x @ y))%expr : expr_scope. + Notation "'SELC' ( x , y , z )" := (#(ident.Z_cast uint256) @ (ident.Z_zselect @ x @ y @ z))%expr : expr_scope. + Notation "'SELM' ( x , y , z )" := (#(ident.Z_cast uint256) @ (ident.Z_zselect @ (#(Z_cast bool) @ (Z_cc_m_concrete _) @ x) @ y @ z))%expr : expr_scope. + Notation "'SELL' ( x , y , z )" := (#(ident.Z_cast uint256) @ (#ident.Z_zselect @ (#(Z_cast bool) @ (#(Z_land 1) @ x)) @ y @ z))%expr : expr_scope. +End PrintingNotations. + +(* +Notation "a ∈ b" := (ZRange.type.is_bounded_by b%zrange a = true) (at level 10) : type_scope. +Notation Interp := (expr.Interp _). +Notation "'ℤ'" := (type.type_primitive type.Z). +Set Printing Width 70. +Goal False. + let rop' := Reify (fun v1v2 : Z * Z => fst v1v2 + snd v1v2) in + pose rop' as rop. + pose (@Pipeline.BoundsPipeline_full + false (fun v => Some v) (type.Z * type.Z) type.Z + rop + (r[0~>10], r[0~>10])%zrange + r[0~>20]%zrange + ) as E. + simple refine (let Ev := _ in + let compiler_outputs_Ev : E = Pipeline.Success Ev := _ in + _); [ shelve | .. ]; revgoals. + clearbody compiler_outputs_Ev. + refine (let H' := + (fun H'' => + @Pipeline.BoundsPipeline_full_correct + _ _ + H'' _ _ _ _ _ _ compiler_outputs_Ev) _ + in _); + clearbody H'. + Focus 2. + { cbv [Pipeline.BoundsPipeline_full] in E. + remember (Pipeline.PrePipeline rop) as cache eqn:Hcache in (value of E). + lazy in Hcache. + subst cache. + lazy in E. + subst E Ev; reflexivity. + } Unfocus. + cbv [rop] in H'; cbn [expr.Interp expr.interp for_reification.ident.interp] in H'. +(* + H' : forall arg : type.interp (ℤ * ℤ), + arg ∈ (r[0 ~> 10], r[0 ~> 10]) -> + (Interp Ev arg) ∈ r[0 ~> 20] /\ + Interp Ev arg = fst arg + snd arg +*) +Abort. +*) + +Module SaturatedSolinas. + Section MulMod. + Context (s : Z) (c : list (Z * Z)) + (s_nz : s <> 0) (modulus_nz : s - Associational.eval c <> 0). + Context (log2base : Z) (log2base_pos : 0 < log2base) + (n nreductions : nat) (n_nz : n <> 0%nat). + + Let weight := weight log2base 1. + Let props : @weight_properties weight := wprops log2base 1 ltac:(omega). + Local Lemma base_nz : 2 ^ log2base <> 0. Proof. auto with zarith. Qed. + + Derive mulmod + SuchThat (forall (f g : list Z) + (Hf : length f = n) + (Hg : length g = n), + (eval weight n (fst (mulmod f g)) + weight n * (snd (mulmod f g))) mod (s - Associational.eval c) + = (eval weight n f * eval weight n g) mod (s - Associational.eval c)) + As eval_mulmod. + Proof. + intros. + rewrite <-Rows.eval_mulmod with (base:=2^log2base) (s:=s) (c:=c) (nreductions:=nreductions) by auto using base_nz. + eapply f_equal2; [|trivial]. + (* expand_lists (). *) (* uncommenting this line removes some unused multiplications but also inlines a bunch of carry stuff at the end *) + subst mulmod. reflexivity. + Qed. + Definition mulmod' := fun x y => fst (mulmod x y). + End MulMod. + + Derive mulmod_gen + SuchThat (forall (log2base s : Z) (c : list (Z * Z)) (n nreductions : nat) + (f g : list Z), + Interp (t:=reify_type_of mulmod') + mulmod_gen s c log2base n nreductions f g + = mulmod' s c log2base n nreductions f g) + As mulmod_gen_correct. + Proof. Time cache_reify (). Time Qed. + Module Export ReifyHints. + Global Hint Extern 1 (_ = mulmod' _ _ _ _ _ _ _) => simple apply mulmod_gen_correct : reify_gen_cache. + End ReifyHints. + + Section rmulmod. + Context (s : Z) + (c : list (Z * Z)) + (machine_wordsize : Z). + + Definition relax_zrange_of_machine_wordsize + := relax_zrange_gen [1; machine_wordsize]%Z. + + Let n : nat := Z.to_nat (Qceiling (Z.log2_up s / machine_wordsize)). + (* Number of reductions is calculated as follows : + Let i be the highest limb index of c. Then, each reduction + decreases the number of extra limbs by (n-i). So, to go from + the n extra limbs we have post-multiplication down to 0, we + need ceil (n / (n - i)) reductions. *) + Let nreductions : nat := + let i := fold_right Z.max 0 (map (fun t => Z.log2 (fst t) / machine_wordsize) c) in + Z.to_nat (Qceiling (Z.of_nat n / (Z.of_nat n - i))). + Let relax_zrange := relax_zrange_of_machine_wordsize. + Let bound := Some r[0 ~> (2^machine_wordsize - 1)]%zrange. + Let boundsn : list (ZRange.type.option.interp base.type.Z) + := repeat bound n. + + (** Note: If you change the name or type signature of this + function, you will need to update the code in CLI.v *) + Definition check_args {T} (res : Pipeline.ErrorT T) + : Pipeline.ErrorT T + := if (negb (0 <? s - Associational.eval c))%Z + then Error (Pipeline.Value_not_ltZ "s - Associational.eval c ≤ 0" 0 (s - Associational.eval c)) + else if (s =? 0)%Z + then Error (Pipeline.Values_not_provably_distinctZ "s ≠ 0" s 0) + else if (n =? 0)%nat + then Error (Pipeline.Values_not_provably_distinctZ "n ≠ 0" n 0) + else if (negb (0 <? machine_wordsize)) + then Error (Pipeline.Value_not_ltZ "0 < machine_wordsize" 0 machine_wordsize) + else res. + + Notation BoundsPipeline rop in_bounds out_bounds + := (Pipeline.BoundsPipeline + (*false*) false None + relax_zrange + rop%Expr in_bounds out_bounds). + + Notation BoundsPipeline_correct in_bounds out_bounds op + := (fun rv (rop : Expr (reify_type_of op)) Hrop + => @Pipeline.BoundsPipeline_correct_trans + (*false*) false None + relax_zrange + (relax_zrange_gen_good _) + _ + rop + in_bounds + out_bounds + op + Hrop rv) + (only parsing). + + Definition rmulmod_correct + := BoundsPipeline_correct + (Some boundsn, (Some boundsn, tt)) + (Some boundsn) + (mulmod' s c machine_wordsize n nreductions). + + Definition rmulmod + := BoundsPipeline + (mulmod_gen @ GallinaReify.Reify s @ GallinaReify.Reify c @ GallinaReify.Reify machine_wordsize @ GallinaReify.Reify n @ GallinaReify.Reify nreductions) + (Some boundsn, (Some boundsn, tt)) + (Some boundsn). + + Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). + Definition rmulmod_correctT rv : Prop + := type_of_strip_3arrow (@rmulmod_correct rv). + + Section for_stringification. + Local Open Scope string_scope. + + Let ToFunLines t name E arg_bounds + := (name, + match E with + | Success E' + => let E := @ToString.C.ToFunctionLines + name t E' None arg_bounds in + match E with + | Some E => Success E + | None => Error (Pipeline.Stringification_failed E') + end + | Error err => Error err + end). + + (** Note: If you change the name or type signature of this + function, you will need to update the code in CLI.v *) + Definition Synthesize (function_name_prefix : string) : list (string * Pipeline.ErrorT (list string)) + := let loose_bounds := Some loose_bounds in + let tight_bounds := Some tight_bounds in + let fe op := (function_name_prefix ++ op)%string in + [(ToFunLines _ (fe "mulmod") rmulmod (Some boundsn, (Some boundsn, tt)))]. + End for_stringification. + End rmulmod. +End SaturatedSolinas. + +Ltac solve_rmulmod := solve_rop SaturatedSolinas.rmulmod_correct. +Ltac solve_rmulmod_nocache := solve_rop_nocache SaturatedSolinas.rmulmod_correct. + +Module Import InvertHighLow. + Section with_wordmax. + Context (log2wordmax : Z) (consts : list Z). + Let wordmax := 2 ^ log2wordmax. + Let half_bits := log2wordmax / 2. + Let wordmax_half_bits := 2 ^ half_bits. + + Inductive kind_of_constant := upper_half (c : BinInt.Z) | lower_half (c : BinInt.Z). + + Definition constant_to_scalar_single (const x : BinInt.Z) : option kind_of_constant := + if x =? (BinInt.Z.shiftr const half_bits) + then Some (upper_half const) + else if x =? (BinInt.Z.land const (wordmax_half_bits - 1)) + then Some (lower_half const) + else None. + + Definition constant_to_scalar (x : BinInt.Z) + : option kind_of_constant := + fold_right (fun c res => match res with + | Some s => Some s + | None => constant_to_scalar_single c x + end) None consts. + + Definition invert_low (v : BinInt.Z) : option BinInt.Z + := match constant_to_scalar v with + | Some (lower_half v) => Some v + | _ => None + end. + + Definition invert_high (v : BinInt.Z) : option BinInt.Z + := match constant_to_scalar v with + | Some (upper_half v) => Some v + | _ => None + end. + End with_wordmax. +End InvertHighLow. + +Module BarrettReduction. + (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *) + Section Generic. + Context {T} (rep : T -> Z -> Prop) + (k : Z) (k_pos : 0 < k) + (low : T -> Z) + (low_correct : forall a x, rep a x -> low a = x mod 2 ^ k) + (shiftr : T -> Z -> T) + (shiftr_correct : forall a x n, + rep a x -> + 0 <= n <= k -> + rep (shiftr a n) (x / 2 ^ n)) + (mul_high : T -> T -> Z -> T) + (mul_high_correct : forall a b x y x0y1, + rep a x -> + rep b y -> + 2 ^ k <= x < 2^(k+1) -> + 0 <= y < 2^(k+1) -> + x0y1 = x mod 2 ^ k * (y / 2 ^ k) -> + rep (mul_high a b x0y1) (x * y / 2 ^ k)) + (mul : Z -> Z -> T) + (mul_correct : forall x y, + 0 <= x < 2^k -> + 0 <= y < 2^k -> + rep (mul x y) (x * y)) + (sub : T -> T -> T) + (sub_correct : forall a b x y, + rep a x -> + rep b y -> + 0 <= x - y < 2^k * 2^k -> + rep (sub a b) (x - y)) + (cond_sub1 : T -> Z -> Z) + (cond_sub1_correct : forall a x y, + rep a x -> + 0 <= x < 2 * y -> + 0 <= y < 2 ^ k -> + cond_sub1 a y = if (x <? 2 ^ k) then x else x - y) + (cond_sub2 : Z -> Z -> Z) + (cond_sub2_correct : forall x y, cond_sub2 x y = if (x <? y) then x else x - y). + Context (xt mut : T) (M muSelect: Z). + + Let mu := 2 ^ (2 * k) / M. + Context x (mu_rep : rep mut mu) (x_rep : rep xt x). + Context (M_nz : 0 < M) + (x_range : 0 <= x < M * 2 ^ k) + (M_range : 2 ^ (k - 1) < M < 2 ^ k) + (M_good : 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu) + (muSelect_correct: muSelect = mu mod 2 ^ k * (x / 2 ^ (k - 1) / 2 ^ k)). + + Definition qt := + dlet_nd muSelect := muSelect in (* makes sure muSelect is not inlined in the output *) + dlet_nd q1 := shiftr xt (k - 1) in + dlet_nd twoq := mul_high mut q1 muSelect in + shiftr twoq 1. + Definition reduce := + dlet_nd qt := qt in + dlet_nd r2 := mul (low qt) M in + dlet_nd r := sub xt r2 in + let q3 := cond_sub1 r M in + cond_sub2 q3 M. + + Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k). + Proof. clear -M_range M_nz x_range k_pos; rewrite <-Z.add_diag, Z.pow_add_r; nia. Qed. + + Lemma pow_2k_eq : 2 ^ (2*k) = 2 ^ (k - 1) * 2 ^ (k + 1). + Proof. clear -k_pos; rewrite <-Z.pow_add_r by omega. f_equal; ring. Qed. + + Lemma mu_bounds : 2 ^ k <= mu < 2^(k+1). + Proof. + pose proof looser_bound. + subst mu. split. + { apply Z.div_le_lower_bound; omega. } + { apply Z.div_lt_upper_bound; try omega. + rewrite pow_2k_eq; apply Z.mul_lt_mono_pos_r; auto with zarith. } + Qed. + + Lemma shiftr_x_bounds : 0 <= x / 2 ^ (k - 1) < 2^(k+1). + Proof. + pose proof looser_bound. + split; [ solve [Z.zero_bounds] | ]. + apply Z.div_lt_upper_bound; auto with zarith. + rewrite <-pow_2k_eq. omega. + Qed. + Hint Resolve shiftr_x_bounds. + + Ltac solve_rep := eauto using shiftr_correct, mul_high_correct, mul_correct, sub_correct with omega. + + Let q := mu * (x / 2 ^ (k - 1)) / 2 ^ (k + 1). + + Lemma q_correct : rep qt q . + Proof. + pose proof mu_bounds. cbv [qt]; subst q. + rewrite Z.pow_add_r, <-Z.div_div by Z.zero_bounds. + solve_rep. + Qed. + Hint Resolve q_correct. + + Lemma x_mod_small : x mod 2 ^ (k - 1) <= M. + Proof. transitivity (2 ^ (k - 1)); auto with zarith. Qed. + Hint Resolve x_mod_small. + + Lemma q_bounds : 0 <= q < 2 ^ k. + Proof. + pose proof looser_bound. pose proof x_mod_small. pose proof mu_bounds. + split; subst q; [ solve [Z.zero_bounds] | ]. + edestruct q_nice_strong with (n:=M) as [? Hqnice]; + try rewrite Hqnice; auto; try omega; [ ]. + apply Z.le_lt_trans with (m:= x / M). + { break_match; omega. } + { apply Z.div_lt_upper_bound; omega. } + Qed. + + Lemma two_conditional_subtracts : + forall a x, + rep a x -> + 0 <= x < 2 * M -> + cond_sub2 (cond_sub1 a M) M = cond_sub2 (cond_sub2 x M) M. + Proof. + intros. + erewrite !cond_sub2_correct, !cond_sub1_correct by (eassumption || omega). + break_match; Z.ltb_to_lt; try lia; discriminate. + Qed. + + Lemma r_bounds : 0 <= x - q * M < 2 * M. + Proof. + pose proof looser_bound. pose proof q_bounds. pose proof x_mod_small. + subst q mu; split. + { Z.zero_bounds. apply qn_small; omega. } + { apply r_small_strong; rewrite ?Z.pow_1_r; auto; omega. } + Qed. + + Lemma reduce_correct : reduce = x mod M. + Proof. + pose proof looser_bound. pose proof r_bounds. pose proof q_bounds. + assert (2 * M < 2^k * 2^k) by nia. + rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega). + cbv [reduce Let_In]. + erewrite low_correct by eauto. Z.rewrite_mod_small. + erewrite two_conditional_subtracts by solve_rep. + rewrite !cond_sub2_correct. + subst q; reflexivity. + Qed. + End Generic. + + Section BarrettReduction. + Context (k : Z) (k_bound : 2 <= k). + Context (M muLow : Z). + Context (M_pos : 0 < M) + (muLow_eq : muLow + 2^k = 2^(2*k) / M) + (muLow_bounds : 0 <= muLow < 2^k) + (M_bound1 : 2 ^ (k - 1) < M < 2^k) + (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)). + + Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k). + Context (nout : nat) (Hnout : nout = 2%nat). + Let w := weight k 1. + Local Lemma k_range : 0 < 1 <= k. Proof. omega. Qed. + Let props : @weight_properties w := wprops k 1 k_range. + + Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval. + + Definition low (t : list Z) : Z := nth_default 0 t 0. + Definition high (t : list Z) : Z := nth_default 0 t 1. + Definition represents (t : list Z) (x : Z) := + t = [x mod 2^k; x / 2^k] /\ 0 <= x < 2^k * 2^k. + + Lemma represents_eq t x : + represents t x -> t = [x mod 2^k; x / 2^k]. + Proof. cbv [represents]; tauto. Qed. + + Lemma represents_length t x : represents t x -> length t = 2%nat. + Proof. cbv [represents]; intuition. subst t; reflexivity. Qed. + + Lemma represents_low t x : + represents t x -> low t = x mod 2^k. + Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. + + Lemma represents_high t x : + represents t x -> high t = x / 2^k. + Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. + + Lemma represents_low_range t x : + represents t x -> 0 <= x mod 2^k < 2^k. + Proof. auto with zarith. Qed. + + Lemma represents_high_range t x : + represents t x -> 0 <= x / 2^k < 2^k. + Proof. + destruct 1 as [? [? ?] ]; intros. + auto using Z.div_lt_upper_bound with zarith. + Qed. + Hint Resolve represents_length represents_low_range represents_high_range. + + Lemma represents_range t x : + represents t x -> 0 <= x < 2^k*2^k. + Proof. cbv [represents]; tauto. Qed. + + Lemma represents_id x : + 0 <= x < 2^k * 2^k -> + represents [x mod 2^k; x / 2^k] x. + Proof. + intros; cbv [represents]; autorewrite with cancel_pair. + Z.rewrite_mod_small; tauto. + Qed. + + Local Ltac push_rep := + repeat match goal with + | H : represents ?t ?x |- _ => unique pose proof (represents_low_range _ _ H) + | H : represents ?t ?x |- _ => unique pose proof (represents_high_range _ _ H) + | H : represents ?t ?x |- _ => rewrite (represents_low t x) in * by assumption + | H : represents ?t ?x |- _ => rewrite (represents_high t x) in * by assumption + end. + + Definition shiftr (t : list Z) (n : Z) : list Z := + [Z.rshi (2^k) (high t) (low t) n; Z.rshi (2^k) 0 (high t) n]. + + Lemma shiftr_represents a i x : + represents a x -> + 0 <= i <= k -> + represents (shiftr a i) (x / 2 ^ i). + Proof. + cbv [shiftr]; intros; push_rep. + match goal with H : _ |- _ => pose proof (represents_range _ _ H) end. + assert (0 < 2 ^ i) by auto with zarith. + assert (x < 2 ^ i * 2 ^ k * 2 ^ k) by nia. + assert (0 <= x / 2 ^ k / 2 ^ i < 2 ^ k) by + (split; Z.zero_bounds; auto using Z.div_lt_upper_bound with zarith). + repeat match goal with + | _ => rewrite Z.rshi_correct by auto with zarith + | _ => rewrite <-Z.div_mod''' by auto with zarith + | _ => progress autorewrite with zsimplify_fast + | _ => progress Z.rewrite_mod_small + | |- context [represents [(?a / ?c) mod ?b; ?a / ?b / ?c] ] => + rewrite (Z.div_div_comm a b c) by auto with zarith + | _ => solve [auto using represents_id, Z.div_lt_upper_bound with zarith lia] + end. + Qed. + + Context (Hw : forall i, w i = (2 ^ k) ^ Z.of_nat i). + Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r. + + Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2). + (* TODO: use this definition once issue #352 is resolved *) + (* Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2). *) + Definition widesub (t1 t2 : list Z) := + let t1_0 := hd 0 t1 in + let t1_1 := hd 0 (tl t1) in + let t2_0 := hd 0 t2 in + let t2_1 := hd 0 (tl t2) in + dlet_nd x0 := Z.sub_get_borrow_full (2^k) t1_0 t2_0 in + dlet_nd x1 := Z.sub_with_get_borrow_full (2^k) (snd x0) t1_1 t2_1 in + [fst x0; fst x1]. + Definition widemul := BaseConversion.widemul_inlined k n nout. + + Lemma partition_represents x : + 0 <= x < 2^k*2^k -> + represents (Rows.partition w 2 x) x. + Proof. + intros; cbn. change_weight. + Z.rewrite_mod_small. + autorewrite with zsimplify_fast. + auto using represents_id. + Qed. + + Lemma eval_represents t x : + represents t x -> eval w 2 t = x. + Proof. + intros; rewrite (represents_eq t x) by assumption. + cbn. change_weight; push_rep. + autorewrite with zsimplify. reflexivity. + Qed. + + Ltac wide_op partitions_pf := + repeat match goal with + | _ => rewrite partitions_pf by eauto + | _ => rewrite partitions_pf by auto with zarith + | _ => erewrite eval_represents by eauto + | _ => solve [auto using partition_represents, represents_id] + end. + + Lemma wideadd_represents t1 t2 x y : + represents t1 x -> + represents t2 y -> + 0 <= x + y < 2^k*2^k -> + represents (wideadd t1 t2) (x + y). + Proof. intros; cbv [wideadd]. wide_op Rows.add_partitions. Qed. + + Lemma widesub_represents t1 t2 x y : + represents t1 x -> + represents t2 y -> + 0 <= x - y < 2^k*2^k -> + represents (widesub t1 t2) (x - y). + Proof. + intros; cbv [widesub Let_In]. + rewrite (represents_eq t1 x) by assumption. + rewrite (represents_eq t2 y) by assumption. + cbn [hd tl]. + autorewrite with to_div_mod. + pull_Zmod. + match goal with |- represents [?m; ?d] ?x => + replace d with (x / 2 ^ k); [solve [auto using represents_id] |] end. + rewrite <-(Z.mod_small ((x - y) / 2^k) (2^k)) by (split; try apply Z.div_lt_upper_bound; Z.zero_bounds). + f_equal. + transitivity ((x mod 2^k - y mod 2^k + 2^k * (x / 2 ^ k) - 2^k * (y / 2^k)) / 2^k). { + rewrite (Z.div_mod x (2^k)) at 1 by auto using Z.pow_nonzero with omega. + rewrite (Z.div_mod y (2^k)) at 1 by auto using Z.pow_nonzero with omega. + f_equal. ring. } + autorewrite with zsimplify. + ring. + Qed. + (* Works with Rows.sub-based widesub definition + Proof. intros; cbv [widesub]. wide_op Rows.sub_partitions. Qed. + *) + + Lemma widemul_represents x y : + 0 <= x < 2^k -> + 0 <= y < 2^k -> + represents (widemul x y) (x * y). + Proof. + intros; cbv [widemul]. + assert (0 <= x * y < 2^k*2^k) by auto with zarith. + wide_op BaseConversion.widemul_correct. + Qed. + + Definition mul_high (a b : list Z) a0b1 : list Z := + dlet_nd a0b0 := widemul (low a) (low b) in + dlet_nd ab := wideadd [high a0b0; high b] [low b; 0] in + wideadd ab [a0b1; 0]. + + Lemma mul_high_idea d a b a0 a1 b0 b1 : + d <> 0 -> + a = d * a1 + a0 -> + b = d * b1 + b0 -> + (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1. + Proof. + intros. subst a b. autorewrite with push_Zmul. + ring_simplify_subterms. rewrite Z.pow_2_r. + rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega). + repeat match goal with + | |- context [d * ?a * ?b * ?c] => + replace (d * a * b * c) with (a * b * c * d) by ring + | |- context [d * ?a * ?b] => + replace (d * a * b) with (a * b * d) by ring + end. + rewrite !Z.div_add by omega. + autorewrite with zsimplify. + rewrite (Z.mul_comm a0 b0). + ring_simplify. ring. + Qed. + + Lemma represents_trans t x y: + represents t y -> y = x -> + represents t x. + Proof. congruence. Qed. + + Lemma represents_add x y : + 0 <= x < 2 ^ k -> + 0 <= y < 2 ^ k -> + represents [x;y] (x + 2^k*y). + Proof. + intros; cbv [represents]; autorewrite with zsimplify. + repeat split; (reflexivity || nia). + Qed. + + Lemma represents_small x : + 0 <= x < 2^k -> + represents [x; 0] x. + Proof. + intros. + eapply represents_trans. + { eauto using represents_add with zarith. } + { ring. } + Qed. + + Lemma mul_high_represents a b x y a0b1 : + represents a x -> + represents b y -> + 2^k <= x < 2^(k+1) -> + 0 <= y < 2^(k+1) -> + a0b1 = x mod 2^k * (y / 2^k) -> + represents (mul_high a b a0b1) ((x * y) / 2^k). + Proof. + cbv [mul_high Let_In]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros. + assert (4 <= 2 ^ k) by (transitivity (Z.pow 2 2); auto with zarith). + assert (0 <= x * y / 2^k < 2^k*2^k) by (Z.div_mod_to_quot_rem; nia). + + rewrite mul_high_idea with (a:=x) (b:=y) (a0 := low a) (a1 := high a) (b0 := low b) (b1 := high b) in * + by (push_rep; Z.div_mod_to_quot_rem; lia). + + push_rep. subst a0b1. + assert (y / 2 ^ k < 2) by (apply Z.div_lt_upper_bound; omega). + replace (x / 2 ^ k) with 1 in * by (rewrite Z.div_between_1; lia). + autorewrite with zsimplify_fast in *. + + eapply represents_trans. + { repeat (apply wideadd_represents; + [ | apply represents_small; Z.div_mod_to_quot_rem; nia| ]). + erewrite represents_high; [ | apply widemul_represents; solve [ auto with zarith ] ]. + { apply represents_add; try reflexivity; solve [auto with zarith]. } + { match goal with H : 0 <= ?x + ?y < ?z |- 0 <= ?x < ?z => + split; [ solve [Z.zero_bounds] | ]; + eapply Z.le_lt_trans with (m:= x + y); nia + end. } + { omega. } } + { ring. } + Qed. + + Definition cond_sub1 (a : list Z) y : Z := + dlet_nd maybe_y := Z.zselect (Z.cc_l (high a)) 0 y in + dlet_nd diff := Z.sub_get_borrow_full (2^k) (low a) maybe_y in + fst diff. + + Lemma cc_l_only_bit : forall x s, 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s. + Proof. + cbv [Z.cc_l]; intros. + rewrite Z.div_between_0_if by omega. + break_match; Z.ltb_to_lt; Z.rewrite_mod_small; omega. + Qed. + + Lemma cond_sub1_correct a x y : + represents a x -> + 0 <= x < 2 * y -> + 0 <= y < 2 ^ k -> + cond_sub1 a y = if (x <? 2 ^ k) then x else x - y. + Proof. + intros; cbv [cond_sub1 Let_In]. rewrite Z.zselect_correct. push_rep. + break_match; Z.ltb_to_lt; rewrite cc_l_only_bit in *; try omega; + autorewrite with zsimplify_fast to_div_mod pull_Zmod; auto with zarith. + Qed. + + Definition cond_sub2 x y := Z.add_modulo x 0 y. + Lemma cond_sub2_correct x y : + cond_sub2 x y = if (x <? y) then x else x - y. + Proof. + cbv [cond_sub2]. rewrite Z.add_modulo_correct. + autorewrite with zsimplify_fast. break_match; Z.ltb_to_lt; omega. + Qed. + + Section Defn. + Context (xLow xHigh : Z) (xLow_bounds : 0 <= xLow < 2^k) (xHigh_bounds : 0 <= xHigh < M). + Let xt := [xLow; xHigh]. + Let x := xLow + 2^k * xHigh. + + Lemma x_rep : represents xt x. + Proof. cbv [represents]; subst xt x; autorewrite with cancel_pair zsimplify; repeat split; nia. Qed. + + Lemma x_bounds : 0 <= x < M * 2 ^ k. + Proof. subst x; nia. Qed. + + Definition muSelect := Z.zselect (Z.cc_m (2 ^ k) xHigh) 0 muLow. + + Local Hint Resolve Z.div_nonneg Z.div_lt_upper_bound. + Local Hint Resolve shiftr_represents mul_high_represents widemul_represents widesub_represents + cond_sub1_correct cond_sub2_correct represents_low represents_add. + + Lemma muSelect_correct : + muSelect = (2 ^ (2 * k) / M) mod 2 ^ k * ((x / 2 ^ (k - 1)) / 2 ^ k). + Proof. + (* assertions to help arith tactics *) + pose proof x_bounds. + assert (2^k * M < 2 ^ (2*k)) by (rewrite <-Z.add_diag, Z.pow_add_r; nia). + assert (0 <= x / (2 ^ k * (2 ^ k / 2)) < 2) by (Z.div_mod_to_quot_rem; auto with nia). + assert (0 < 2 ^ k / 2) by Z.zero_bounds. + assert (2 ^ (k - 1) <> 0) by auto with zarith. + assert (2 < 2 ^ k) by (eapply Z.le_lt_trans with (m:=2 ^ 1); auto with zarith). + + cbv [muSelect]. rewrite <-muLow_eq. + rewrite Z.zselect_correct, Z.cc_m_eq by auto with zarith. + replace xHigh with (x / 2^k) by (subst x; autorewrite with zsimplify; lia). + autorewrite with pull_Zdiv push_Zpow. + rewrite (Z.mul_comm (2 ^ k / 2)). + break_match; [ ring | ]. + match goal with H : 0 <= ?x < 2, H' : ?x <> 0 |- _ => replace x with 1 by omega end. + autorewrite with zsimplify; reflexivity. + Qed. + + Lemma mu_rep : represents [muLow; 1] (2 ^ (2 * k) / M). + Proof. rewrite <-muLow_eq. eapply represents_trans; auto with zarith. Qed. + + Derive barrett_reduce + SuchThat (barrett_reduce = x mod M) + As barrett_reduce_correct. + Proof. + erewrite <-reduce_correct with (rep:=represents) (muSelect:=muSelect) (k0:=k) (mut:=[muLow;1]) (xt0:=xt) + by (auto using x_bounds, muSelect_correct, x_rep, mu_rep; omega). + subst barrett_reduce. reflexivity. + Qed. + End Defn. + End BarrettReduction. + + (* all the list operations from for_reification.ident *) + Strategy 100 [length seq repeat combine map flat_map partition app rev fold_right update_nth nth_default ]. + Strategy -10 [barrett_reduce reduce]. + + Derive barrett_red_gen + SuchThat (forall (k M muLow : Z) + (n nout: nat) + (xLow xHigh : Z), + Interp (t:=reify_type_of barrett_reduce) + barrett_red_gen k M muLow n nout xLow xHigh + = barrett_reduce k M muLow n nout xLow xHigh) + As barrett_red_gen_correct. + Proof. Time cache_reify (). Time Qed. (* Now only takes ~5-10 s, because we set up [Strategy] commands correctly *) + Module Export ReifyHints. + Global Hint Extern 1 (_ = barrett_reduce _ _ _ _ _ _ _) => simple apply barrett_red_gen_correct : reify_gen_cache. + End ReifyHints. + + Section rbarrett_red. + Context (M : Z) + (machine_wordsize : Z). + + Let bound := Some r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. + Let mu := (2 ^ (2 * machine_wordsize)) / M. + Let muLow := mu mod (2 ^ machine_wordsize). + Let consts_list := [M; muLow]. + + Definition relax_zrange_of_machine_wordsize' + := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z. + (* TODO: This is a special-case hack to let the prefancy pass have enough bounds information. *) + Definition relax_zrange_of_machine_wordsize r : option zrange := + if (lower r =? 0) && (upper r =? 2) + then Some r + else relax_zrange_of_machine_wordsize' r. + + Lemma relax_zrange_good (r r' z : zrange) : + (z <=? r)%zrange = true -> + relax_zrange_of_machine_wordsize r = Some r' -> (z <=? r')%zrange = true. + Proof. + cbv [relax_zrange_of_machine_wordsize]; break_match; [congruence|]. + eauto using relax_zrange_gen_good. + Qed. + + Local Arguments relax_zrange_of_machine_wordsize / . + + Let relax_zrange := relax_zrange_of_machine_wordsize. + + Definition check_args {T} (res : Pipeline.ErrorT T) + : Pipeline.ErrorT T + := if (mu / (2 ^ machine_wordsize) =? 0) + then Error (Pipeline.Values_not_provably_distinctZ "mu / 2 ^ k ≠ 0" (mu / 2 ^ machine_wordsize) 0) + else if (machine_wordsize <? 2) + then Error (Pipeline.Value_not_leZ "~ (2 <=k)" 2 machine_wordsize) + else if (negb (Z.log2 M + 1 =? machine_wordsize)) + then Error + (Pipeline.Values_not_provably_equalZ "log2(M)+1 != k" (Z.log2 M + 1) machine_wordsize) + else if (2 ^ (machine_wordsize + 1) - mu <? 2 * (2 ^ (2 * machine_wordsize) mod M)) + then Error + (Pipeline.Value_not_leZ "~ (2 * (2 ^ (2*k) mod M) <= 2^(k + 1) - mu)" + (2 * (2 ^ (2*machine_wordsize) mod M)) + (2^(machine_wordsize + 1) - mu)) + else res. + + Let fancy_args + := (Some {| Pipeline.invert_low log2wordsize := invert_low log2wordsize consts_list; + Pipeline.invert_high log2wordsize := invert_high log2wordsize consts_list |}). + + Notation BoundsPipeline_correct in_bounds out_bounds op + := (fun rv (rop : Expr (reify_type_of op)) Hrop + => @Pipeline.BoundsPipeline_correct_trans + false (* subst01 *) fancy_args + relax_zrange + relax_zrange_good + _ + rop + in_bounds + out_bounds + op + Hrop rv) + (only parsing). + + Definition rbarrett_red_correct + := BoundsPipeline_correct + (bound, (bound, tt)) + bound + (barrett_reduce machine_wordsize M muLow 2 2). + + Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). + Definition rbarrett_red_correctT rv : Prop + := type_of_strip_3arrow (@rbarrett_red_correct rv). + End rbarrett_red. +End BarrettReduction. + +Ltac solve_rbarrett_red := solve_rop BarrettReduction.rbarrett_red_correct. +Ltac solve_rbarrett_red_nocache := solve_rop_nocache BarrettReduction.rbarrett_red_correct. + +Module MontgomeryReduction. + Section MontRed'. + Context (N R N' R' : Z). + Context (HN_range : 0 <= N < R) (HN'_range : 0 <= N' < R) (HN_nz : N <> 0) (R_gt_1 : R > 1) + (N'_good : Z.equiv_modulo R (N*N') (-1)) (R'_good: Z.equiv_modulo N (R*R') 1). + + Context (Zlog2R : Z) . + Let w : nat -> Z := weight Zlog2R 1. + Context (n:nat) (Hn_nz: n <> 0%nat) (n_good : Zlog2R mod Z.of_nat n = 0). + Context (R_big_enough : n <= Zlog2R) + (R_two_pow : 2^Zlog2R = R). + Let w_mul : nat -> Z := weight (Zlog2R / n) 1. + Context (nout : nat) (Hnout : nout = 2%nat). + + Definition montred' (lo_hi : (Z * Z)) := + dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R n nout (fst lo_hi) N') 0 in + dlet_nd t1_t2 := (BaseConversion.widemul_inlined_reverse Zlog2R n nout N y) in + dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [fst lo_hi; snd lo_hi] t1_t2 in + dlet_nd y' := Z.zselect (snd sum_carry) 0 N in + dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in + Z.add_modulo (fst lo''_carry) 0 N. + + Local Lemma Hw : forall i, w i = R ^ Z.of_nat i. + Proof. + clear -R_big_enough R_two_pow; cbv [w weight]; intro. + autorewrite with zsimplify. + rewrite Z.pow_mul_r, R_two_pow by omega; reflexivity. + Qed. + + Local Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r, ?Z.pow_1_l in *. + Local Ltac solve_range := + repeat match goal with + | _ => progress change_weight + | |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega)) + | |- 0 <= _ => progress Z.zero_bounds + | |- 0 <= _ * _ < _ * _ => + split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ] + | _ => solve [auto] + | _ => omega + end. + + Local Lemma eval2 x y : eval w 2 [x;y] = x + R * y. + Proof. cbn. change_weight. ring. Qed. + + Hint Rewrite BaseConversion.widemul_inlined_reverse_correct BaseConversion.widemul_inlined_correct + using (autorewrite with widemul push_nth_default; solve [solve_range]) : widemul. + + Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) + (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): + montred' lo_hi = reduce_via_partial N R N' T. + Proof. + rewrite <-reduce_via_partial_alt_eq by nia. + cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. + rewrite Hlo, Hhi. + assert (0 <= (T mod R) * N' < w 2) by (solve_range). + + autorewrite with widemul. + rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega). + rewrite R_two_pow. + cbv [Rows.partition seq]. rewrite !eval2. + autorewrite with push_nth_default push_map. + autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. + change_weight. + + (* pull out value before last modular reduction *) + match goal with |- (if (?n <=? ?x)%Z then ?x - ?n else ?x) = (if (?n <=? ?y) then ?y - ?n else ?y)%Z => + let P := fresh "H" in assert (x = y) as P; [|rewrite P; reflexivity] end. + + autorewrite with zsimplify. + rewrite (Z.mul_comm (((T mod R) * N') mod R) N) in *. + break_match; try reflexivity; Z.ltb_to_lt; rewrite Z.div_small_iff in * by omega; + repeat match goal with + | _ => progress autorewrite with zsimplify_fast + | |- context [?x mod (R * R)] => + unique pose proof (Z.mod_pos_bound x (R * R)); + try rewrite (Z.mod_small x (R * R)) in * by Z.rewrite_mod_small_solver + | _ => omega + | _ => progress Z.rewrite_mod_small + end. + Qed. + + Lemma montred'_correct lo_hi T (HT_range: 0 <= T < R * N) + (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): montred' lo_hi = (T * R') mod N. + Proof. + erewrite montred'_eq by eauto. + apply Z.equiv_modulo_mod_small; auto using reduce_via_partial_correct. + replace 0 with (Z.min 0 (R-N)) by (apply Z.min_l; omega). + apply reduce_via_partial_in_range; omega. + Qed. + End MontRed'. + + Derive montred_gen + SuchThat (forall (N R N' : Z) + (Zlog2R : Z) + (n nout: nat) + (lo_hi : Z * Z), + Interp (t:=reify_type_of montred') + montred_gen N R N' Zlog2R n nout lo_hi + = montred' N R N' Zlog2R n nout lo_hi) + As montred_gen_correct. + Proof. Time cache_reify (). Time Qed. + Module Export ReifyHints. + Global Hint Extern 1 (_ = montred' _ _ _ _ _ _ _) => simple apply montred_gen_correct : reify_gen_cache. + End ReifyHints. + + Section rmontred. + Context (N R N' : Z) + (machine_wordsize : Z). + + Let bound := Some r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. + Let consts_list := [N; N']. + + Definition relax_zrange_of_machine_wordsize + := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z. + Local Arguments relax_zrange_of_machine_wordsize / . + + Let relax_zrange := relax_zrange_of_machine_wordsize. + + Definition check_args {T} (res : Pipeline.ErrorT T) + : Pipeline.ErrorT T + := res. (* TODO: this should actually check stuff that corresponds with preconditions of montred'_correct *) + + Let fancy_args + := (Some {| Pipeline.invert_low log2wordsize := invert_low log2wordsize consts_list; + Pipeline.invert_high log2wordsize := invert_high log2wordsize consts_list |}). + + Notation BoundsPipeline_correct in_bounds out_bounds op + := (fun rv (rop : Expr (reify_type_of op)) Hrop + => @Pipeline.BoundsPipeline_correct_trans + false (* subst01 *) fancy_args + relax_zrange + (relax_zrange_gen_good _) + _ + rop + in_bounds + out_bounds + op + Hrop rv) + (only parsing). + + Definition rmontred_correct + := BoundsPipeline_correct + ((bound, bound), tt) + bound + (montred' N R N' (Z.log2 R) 2 2). + + Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). + Definition rmontred_correctT rv : Prop + := type_of_strip_3arrow (@rmontred_correct rv). + End rmontred. +End MontgomeryReduction. + +Ltac solve_rmontred := solve_rop MontgomeryReduction.rmontred_correct. +Ltac solve_rmontred_nocache := solve_rop_nocache MontgomeryReduction.rmontred_correct. + + +Time Compute + (Pipeline.BoundsPipeline + true None (relax_zrange_gen [64; 128]) + ltac:(let r := Reify (to_associational (weight 51 1) 5) in + exact r) + (Some (repeat (@None _) 5), tt) + ZRange.type.base.option.None). + +Time Compute + (Pipeline.BoundsPipeline + true None (relax_zrange_gen [64; 128]) + ltac:(let r := Reify (scmul (weight 51 1) 5) in + exact r) + (None, (Some (repeat (@None _) 5), tt)) + ZRange.type.base.option.None). |