From 0774eb4535eff89d0fd4eba3bc4c4f89864812b1 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 1 Feb 2019 17:21:39 -0500 Subject: Reify most rewrite rules Currently we don't handle rules that require "concrete list of cons cells" nor rules that require "concrete nat literal to do recursion on"; both of these require special treatment, to be implemented in an upcoming commit. The reifier is kind-of slow, alas. After | File Name | Before || Change | % Change -------------------------------------------------------------------------------------------- 21m52.72s | Total | 21m20.90s || +0m31.82s | +2.48% -------------------------------------------------------------------------------------------- 1m12.22s | Rewriter.vo | 0m47.38s || +0m24.83s | +52.42% 3m14.35s | p384_32.c | 3m19.59s || -0m05.24s | -2.62% 1m45.12s | RewriterRulesGood.vo | 1m39.44s || +0m05.68s | +5.71% 0m40.82s | ExtractionHaskell/unsaturated_solinas | 0m37.58s || +0m03.24s | +8.62% 1m35.04s | RewriterRulesInterpGood.vo | 1m32.48s || +0m02.56s | +2.76% 0m59.49s | ExtractionHaskell/word_by_word_montgomery | 0m56.71s || +0m02.78s | +4.90% 1m45.10s | Fancy/Barrett256.vo | 1m47.00s || -0m01.90s | -1.77% 0m40.21s | p521_64.c | 0m38.97s || +0m01.24s | +3.18% 0m24.42s | SlowPrimeSynthesisExamples.vo | 0m25.67s || -0m01.25s | -4.86% 0m20.48s | secp256k1_32.c | 0m19.27s || +0m01.21s | +6.27% 1m47.10s | RewriterWf2.vo | 1m47.22s || -0m00.12s | -0.11% 0m47.85s | p521_32.c | 0m47.47s || +0m00.38s | +0.80% 0m45.66s | RewriterInterpProofs1.vo | 0m45.74s || -0m00.08s | -0.17% 0m37.18s | Fancy/Montgomery256.vo | 0m37.14s || +0m00.03s | +0.10% 0m36.26s | PushButtonSynthesis/UnsaturatedSolinas.vo | 0m36.14s || +0m00.11s | +0.33% 0m28.38s | ExtractionHaskell/saturated_solinas | 0m28.04s || +0m00.33s | +1.21% 0m26.80s | PushButtonSynthesis/WordByWordMontgomery.vo | 0m26.83s || -0m00.02s | -0.11% 0m24.00s | RewriterWf1.vo | 0m23.96s || +0m00.03s | +0.16% 0m19.82s | ExtractionOCaml/word_by_word_montgomery | 0m20.34s || -0m00.51s | -2.55% 0m19.65s | p256_32.c | 0m19.24s || +0m00.41s | +2.13% 0m18.54s | p448_solinas_64.c | 0m19.24s || -0m00.69s | -3.63% 0m16.18s | p434_64.c | 0m16.49s || -0m00.30s | -1.87% 0m13.67s | ExtractionOCaml/word_by_word_montgomery.ml | 0m14.47s || -0m00.80s | -5.52% 0m13.22s | ExtractionOCaml/unsaturated_solinas | 0m13.72s || -0m00.50s | -3.64% 0m09.91s | p224_32.c | 0m09.89s || +0m00.01s | +0.20% 0m09.85s | ExtractionOCaml/saturated_solinas | 0m09.83s || +0m00.01s | +0.20% 0m08.66s | ExtractionHaskell/word_by_word_montgomery.hs | 0m07.86s || +0m00.79s | +10.17% 0m08.36s | p384_64.c | 0m08.46s || -0m00.10s | -1.18% 0m07.72s | ExtractionOCaml/unsaturated_solinas.ml | 0m08.29s || -0m00.56s | -6.87% 0m06.66s | BoundsPipeline.vo | 0m06.73s || -0m00.07s | -1.04% 0m06.50s | ExtractionHaskell/unsaturated_solinas.hs | 0m06.56s || -0m00.05s | -0.91% 0m06.22s | ExtractionOCaml/saturated_solinas.ml | 0m05.69s || +0m00.52s | +9.31% 0m05.50s | ExtractionHaskell/saturated_solinas.hs | 0m05.88s || -0m00.37s | -6.46% 0m03.51s | PushButtonSynthesis/Primitives.vo | 0m03.48s || +0m00.02s | +0.86% 0m03.34s | PushButtonSynthesis/SmallExamples.vo | 0m03.32s || +0m00.02s | +0.60% 0m03.31s | curve25519_32.c | 0m03.13s || +0m00.18s | +5.75% 0m03.15s | PushButtonSynthesis/SaturatedSolinas.vo | 0m03.12s || +0m00.02s | +0.96% 0m03.11s | PushButtonSynthesis/BarrettReduction.vo | 0m03.08s || +0m00.02s | +0.97% 0m02.80s | PushButtonSynthesis/MontgomeryReduction.vo | 0m02.86s || -0m00.06s | -2.09% 0m02.13s | curve25519_64.c | 0m02.02s || +0m00.10s | +5.44% 0m01.58s | p256_64.c | 0m01.66s || -0m00.07s | -4.81% 0m01.53s | p224_64.c | 0m01.60s || -0m00.07s | -4.37% 0m01.49s | secp256k1_64.c | 0m01.72s || -0m00.23s | -13.37% 0m01.34s | CLI.vo | 0m01.23s || +0m00.11s | +8.94% 0m01.16s | RewriterProofs.vo | 0m01.10s || +0m00.05s | +5.45% 0m01.16s | StandaloneOCamlMain.vo | 0m01.13s || +0m00.03s | +2.65% 0m01.10s | CompilersTestCases.vo | 0m01.09s || +0m00.01s | +0.91% 0m01.07s | StandaloneHaskellMain.vo | 0m01.04s || +0m00.03s | +2.88% --- src/Rewriter.v | 1694 ++++++++++++++++++++++++++++++----------- src/RewriterRulesInterpGood.v | 5 + 2 files changed, 1264 insertions(+), 435 deletions(-) diff --git a/src/Rewriter.v b/src/Rewriter.v index c39b2fe9f..1aefc35e8 100644 --- a/src/Rewriter.v +++ b/src/Rewriter.v @@ -5,9 +5,11 @@ Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.ListUtil.FoldBool Require Import Crypto.Util.Option. Require Import Crypto.Util.OptionList. Require Import Crypto.Util.CPSNotations. +Require Import Crypto.Util.Bool.Reflect. Require Import Crypto.Util.ZRange. Require Import Crypto.Util.ZRange.Operations. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.Notations. Require Crypto.Util.PrimitiveProd. Require Crypto.Util.PrimitiveHList. Require Import Crypto.Language. @@ -72,6 +74,22 @@ Module Compilers. => @subst_default_relax (fun t => P (Compilers.base.type.option t)) A evm end. + Fixpoint unsubst_default_relax P {t evm} : P (subst_default (relax t) evm) -> P t + := match t return P (subst_default (relax t) evm) -> P t with + | Compilers.base.type.type_base t => fun x => x + | Compilers.base.type.prod A B + => fun v + => @unsubst_default_relax + (fun A' => P (Compilers.base.type.prod A' _)) A evm + (@unsubst_default_relax + (fun B' => P (Compilers.base.type.prod _ B')) B evm + v) + | Compilers.base.type.list A + => @unsubst_default_relax (fun t => P (Compilers.base.type.list t)) A evm + | Compilers.base.type.option A + => @unsubst_default_relax (fun t => P (Compilers.base.type.option t)) A evm + end. + Fixpoint var_types_of (t : type) : Set := match t with | type.var _ => Compilers.base.type @@ -146,6 +164,18 @@ Module Compilers. v) end. + Fixpoint unsubst_default_relax P {t evm} : P (type.subst_default (type.relax t) evm) -> P t + := match t return P (type.subst_default (type.relax t) evm) -> P t with + | type.base t => base.unsubst_default_relax (fun t => P (type.base t)) + | type.arrow A B + => fun v + => @unsubst_default_relax + (fun A' => P (type.arrow A' _)) A evm + (@unsubst_default_relax + (fun B' => P (type.arrow _ B')) B evm + v) + end. + Fixpoint var_types_of (t : type) : Set := match t with | type.base t => base.var_types_of t @@ -542,6 +572,44 @@ Module Compilers. | expr.LetIn A B x f => expr.LetIn (@reify_expr _ x) (fun xv => @reify_expr _ (f (reflect (expr.Var xv)))) end. + (** N.B. In order to preserve the (currently unstated) + invariant that ensures that the rewriter does enough + rewriting, when we reify rewrite rules, we have to perform β + on the RHS to ensure that there are no var nodes holding + values which show up in the heads of app nodes. *) + Fixpoint reflect_expr_beta {t} (e : @expr.expr base.type ident value t) + : UnderLets (value t) + := match e in expr.expr t return UnderLets (value t) with + | expr.Var t v => UnderLets.Base v + | expr.Abs s d f => UnderLets.Base (fun x : value s => fx <----- @reflect_expr_beta d (f x); Base_value fx) + | expr.App s (type.base d) f x + => f <-- @reflect_expr_beta _ f; + x <-- @reflect_expr_beta _ x; + f x + | expr.App s (type.arrow _ _) f x + => f <-- @reflect_expr_beta _ f; + x <-- @reflect_expr_beta _ x; + UnderLets.Base (f x) + | expr.LetIn A B x f + => x <-- @reflect_expr_beta _ x; + UnderLets.UnderLet + (reify x) + (fun xv => @reflect_expr_beta _ (f (reflect (expr.Var xv)))) + | expr.Ident t idc => UnderLets.Base (reflect (expr.Ident idc)) + end%under_lets. + + Definition reify_to_UnderLets {with_lets} {t} : value' with_lets t -> UnderLets (expr t) + := match t, with_lets return value' with_lets t -> UnderLets (expr t) with + | type.base _, false => fun v => UnderLets.Base v + | type.base _, true => fun v => v + | type.arrow s d, _ + => fun f => UnderLets.Base (reify f) + end. + + Definition reify_expr_beta {t} (e : @expr.expr base.type ident value t) + : UnderLets (@expr.expr base.type ident var t) + := e <-- @reflect_expr_beta t e; reify_to_UnderLets e. + Definition reify_and_let_binds_cps {with_lets} {t} : value' with_lets t -> forall T, (expr t -> UnderLets T) -> UnderLets T := match t, with_lets return value' with_lets t -> forall T, (expr t -> UnderLets T) -> UnderLets T with | type.base _, false => reify_and_let_binds_base_cps _ @@ -643,6 +711,20 @@ Module Compilers. (snd xy) end. + Fixpoint lam_unification_resultT' {var t p evm K} {struct p} + : (@unification_resultT' var t p evm -> K) -> @with_unification_resultT' var t p evm K + := match p return (unification_resultT' p evm -> K) -> with_unification_resultT' p evm K with + | pattern.Wildcard t => fun f x => f x + | pattern.Ident t idc => lam_type_of_list + | pattern.App s d f x + => fun (F : unification_resultT' f _ * unification_resultT' x _ -> _) + => @lam_unification_resultT' + _ _ f _ _ + (fun fv + => @lam_unification_resultT' + _ _ x _ _ (fun xv => F (fv, xv))) + end. + (** TODO: Maybe have a fancier version of this that doesn't actually need to insert casts, by doing a fixpoint on the list of elements / the evar map *) @@ -696,6 +778,19 @@ Module Compilers. (fun fx => k (existT _ _ fx)))%option. + Definition lam_unification_resultT {var' t p K} + : (forall v : @unification_resultT var' t p, K (pattern.type.subst_default t (projT1 v))) -> @with_unification_resultT var' t p K + := fun f + => pattern.type.lam_forall_vars + (fun evm + => lam_unification_resultT' + (K:=K (pattern.type.subst_default t evm)) + (fun x' => f (existT (unification_resultT' p) evm x'))). + + Definition partial_lam_unification_resultT {var' t p K} + : (forall evm, @with_unification_resultT' var' t p evm (K (pattern.type.subst_default t evm))) -> @with_unification_resultT var' t p K + := pattern.type.lam_forall_vars. + Definition under_with_unification_resultT {var t p K1 K2} (F : forall evm, K1 (pattern.type.subst_default t evm) -> K2 (pattern.type.subst_default t evm)) : @with_unification_resultT var t p K1 -> @with_unification_resultT var t p K2 @@ -913,8 +1008,11 @@ Module Compilers. end end%option. + Local Notation expr_maybe_do_again should_do_again + := (@expr.expr base.type ident (if should_do_again then value else var)). + Local Notation deep_rewrite_ruleTP_gen' should_do_again with_opt under_lets t - := (match (@expr.expr base.type ident (if should_do_again then value else var) t) with + := (match (expr_maybe_do_again should_do_again t) with | x0 => match (if under_lets then UnderLets x0 else x0) with | x1 => if with_opt then option x1 else x1 end @@ -933,9 +1031,20 @@ Module Compilers. | false, false => fun x => Some (UnderLets.Base x) end%cps. + Definition with_unif_rewrite_ruleTP_gen' {var t} (p : pattern t) (should_do_again : bool) (with_opt : bool) (under_lets : bool) evm + := @with_unification_resultT' var t p evm (deep_rewrite_ruleTP_gen' should_do_again with_opt under_lets (pattern.type.subst_default t evm)). + Definition with_unif_rewrite_ruleTP_gen {var t} (p : pattern t) (should_do_again : bool) (with_opt : bool) (under_lets : bool) := @with_unification_resultT var t p (fun t => deep_rewrite_ruleTP_gen' should_do_again with_opt under_lets t). + Definition lam_unif_rewrite_ruleTP_gen {var t} (p : pattern t) (should_do_again : bool) (with_opt : bool) (under_lets : bool) + : _ -> @with_unif_rewrite_ruleTP_gen var t p should_do_again with_opt under_lets + := lam_unification_resultT. + + Definition partial_lam_unif_rewrite_ruleTP_gen {var t} (p : pattern t) (should_do_again : bool) (with_opt : bool) (under_lets : bool) + : (forall evm, @with_unif_rewrite_ruleTP_gen' var t p should_do_again with_opt under_lets evm) -> @with_unif_rewrite_ruleTP_gen var t p should_do_again with_opt under_lets + := partial_lam_unification_resultT. + Record rewrite_rule_data {t} {p : pattern t} := { rew_should_do_again : bool; rew_with_opt : bool; @@ -1209,6 +1318,559 @@ Module Compilers. := fun var => @rewrite var (rewrite_head var) fuel t (e _). End Compile. + Module Reify. + Import Compile. + Local Notation EvarMap := pattern.EvarMap. + + Inductive dynlist := dynnil | dyncons {T} (x : T) (xs : dynlist). + + Section with_var. + Local Notation type_of_list + := (fold_right (fun a b => prod a b) unit). + Local Notation type_of_list_cps + := (fold_right (fun a K => a -> K)). + Context {ident var : type.type base.type -> Type} + {pident : type.type pattern.base.type -> Type} + (pident_arg_types : forall t, pident t -> list Type) + (pident_type_of_list_arg_types_beq : forall t idc, type_of_list (pident_arg_types t idc) -> type_of_list (pident_arg_types t idc) -> bool) + (pident_of_typed_ident : forall t, ident t -> pident (pattern.type.relax t)) + (pident_arg_types_of_typed_ident : forall t (idc : ident t), type_of_list (pident_arg_types _ (pident_of_typed_ident t idc))). + + Local Notation type := (type.type base.type). + Local Notation expr := (@expr.expr base.type ident var). + Local Notation pattern := (@pattern.pattern pident). + Local Notation ptype := (type.type pattern.base.type). + Local Notation value := (@Compile.value base.type ident var). + Local Notation value_with_lets := (@Compile.value_with_lets base.type ident var). + Local Notation UnderLets := (UnderLets.UnderLets base.type ident var). + Local Notation reify_expr_beta := (@reify_expr_beta ident var). + Local Notation unification_resultT' := (@unification_resultT' pident pident_arg_types). + Local Notation with_unif_rewrite_ruleTP_gen' := (@with_unif_rewrite_ruleTP_gen' ident var pident pident_arg_types value). + Local Notation lam_unification_resultT' := (@lam_unification_resultT' pident pident_arg_types). + + Local Notation expr_maybe_do_again should_do_again + := (@expr.expr base.type ident (if should_do_again then value else var)). + + Fixpoint pattern_of_expr (var' := fun _ => positive) evm (invalid : forall t, @expr.expr base.type ident var' t -> { p : pattern (pattern.type.relax t) & @unification_resultT' var' _ p evm }) + {t} (e : @expr.expr base.type ident var' t) + : { p : pattern (pattern.type.relax t) & @unification_resultT' var' _ p evm } + := match e in expr.expr t return { p : pattern (pattern.type.relax t) & @unification_resultT' var' _ p evm } with + | expr.Ident t idc + => existT _ (pattern.Ident (pident_of_typed_ident _ idc)) + (pident_arg_types_of_typed_ident _ idc) + | expr.Var t v + => existT _ (pattern.Wildcard _) v + | expr.App s d f x + => let 'existT fp fv := @pattern_of_expr evm invalid _ f in + let 'existT xp xv := @pattern_of_expr evm invalid _ x in + existT _ (pattern.App fp xp) + (fv, xv) + | expr.Abs _ _ _ as e + | expr.LetIn _ _ _ _ as e + => invalid _ e + end. + + Definition expr_value_to_rewrite_rule_replacement (should_do_again : bool) {t} (e : @expr.expr base.type ident value t) + : UnderLets (expr_maybe_do_again should_do_again t) + := (e <-- UnderLets.flat_map (@reify_expr_beta) (fun t v => reflect (expr.Var v)) (UnderLets.of_expr e); + if should_do_again return UnderLets (expr_maybe_do_again should_do_again t) + then UnderLets.Base e + else reify_expr_beta e)%under_lets. + + Fixpoint pair'_unification_resultT' {evm t p} + : @unification_resultT' (fun _ => positive) t p evm -> @unification_resultT' value t p evm -> PositiveMap.t { t : _ & value t } * list bool -> PositiveMap.t { t : _ & value t } * list bool + := match p return @unification_resultT' (fun _ => positive) _ p evm -> @unification_resultT' value _ p evm -> PositiveMap.t { t : _ & value t } * list bool -> PositiveMap.t { t : _ & value t } * list bool with + | pattern.Wildcard t => fun p v '(m, l) => (PositiveMap.add p (existT value _ v) m, l) + | pattern.Ident t idc => fun v1 v2 '(m, l) => (m, pident_type_of_list_arg_types_beq t idc v2 v1 :: l) + | pattern.App _ _ F X + => fun x y '(m, l) + => let '(m, l) := @pair'_unification_resultT' _ _ F (fst x) (fst y) (m, l) in + let '(m, l) := @pair'_unification_resultT' _ _ X (snd x) (snd y) (m, l) in + (m, l) + end. + + Fixpoint expr_pos_to_expr_value + (invalid : forall t, positive * type * PositiveMap.t { t : _ & value t } -> @expr.expr base.type ident value t) + {t} (e : @expr.expr base.type ident (fun _ => positive) t) (m : PositiveMap.t { t : _ & value t }) (cur_i : positive) + : @expr.expr base.type ident value t + := match e in expr.expr t return expr.expr t with + | expr.Ident t idc => expr.Ident idc + | expr.App s d f x + => expr.App (@expr_pos_to_expr_value invalid _ f m cur_i) + (@expr_pos_to_expr_value invalid _ x m cur_i) + | expr.Var t v + => match option_map + (fun tv => type.try_transport base.try_make_transport_cps value _ t (projT2 tv)) + (PositiveMap.find v m) with + | Some (Some res) => expr.Var res + | Some None | None => invalid _ (v, t, m) + end + | expr.Abs s d f + => expr.Abs (fun v => @expr_pos_to_expr_value invalid _ (f cur_i) (PositiveMap.add cur_i (existT value _ v) m) (Pos.succ cur_i)) + | expr.LetIn A B x f + => expr.LetIn (@expr_pos_to_expr_value invalid _ x m cur_i) + (fun v => @expr_pos_to_expr_value invalid _ (f cur_i) (PositiveMap.add cur_i (existT value _ v) m) (Pos.succ cur_i)) + end. + + Definition expr_to_pattern_and_replacement + (should_do_again : bool) evm + (invalid : forall A B, A -> B) + {t} (lhs rhs : @expr.expr base.type ident (fun _ => positive) t) + (side_conditions : list bool) + : { p : pattern (pattern.type.relax t) & @with_unif_rewrite_ruleTP_gen' _ p should_do_again true true evm } + := let (p, unif_data_lhs) := @pattern_of_expr evm (fun _ => invalid _ _) t lhs in + existT + _ + p + (pattern.type.subst_default_relax + (fun t' + => with_unification_resultT' + pident_arg_types p evm + (option (UnderLets (expr_maybe_do_again should_do_again t')))) + (lam_unification_resultT' + (fun unif_data_new + => let '(value_map, side_conditions) := pair'_unification_resultT' unif_data_lhs unif_data_new (PositiveMap.empty _, side_conditions) in + let start_i := Pos.succ (PositiveMap.fold (fun k _ max => Pos.max k max) value_map 1%positive) in + let replacement := expr_pos_to_expr_value (fun _ => invalid _ _) rhs value_map start_i in + let replacement := expr_value_to_rewrite_rule_replacement should_do_again replacement in + if fold_left andb (List.rev side_conditions) true + then Some replacement + else None))). + + + Definition expr_to_pattern_and_replacement_unfolded should_do_again evm invalid {t} lhs rhs side_conditions + := Eval cbv beta iota delta [expr_to_pattern_and_replacement pattern_of_expr lam_unification_resultT' Pos.succ pair'_unification_resultT' PositiveMap.empty PositiveMap.fold Pos.max expr_pos_to_expr_value expr_value_to_rewrite_rule_replacement fold_left List.rev List.app value PositiveMap.add PositiveMap.xfoldi Pos.compare Pos.compare_cont FMapPositive.append projT1 projT2 PositiveMap.find Base_value (*UnderLets.map reify_expr_beta reflect_expr_beta*) lam_type_of_list fold_right list_rect pattern.type.relax pattern.type.subst_default pattern.type.subst_default_relax pattern.type.unsubst_default_relax option_map unification_resultT' with_unification_resultT' with_unif_rewrite_ruleTP_gen'] + in @expr_to_pattern_and_replacement should_do_again evm invalid t lhs rhs side_conditions. + + Definition partial_lam_unif_rewrite_ruleTP_gen_unfolded should_do_again {t} p + := Eval cbv beta iota delta [partial_lam_unif_rewrite_ruleTP_gen pattern.collect_vars pattern.type.lam_forall_vars partial_lam_unification_resultT pattern.type.collect_vars pattern.base.collect_vars PositiveSet.union PositiveSet.add PositiveSet.empty pattern.type.lam_forall_vars_gen List.rev PositiveSet.elements PositiveSet.xelements PositiveSet.rev PositiveSet.rev_append List.app orb fold_right PositiveMap.add PositiveMap.empty] + in @partial_lam_unif_rewrite_ruleTP_gen ident var pident pident_arg_types value t p should_do_again true true. + End with_var. + + Ltac strip_functional_dependency term := + lazymatch term with + | fun _ => ?P => P + | _ => let __ := match goal with _ => idtac "Cannot eliminate functional dependencies of" term; fail 1 "Cannot eliminate functional dependencies of" term end in + constr:(I : I) + end. + + Ltac reify_under_forall_types' ty_ctx cur_i lem cont := + lazymatch lem with + | forall T : Type, ?P + => let P' := fresh in + let ty_ctx' := fresh "ty_ctx" in + let t := fresh "t" in + strip_functional_dependency + (fun t : Compilers.base.type + => match PositiveMap.add cur_i t ty_ctx return _ with + | ty_ctx' + => match Compilers.base.interp (pattern.base.lookup_default cur_i ty_ctx') return _ with + | T + => match P return _ with + | P' + => ltac:(let P := (eval cbv delta [P' T ty_ctx'] in P') in + let ty_ctx := (eval cbv delta [ty_ctx'] in ty_ctx') in + clear P' T ty_ctx'; + let cur_i := (eval vm_compute in (Pos.succ cur_i)) in + let res := reify_under_forall_types' ty_ctx cur_i P cont in + exact res) + end + end + end) + | ?lem => cont ty_ctx cur_i lem + end. + + Ltac prop_to_bool H := eval cbv [decb] in (decb H). + + + Ltac push_side_conditions H side_conditions := + constr:(H :: side_conditions). + + Ltac equation_to_parts' lem side_conditions := + lazymatch lem with + | ?H -> ?P + => let H := prop_to_bool H in + let side_conditions := push_side_conditions H side_conditions in + equation_to_parts' P side_conditions + | forall x : ?T, ?P + => let P' := fresh in + constr:( + fun x : T + => match P return _ with + | P' + => ltac:(let P := (eval cbv delta [P'] in P') in + clear P'; + let res := equation_to_parts' P side_conditions in + exact res) + end) + | @eq ?T ?A ?B + => constr:((@eq T A B, side_conditions)) + | ?T => let __ := match goal with _ => fail 1 "Invalid type of equation:" T end in + constr:(I : I) + end. + Ltac equation_to_parts lem := + equation_to_parts' lem (@nil bool). + + Ltac reify_under_forall_types lem cont := + reify_under_forall_types' (@PositiveMap.empty Compilers.base.type) (1%positive) lem cont. + + Ltac preadjust_pattern_type_variables pat := + let pat := (eval cbv [pattern.type.relax pattern.type.subst_default pattern.type.subst_default_relax pattern.type.unsubst_default_relax] in pat) in + let pat := (eval cbn [pattern.base.relax pattern.base.subst_default pattern.base.subst_default_relax pattern.base.unsubst_default_relax] in pat) in + pat. + + Ltac adjust_pattern_type_variables' pat := + lazymatch pat with + | context[pattern.base.relax (pattern.base.lookup_default ?p ?evm')] + => let t := constr:(pattern.base.relax (pattern.base.lookup_default p evm')) in + let T := fresh in + let pat := + lazymatch (eval pattern t in pat) with + | ?pat _ + => let P := match type of pat with forall x, @?P x => P end in + lazymatch pat with + | fun T => ?pat + => constr:(match pattern.base.type.var p as T return P T with + | T => pat + end) + end + end in + adjust_pattern_type_variables' pat + | ?pat => pat + end. + + Ltac adjust_pattern_type_variables pat := + let pat := preadjust_pattern_type_variables pat in + adjust_pattern_type_variables' pat. + + Ltac strip_invalid_or_fail term := + lazymatch term with + | fun _ => ?f => f + | fun invalid : ?T => ?f + => let f' := fresh in + constr:(fun invalid : T + => match f return _ with + | f' + => ltac:(lazymatch (eval cbv [f'] in f') with + | context[invalid _ _ ?x] + => fail 0 "Invalid:" x + | context[invalid _ ?x] + => fail 0 "Invalid:" x + end) + end) + end. + + Definition pattern_base_subst_default_relax' t evm P + := @pattern.base.subst_default_relax P t evm. + Definition pattern_base_unsubst_default_relax' t evm P + := @pattern.base.unsubst_default_relax P t evm. + + Ltac change_pattern_base_subst_default_relax term := + lazymatch (eval pattern (@pattern.base.subst_default_relax), (@pattern.base.unsubst_default_relax) in term) with + | ?f _ _ + => let P := fresh "P" in + let t := fresh "t" in + let evm := fresh "evm" in + (eval cbv beta in (f (fun P t evm => @pattern_base_subst_default_relax' t evm P) (fun P t evm => @pattern_base_unsubst_default_relax' t evm P))) + end. + + Ltac adjust_lookup_default rewr := + lazymatch (eval pattern pattern.base.lookup_default in rewr) with + | ?rewr _ + => let p := fresh "p" in + let evm := fresh "evm" in + (eval cbv beta in (rewr (fun p evm => pattern.base.subst_default (pattern.base.type.var p) evm))) + end. + + Ltac replace_evar_map evm rewr := + let evm' := match rewr with + | context[pattern.base.lookup_default _ ?evm'] + => let __ := match goal with _ => tryif constr_eq evm evm' then fail else idtac end in + evm' + | context[pattern.base.subst_default _ ?evm'] + => let __ := match goal with _ => tryif constr_eq evm evm' then fail else idtac end in + evm' + | _ => tt + end in + lazymatch evm' with + | tt => rewr + | _ + => let rewr := lazymatch (eval pattern evm' in rewr) with + | ?rewr _ => (eval cbv beta in (rewr evm)) + end in + replace_evar_map evm rewr + end. + + Ltac adjust_type_variables rewr := + lazymatch rewr with + | context[pattern.base.subst_default (pattern.base.relax ?t) ?evm''] + => let t' := constr:(pattern.base.subst_default (pattern.base.relax t) evm'') in + let rewr := + lazymatch (eval pattern + t', + (pattern_base_subst_default_relax' t evm''), + (pattern_base_unsubst_default_relax' t evm'') + in rewr) + with + | ?rewr _ _ _ + => (eval cbv beta in (rewr t (fun P x => x) (fun P x => x))) + end in + adjust_type_variables rewr + | _ => rewr + end. + + Ltac replace_type_try_transport term := + lazymatch term with + | context[@type.try_transport ?base_type ?try_make_transport_base_type_cps ?P ?t ?t] + => let v := constr:(@type.try_transport base_type try_make_transport_base_type_cps P t t) in + let term := lazymatch (eval pattern v in term) with + | ?term _ => (eval cbv beta in (term (@Some _))) + end in + replace_type_try_transport term + | _ => term + end. + + Ltac under_binders term cont ctx := + lazymatch term with + | (fun x : ?T => ?f) + => let ctx' := fresh in + let f' := fresh in + constr:(fun x : T + => match f, dyncons x ctx return _ with + | f', ctx' + => ltac:(let ctx := (eval cbv delta [ctx'] in ctx') in + let f := (eval cbv delta [f'] in f') in + clear f' ctx'; + let res := under_binders f cont ctx in + exact res) + end) + | ?term => cont ctx term + end. + Ltac substitute_with term x y := + lazymatch (eval pattern y in term) with + | (fun z => ?term) _ => constr:(match x return _ with z => term end) + end. + Ltac substitute_beq_with full_ctx term beq x := + let is_good y := + lazymatch full_ctx with + | context[dyncons y _] => fail + | _ => is_var y + end in + let y := match term with + | context term' [beq x ?y] + => let __ := is_good y in + constr:(Some (beq x y)) + | context term' [@base.interp_beq ?t x ?y] + => let __ := is_good y in + constr:(Some (@base.interp_beq t x y)) + | context term' [@base.base_interp_beq ?t x ?y] + => let __ := is_good y in + constr:(Some (@base.base_interp_beq t x y)) + | _ => constr:(@None unit) + end in + lazymatch y with + | Some (?beq x ?y) + => lazymatch term with + | context term'[beq x y] + => let term := context term'[true] in + substitute_with term x y + end + | None => term + end. + Ltac remove_andb_true term := + let term := lazymatch (eval pattern andb, (andb true) in term) with + | ?f _ _ => (eval cbn [andb] in (f (fun x y => andb y x) (fun b => b))) + end in + let term := lazymatch (eval pattern andb, (andb true) in term) with + | ?f _ _ => (eval cbn [andb] in (f (fun x y => andb y x) (fun b => b))) + end in + term. + Ltac adjust_if_negb term := + lazymatch term with + | context term'[if negb ?x then ?a else ?b] + => let term := context term'[if x then b else a] in + adjust_if_negb term + | _ => term + end. + Ltac substitute_bool_eqb term := + lazymatch term with + | context term'[Bool.eqb ?x true] + => let term := context term'[x] in + substitute_bool_eqb term + | context term'[Bool.eqb ?x false] + => let term := context term'[negb x] in + substitute_bool_eqb term + | context term'[Bool.eqb true ?x] + => let term := context term'[x] in + substitute_bool_eqb term + | context term'[Bool.eqb false ?x] + => let term := context term'[negb x] in + substitute_bool_eqb term + | _ => term + end. + + Ltac substitute_beq full_ctx ctx term := + lazymatch ctx with + | dynnil + => let term := (eval cbv [base.interp_beq base.base_interp_beq] in term) in + let term := substitute_bool_eqb term in + let term := remove_andb_true term in + let term := adjust_if_negb term in + term + | dyncons ?v ?ctx + => let term := substitute_beq_with full_ctx term zrange_beq v in + let term := substitute_beq_with full_ctx term Z.eqb v in + let term := match constr:(Set) with + | _ => let T := type of v in + let beq := (eval cbv beta delta [Reflect.decb_rel] in (Reflect.decb_rel (@eq T))) in + substitute_beq_with full_ctx term beq v + | _ => term + end in + substitute_beq full_ctx ctx term + end. + + Ltac deep_substitute_beq term := + lazymatch term with + | context term'[@Build_rewrite_rule_data ?ident ?var ?pident ?pident_arg_types ?t ?p ?sda ?wo ?ul ?subterm] + => let subterm := under_binders subterm ltac:(fun ctx term => substitute_beq ctx ctx term) dynnil in + let term := context term'[@Build_rewrite_rule_data ident var pident pident_arg_types t p sda wo ul subterm] in + term + end. + + Ltac clean_beq term := + let term := (eval cbn [Prod.prod_beq] in term) in + let term := (eval cbv [ident.literal] in term) in + let term := deep_substitute_beq term in + let term := (eval cbv [base.interp_beq base.base_interp_beq] in term) in + let term := remove_andb_true term in + term. + + + Ltac reify_to_pattern_and_replacement_in_context ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident type_ctx var should_do_again cur_i term value_ctx := + let t := fresh "t" in + let p := fresh "p" in + let reify_rec_gen type_ctx := reify_to_pattern_and_replacement_in_context ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident type_ctx var should_do_again in + let var_pos := constr:(fun _ : type base.type => positive) in + let value := constr:(@value base.type ident var) in + let cexpr_to_pattern_and_replacement_unfolded := constr:(@expr_to_pattern_and_replacement_unfolded ident var pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident should_do_again type_ctx) in + let cpartial_lam_unif_rewrite_ruleTP_gen := constr:(@partial_lam_unif_rewrite_ruleTP_gen_unfolded ident var pident pident_arg_types should_do_again) in + let cwith_unif_rewrite_ruleTP_gen := constr:(fun t p => @with_unif_rewrite_ruleTP_gen ident var pident pident_arg_types value t p should_do_again true true) in + lazymatch term with + | (fun x : ?T => ?f) + => let rT := Compilers.type.reify ltac:(Compilers.base.reify) base.type T in + let not_x1 := fresh in + let not_x2 := fresh in + let next_i := (eval vm_compute in (Pos.succ cur_i)) in + let type_ctx' := fresh in (* COQBUG(https://github.com/coq/coq/issues/7210#issuecomment-470009463) *) + let rf0 := + constr:( + match type_ctx return _ with + | type_ctx' + => fun (x : T) + => match f, @expr.var_context.cons base.type var_pos T rT x cur_i value_ctx return _ with (* c.f. COQBUG(https://github.com/coq/coq/issues/6252#issuecomment-347041995) for [return _] *) + | not_x1, not_x2 + => ltac:( + let f := (eval cbv delta [not_x1] in not_x1) in + let value_ctx := (eval cbv delta [not_x2] in not_x2) in + let type_ctx := (eval cbv delta [type_ctx'] in type_ctx') in + clear not_x1 not_x2 type_ctx'; + let rf := reify_rec_gen type_ctx next_i f value_ctx in + exact rf) + end + end) in + lazymatch rf0 with + | (fun _ => ?f) => f + | _ + => let __ := match goal with + | _ => fail 1 "Failure to eliminate functional dependencies of" rf0 + end in + constr:(I : I) + end + | (@eq ?T ?A ?B, ?side_conditions) + => let rA := expr.reify_in_context base.type ident ltac:(base.reify) reify_ident var_pos A value_ctx tt in + let rB := expr.reify_in_context base.type ident ltac:(base.reify) reify_ident var_pos B value_ctx tt in + let invalid := fresh "invalid" in + let res := constr:(fun invalid => cexpr_to_pattern_and_replacement_unfolded invalid _ rA rB side_conditions) in + let res := (eval cbv [expr_to_pattern_and_replacement_unfolded pident_arg_types pident_of_typed_ident pident_type_of_list_arg_types_beq pident_arg_types_of_typed_ident] in res) in + let res := (eval cbn [fst snd andb pattern.base.relax pattern.base.subst_default pattern.base.subst_default_relax] in res) in + let res := change_pattern_base_subst_default_relax res in + let p := (eval cbv [projT1] in (fun invalid => projT1 (res invalid))) in + let p := strip_invalid_or_fail p in + let p := adjust_pattern_type_variables p in + let res := (eval cbv [projT2] in (fun invalid => projT2 (res invalid))) in + let evm' := fresh "evm" in + let res' := fresh in + let res := + constr:( + fun invalid (evm' : EvarMap) + => match res invalid return _ with + | res' + => ltac:(let res := (eval cbv beta delta [res'] in res') in + clear res'; + let res := adjust_lookup_default res in + let res := adjust_type_variables res in + let res := replace_evar_map evm' res in + let res := replace_type_try_transport res in + exact res) + end) in + let res := (eval cbv [UnderLets.map UnderLets.flat_map reify_expr_beta reflect_expr_beta reify_to_UnderLets] in res) in + let res := (eval cbn [reify reflect UnderLets.of_expr UnderLets.to_expr UnderLets.splice value' Base_value] in res) in + let res := strip_invalid_or_fail res in + (* cbv here not strictly needed *) + let res := (eval cbv [partial_lam_unif_rewrite_ruleTP_gen_unfolded] in + (existT + (cwith_unif_rewrite_ruleTP_gen _) + p + (cpartial_lam_unif_rewrite_ruleTP_gen _ p res))) in + (* not strictly needed *) + let res := (eval cbn [pattern.base.subst_default pattern.base.lookup_default PositiveMap.find type.interp base.interp base.base_interp] in res) in + let res := (eval cbv [projT1 projT2] in + (existT + (@rewrite_ruleTP ident var pident pident_arg_types) + {| pattern.pattern_of_anypattern := projT1 res |} + {| rew_replacement := projT2 res |})) in + let res := clean_beq res in + res + end. + + Ltac reify ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident var should_do_again lem := + reify_under_forall_types + lem + ltac:( + fun ty_ctx cur_i lem + => let lem := equation_to_parts lem in + let res := reify_to_pattern_and_replacement_in_context ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident ty_ctx var should_do_again constr:(1%positive) lem (@expr.var_context.nil base.type (fun _ => positive)) in + res). + + Ltac Reify ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident should_do_again lem := + let var := fresh "var" in + constr:(fun var : Compilers.type.type Compilers.base.type -> Type + => ltac:(let res := reify ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident var should_do_again lem in + exact res)). + + (* lems is either a list of [Prop]s, or a list of [bool (* should_do_again *) * Prop] *) + Ltac reify_list ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident var lems := + let reify' := reify ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident var in + let reify_list_rec := reify_list ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident var in + lazymatch lems with + | (?b, ?lem) :: ?lems + => let rlem := reify' b lem in + let rlems := reify_list_rec lems in + constr:(rlem :: rlems) + | nil => constr:(@nil (@rewrite_ruleT ident var pident pident_arg_types)) + | _ + => let lems := (eval cbv beta iota delta [List.map] in + (List.map (fun p : Prop => (false, p)) lems)) in + reify_list_rec lems + end. + + Ltac Reify_list ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident lems := + let var := fresh "var" in + constr:(fun var : Compilers.type.type Compilers.base.type -> Type + => ltac:(let res := reify_list ident reify_ident pident pident_arg_types pident_type_of_list_arg_types_beq pident_of_typed_ident pident_arg_types_of_typed_ident var lems in + exact res)). + End Reify. + Module Make. Section make_rewrite_rules. Import Compile. @@ -1479,6 +2141,14 @@ Module Compilers. Definition invert_bind_args_unknown := @pattern.Raw.ident.invert_bind_args. Local Notation assemble_identifier_rewriters := (@assemble_identifier_rewriters ident var (@pattern.ident.eta_ident_cps) (@pattern.ident) (@pattern.ident.arg_types) (@pattern.ident.unify) pident_unify_unknown pattern.Raw.ident (@pattern.Raw.ident.full_types) (@pattern.Raw.ident.invert_bind_args) invert_bind_args_unknown (@pattern.Raw.ident.type_of) (@pattern.Raw.ident.to_typed) pattern.Raw.ident.is_simple). + Ltac reify lems := + Reify.reify_list ident ident.reify pattern.ident (@pattern.ident.arg_types) (@pattern.ident.type_of_list_arg_types_beq) (@pattern.ident.of_typed_ident) (@pattern.ident.arg_types_of_typed_ident) var lems. + (* Play games with [match] to get [lems] interpreted as a constr rather than an ident when it's not closed, to get better error messages *) + Local Notation reify lems + := (match lems return _ with + | _LEMS => ltac:(let LEMS := (eval cbv delta [_LEMS] in _LEMS) in let res := reify LEMS in exact res) + end) (only parsing). + Delimit Scope rewrite_scope with rewrite. Delimit Scope rewrite_opt_scope with rewrite_opt. Delimit Scope rewrite_lets_scope with rewrite_lets. @@ -1553,6 +2223,7 @@ Module Compilers. Local Notation ℕ := base.type.nat. Local Notation bool := base.type.bool. Local Notation list := pattern.base.type.list. + Local Notation "' x" := (ident.literal x). (* Local Arguments Make.interp_rewrite_rules / . @@ -1595,16 +2266,25 @@ Module Compilers. Local Arguments pattern.anypattern : clear implicits. Local Arguments Make.interp_rewrite_rules / . Let myapp {A} := Eval cbv [List.app] in @List.app A. + Let myflatten {A} := Eval cbv in List.fold_right myapp (@nil A). + Local Notation do_again P := (true, P) (only parsing). + Local Notation cstZ := (ident.cast ident.cast_outside_of_range). + Local Notation cstZZ := (ident.cast2 ident.cast_outside_of_range). Definition nbe_rewrite_rules : rewrite_rulesT - := Eval cbv [Make.interp_rewrite_rules myapp] in + := Eval cbv [Make.interp_rewrite_rules myapp myflatten] in myapp Make.interp_rewrite_rules - [make_rewrite (#(@pattern.ident.fst '1 '2) @ (??, ??)) (fun _ _ x y => x) - ; make_rewrite (#(@pattern.ident.snd '1 '2) @ (??, ??)) (fun _ x _ y => y) - ; make_rewrite (#(@pattern.ident.List_repeat '1) @ ?? @ #?ℕ) (fun _ x n => reify_list (repeat x n)) - ; make_rewritel (#(@pattern.ident.bool_rect '1) @ ?? @ ?? @ #?𝔹) (fun _ t f b => if b then t ##tt else f ##tt) - ; make_rewritel (#(@pattern.ident.prod_rect '1 '2 '3) @ ?? @ (??, ??)) (fun _ _ _ f x y => f x y) - ; make_rewriteol (??{list '1} ++ ??{list '1}) (fun _ xs ys => rlist_rect ys (fun x _ xs_ys => x :: xs_ys) xs) + (myflatten + [ + (reify + [(forall A B x y, @fst A B (x, y) = x) + ; (forall A B x y, @snd A B (x, y) = y) + ; (forall P t f, @ident.Thunked.bool_rect P t f true = t tt) + ; (forall P t f, @ident.Thunked.bool_rect P t f false = f tt) + ; (forall A B C f x y, @prod_rect A B (fun _ => C) f (x, y) = f x y) + ]) + ; [make_rewrite (#(@pattern.ident.List_repeat '1) @ ?? @ #?ℕ) (fun _ x n => reify_list (repeat x n)) + ; make_rewriteol (??{list '1} ++ ??{list '1}) (fun _ xs ys => rlist_rect ys (fun x _ xs_ys => x :: xs_ys) xs) ; make_rewriteol (#(@pattern.ident.List_firstn '1) @ #?ℕ @ ??) (fun _ n xs => xs <- reflect_list xs; reify_list (List.firstn n xs)) @@ -1641,14 +2321,14 @@ Module Compilers. ; make_rewriteol (#(@pattern.ident.list_rect '1 '2) @ ?? @ ?? @ ??) (fun _ _ Pnil Pcons xs - => rlist_rect (Pnil ##tt) (fun x' xs' rec => Pcons x' (reify_list xs') rec) xs) - ; make_rewritel - (#(@pattern.ident.list_case '1 '2) @ ?? @ ?? @ []) - (fun _ _ Pnil Pcons => Pnil ##tt) - ; make_rewritel - (#(@pattern.ident.list_case '1 '2) @ ?? @ ?? @ (?? :: ??)) - (fun _ _ Pnil Pcons x xs => Pcons x xs) - ; make_rewriteol + => rlist_rect (Pnil ##tt) (fun x' xs' rec => Pcons x' (reify_list xs') rec) xs) ] + + ; (reify + [(forall A P N C, @ident.Thunked.list_case A P N C nil = N tt) + ; (forall A P N C x xs, @ident.Thunked.list_case A P N C (x :: xs) = C x xs) + ]) + ; [ + make_rewriteol (#(@pattern.ident.List_map '1 '2) @ ?? @ ??) (fun _ _ f xs => rlist_rect [] (fun x _ fxs => fx <-- f x; fx :: fxs) xs) ; make_rewriteo @@ -1680,71 +2360,92 @@ Module Compilers. (fun x => x <-- x; f x) (List.map UnderLets.Base ls)); reify_list retv) - ]. + ] ]). Definition arith_rewrite_rules (max_const_val : Z) : rewrite_rulesT - := [make_rewrite (#(@pattern.ident.fst '1 '2) @ (??, ??)) (fun _ _ x y => x) - ; make_rewrite (#(@pattern.ident.snd '1 '2) @ (??, ??)) (fun _ x _ y => y) - ; make_rewriteo (#?ℤ + ??) (fun z v => v when z =? 0) - ; make_rewriteo (?? + #?ℤ ) (fun v z => v when z =? 0) - ; make_rewriteo (#?ℤ + (-??)) (fun z v => ##z - v when z >? 0) - ; make_rewriteo ((-??) + #?ℤ ) (fun v z => ##z - v when z >? 0) - ; make_rewriteo (#?ℤ + (-??)) (fun z v => -(##((-z)%Z) + v) when z -(v + ##((-z)%Z)) when z -(x + y)) - ; make_rewrite ((-??) + ?? ) (fun x y => y - x) - ; make_rewrite ( ?? + (-??)) (fun x y => x - y) - - ; make_rewriteo (#?ℤ - (-??)) (fun z v => v when z =? 0) - ; make_rewriteo (#?ℤ - ?? ) (fun z v => -v when z =? 0) - ; make_rewriteo (?? - #?ℤ ) (fun v z => v when z =? 0) - ; make_rewriteo (#?ℤ - (-??)) (fun z v => ##z + v when z >? 0) - ; make_rewriteo (#?ℤ - (-??)) (fun z v => v - ##((-z)%Z) when z -(##((-z)%Z) + v) when z -(v + ##(z)) when z >? 0) - ; make_rewriteo ((-??) - #?ℤ ) (fun v z => ##((-z)%Z) - v when z v + ##((-z)%Z) when z y - x) - ; make_rewrite ((-??) - ?? ) (fun x y => -(x + y)) - ; make_rewrite ( ?? - (-??)) (fun x y => x + y) - - ; make_rewrite (#?ℤ * #?ℤ ) (fun x y => ##((x*y)%Z)) - ; make_rewriteo (#?ℤ * ??) (fun z v => ##0 when z =? 0) - ; make_rewriteo (?? * #?ℤ ) (fun v z => ##0 when z =? 0) - ; make_rewriteo (#?ℤ * ??) (fun z v => v when z =? 1) - ; make_rewriteo (?? * #?ℤ ) (fun v z => v when z =? 1) - ; make_rewriteo (#?ℤ * (-??)) (fun z v => v when z =? (-1)) - ; make_rewriteo ((-??) * #?ℤ ) (fun v z => v when z =? (-1)) - ; make_rewriteo (#?ℤ * ?? ) (fun z v => -v when z =? (-1)) - ; make_rewriteo (?? * #?ℤ ) (fun v z => -v when z =? (-1)) - ; make_rewriteo (#?ℤ * ?? ) (fun z v => -(##((-z)%Z) * v) when z -(v * ##((-z)%Z)) when z x * y) - ; make_rewrite ((-??) * ?? ) (fun x y => -(x * y)) - ; make_rewrite ( ?? * (-??)) (fun x y => -(x * y)) - - ; make_rewriteo (?? &' #?ℤ) (fun x mask => ##(0)%Z when mask =? 0) - ; make_rewriteo (#?ℤ &' ??) (fun mask x => ##(0)%Z when mask =? 0) - - ; make_rewriteo (?? * #?ℤ) (fun x y => x << ##(Z.log2 y) when (y =? (2^Z.log2 y)) && (negb (y =? 2))) - ; make_rewriteo (#?ℤ * ??) (fun y x => x << ##(Z.log2 y) when (y =? (2^Z.log2 y)) && (negb (y =? 2))) - ; make_rewriteo (?? / #?ℤ) (fun x y => x when (y =? 1)) - ; make_rewriteo (?? mod #?ℤ) (fun x y => ##(0)%Z when (y =? 1)) - ; make_rewriteo (?? / #?ℤ) (fun x y => x >> ##(Z.log2 y) when (y =? (2^Z.log2 y))) - ; make_rewriteo (?? mod #?ℤ) (fun x y => x &' ##(y-1)%Z when (y =? (2^Z.log2 y))) - ; make_rewrite (-(-??)) (fun v => v) - - (* We reassociate some multiplication of small constants *) - ; make_rewriteo (#?ℤ * (#?ℤ * (?? * ??))) (fun c1 c2 x y => (x * (y * (##c1 * ##c2))) when (Z.abs c1 <=? Z.abs max_const_val) && (Z.abs c2 <=? Z.abs max_const_val)) - ; make_rewriteo (#?ℤ * (?? * (?? * #?ℤ))) (fun c1 x y c2 => (x * (y * (##c1 * ##c2))) when (Z.abs c1 <=? Z.abs max_const_val) && (Z.abs c2 <=? Z.abs max_const_val)) - ; make_rewriteo (#?ℤ * (?? * ??)) (fun c x y => (x * (y * ##c)) when (Z.abs c <=? Z.abs max_const_val)) - ; make_rewriteo (#?ℤ * ??) (fun c x => (x * ##c) when (Z.abs c <=? Z.abs max_const_val)) - - ; make_rewrite_step (* _step, so that if one of the arguments is concrete, we automatically get the rewrite rule for [Z_cast] applying to it *) - (#pattern.ident.Z_cast2 @ (??, ??)) (fun r x y => (#(ident.Z_cast (fst r)) @ $x, #(ident.Z_cast (snd r)) @ $y)) - - ; make_rewriteol (-??) (fun e => (llet v := e in -$v) when negb (SubstVarLike.is_var_fst_snd_pair_opp_cast e)) (* inline negation when the rewriter wouldn't already inline it *) - ]. + := Eval cbv [Make.interp_rewrite_rules myapp myflatten] in + myflatten + [reify + [(forall A B x y, @fst A B (x, y) = x) + ; (forall A B x y, @snd A B (x, y) = y) + ; (forall v, 0 + v = v) + ; (forall v, v + 0 = v) + ; (forall x y, (-x) + (-y) = -(x + y)) + ; (forall x y, (-x) + y = y - x) + ; (forall x y, x + (-y) = x - y) + + ; (forall v, 0 - (-v) = v) + ; (forall v, 0 - v = -v) + ; (forall v, v - 0 = v) + ; (forall x y, (-x) - (-y) = y - x) + ; (forall x y, (-x) - y = -(x + y)) + ; (forall x y, x - (-y) = x + y) + + ; (forall v, 0 * v = 0) + ; (forall v, v * 0 = 0) + ; (forall v, 1 * v = v) + ; (forall v, v * 1 = v) + ; (forall v, (-1) * (-v) = v) + ; (forall v, (-v) * (-1) = v) + ; (forall v, (-1) * v = -v) + ; (forall v, v * (-1) = -v) + ; (forall x y, (-x) * (-y) = x * y) + ; (forall x y, (-x) * y = -(x * y)) + ; (forall x y, x * (-y) = -(x * y)) + + ; (forall x, x &' 0 = 0) + + ; (forall x, x / 1 = x) + ; (forall x, x mod 1 = 0) + + ; (forall v, -(-v) = v) + + ; (forall z v, z > 0 -> 'z + (-v) = 'z - v) + ; (forall z v, z > 0 -> (-v) + 'z = 'z - v) + ; (forall z v, z < 0 -> 'z + (-v) = -('(-z) + v)) + ; (forall z v, z < 0 -> (-v) + 'z = -(v + '(-z))) + + ; (forall z v, z > 0 -> 'z - (-v) = 'z + v) + ; (forall z v, z < 0 -> 'z - (-v) = v - '(-z)) + ; (forall z v, z < 0 -> 'z - v = -('(-z) + v)) + ; (forall z v, z > 0 -> (-v) - 'z = -(v + 'z)) + ; (forall z v, z < 0 -> (-v) - 'z = '(-z) - v) + ; (forall z v, z < 0 -> v - 'z = v + '(-z)) + + ; (forall x y, 'x * 'y = '(x*y)) + ; (forall z v, z < 0 -> 'z * v = -('(-z) * v)) + ; (forall z v, z < 0 -> v * 'z = -(v * '(-z))) + + ; (forall x y, y = 2^Z.log2 y -> y <> 2 -> x * 'y = x << '(Z.log2 y)) + ; (forall x y, y = 2^Z.log2 y -> y <> 2 -> 'y * x = x << '(Z.log2 y)) + + ; (forall x y, y = 2^Z.log2 y -> x / 'y = x >> '(Z.log2 y)) + ; (forall x y, y = 2^Z.log2 y -> x mod 'y = x &' '(y-1)) + + (* We reassociate some multiplication of small constants *) + ; (forall c1 c2 x y, + Z.abs c1 <= Z.abs max_const_val + -> Z.abs c2 <= Z.abs max_const_val + -> 'c1 * ('c2 * (x * y)) = (x * (y * ('c1 * 'c2)))) + ; (forall c1 c2 x y, + Z.abs c1 <= Z.abs max_const_val + -> Z.abs c2 <= Z.abs max_const_val + -> 'c1 * (x * (y * 'c2)) = (x * (y * ('c1 * 'c2)))) + ; (forall c x y, + Z.abs c <= Z.abs max_const_val + -> 'c * (x * y) = x * (y * 'c)) + ; (forall c x, + Z.abs c <= Z.abs max_const_val + -> 'c * x = x * 'c) + ] + ; reify + [ (* [do_again], so that if one of the arguments is concrete, we automatically get the rewrite rule for [Z_cast] applying to it *) + do_again (forall r x y, cstZZ r (x, y) = (cstZ (fst r) x, cstZ (snd r) y)) + ] + + ; [ + make_rewriteol (-??) (fun e => (llet v := e in -$v) when negb (SubstVarLike.is_var_fst_snd_pair_opp_cast e)) (* inline negation when the rewriter wouldn't already inline it *) + ] ]. Let cst {var} (r : zrange) (e : @expr.expr _ _ var _) := (#(ident.Z_cast r) @ e)%expr. Let cst' {var} (r : zrange) (e : @expr.expr _ _ var _) := (#(ident.Z_cast (-r)) @ e)%expr. @@ -1764,186 +2465,189 @@ Module Compilers. (cst (fst rvc) (#ident.fst @ (cst2 rvc ($vc))), cst (snd rvc) (#ident.snd @ (cst2 rvc ($vc))))))%expr. + Local Notation "'plet' x := y 'in' z" + := (match y return _ with x => z end). + + Local Notation dlet2_opp2 rvc e + := (plet rvc' := (fst rvc, -snd rvc)%zrange in + plet cst' := cstZZ rvc' in + plet cst1 := cstZ (fst rvc%zrange%zrange) in + plet cst2 := cstZ (snd rvc%zrange%zrange) in + plet cst2' := cstZ (-snd rvc%zrange%zrange) in + (dlet vc := cst' e in + (cst1 (fst (cst' vc)), cst2 (-(cst2' (snd (cst' vc))))))). + + Local Notation dlet2 rvc e + := (dlet vc := cstZZ rvc e in + (cstZ (fst rvc) (fst (cstZZ rvc vc)), + cstZ (snd rvc) (snd (cstZZ rvc vc)))). + + + Local Notation "x '\in' y" := (is_bounded_by_bool x (ZRange.normalize y) = true) : zrange_scope. + Local Notation "x ∈ y" := (is_bounded_by_bool x (ZRange.normalize y) = true) : zrange_scope. + Local Notation "x <= y" := (is_tighter_than_bool (ZRange.normalize x) y = true) : zrange_scope. + Local Notation litZZ x := (ident.literal (fst x), ident.literal (snd x)) (only parsing). + Local Notation n r := (ZRange.normalize r) (only parsing). + Definition arith_with_casts_rewrite_rules : rewrite_rulesT - := [make_rewrite (#(@pattern.ident.fst '1 '2) @ (??, ??)) (fun _ _ x y => x) - ; make_rewrite (#(@pattern.ident.snd '1 '2) @ (??, ??)) (fun _ _ x y => y) - - ; make_rewriteo (??') (fun r v => cst r (##(lower r)) when lower r =? upper r) - - ; make_rewriteo - (#?ℤ' + ?? ) - (fun rz z v => v when (z =? 0) && (is_bounded_by_bool z (ZRange.normalize rz))) - ; make_rewriteo - (?? + #?ℤ') - (fun v rz z => v when (z =? 0) && (is_bounded_by_bool z (ZRange.normalize rz))) - - ; make_rewriteo - (#?ℤ' - (-'??')) - (fun rz z rnv rv v => cst rv v when (z =? 0) && (ZRange.normalize rv <=? -ZRange.normalize rnv)%zrange && (is_bounded_by_bool z rz)) - ; make_rewriteo (#?ℤ' - ?? ) (fun rz z v => -v when (z =? 0) && is_bounded_by_bool z (ZRange.normalize rz)) - - ; make_rewriteo (#?ℤ' << ??) (fun rx x y => ##0 when (x =? 0) && is_bounded_by_bool x (ZRange.normalize rx)) - - ; make_rewriteo (-(-'??')) (fun rnv rv v => cst rv v when (ZRange.normalize rv <=? -ZRange.normalize rnv)%zrange) - - ; make_rewriteo (#pattern.ident.Z_mul_split @ ?? @ #?ℤ' @ ??) (fun s rxx xx y => (cst r[0~>0] ##0, cst r[0~>0] ##0)%Z when (xx =? 0) && is_bounded_by_bool xx (ZRange.normalize rxx)) - ; make_rewriteo (#pattern.ident.Z_mul_split @ ?? @ ?? @ #?ℤ') (fun s y rxx xx => (cst r[0~>0] ##0, cst r[0~>0] ##0)%Z when (xx =? 0) && is_bounded_by_bool xx (ZRange.normalize rxx)) - ; make_rewriteo - (#pattern.ident.Z_mul_split @ #?ℤ' @ #?ℤ' @ ??') - (fun rs s rxx xx ry y => (cst ry y, cst r[0~>0] ##0)%Z when (xx =? 1) && (ZRange.normalize ry <=? r[0~>s-1])%zrange && is_bounded_by_bool s (ZRange.normalize rs) && is_bounded_by_bool xx (ZRange.normalize rxx)) - ; make_rewriteo - (#pattern.ident.Z_mul_split @ #?ℤ' @ ??' @ #?ℤ') - (fun rs s ry y rxx xx => (cst ry y, cst r[0~>0] ##0)%Z when (xx =? 1) && (ZRange.normalize ry <=? r[0~>s-1])%zrange && is_bounded_by_bool s (ZRange.normalize rs) && is_bounded_by_bool xx (ZRange.normalize rxx)) - (* - ; make_rewriteo - (#pattern.ident.Z_mul_split @ #?ℤ @ #?ℤ @ ??') - (fun s xx ry y => (cst' ry (-cst ry y), ##0%Z) when (xx =? (-1)) && (ZRange.normalize ry <=? r[0~>s-1])%zrange) - ; make_rewriteo - (#pattern.ident.Z_mul_split @ #?ℤ @ ??' @ #?ℤ) - (fun s ry y xx => (cst' ry (-cst ry y), ##0%Z) when (xx =? (-1)) && (ZRange.normalize ry <=? r[0~>s-1])%zrange) - *) - - - ; make_rewriteol - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_get_carry @ ?? @ (-'??') @ ??)) - (fun rvc s rny ry y x - => (llet2_opp2 rvc (#ident.Z_sub_get_borrow @ s @ x @ cst ry y)) - when (ZRange.normalize ry <=? -ZRange.normalize rny)%zrange) - ; make_rewriteol - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_get_carry @ ?? @ ?? @ (-'??'))) - (fun rvc s x rny ry y - => (llet2_opp2 rvc (#ident.Z_sub_get_borrow @ s @ x @ cst ry y)) - when (ZRange.normalize ry <=? -ZRange.normalize rny)%zrange) - ; make_rewriteol - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_get_carry @ ?? @ #?ℤ' @ ??)) - (fun rvc s ryy yy x - => (llet2_opp2 rvc (#ident.Z_sub_get_borrow @ s @ x @ cst (ZRange.opp ryy) ##(-yy)%Z)) - when (yy (llet2_opp2 rvc (#ident.Z_sub_get_borrow @ s @ x @ cst (ZRange.opp ryy) ##(-yy)%Z)) - when (yy (llet2_opp2 rvc (#ident.Z_sub_with_get_borrow @ s @ (cst rc c) @ x @ (cst ry y))) - when ((ZRange.normalize ry <=? -ZRange.normalize rny) && (ZRange.normalize rc <=? -ZRange.normalize rnc))%zrange) - ; make_rewriteol - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_with_get_carry @ ?? @ (-'??') @ ?? @ (-'??'))) - (fun rvc s rnc rc c x rny ry y - => (llet2_opp2 rvc (#ident.Z_sub_with_get_borrow @ s @ (cst rc c) @ x @ (cst ry y))) - when ((ZRange.normalize ry <=? -ZRange.normalize rny) && (ZRange.normalize rc <=? -ZRange.normalize rnc))%zrange) - ; make_rewriteol - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_with_get_carry @ ?? @ #?ℤ' @ (-'??') @ ??)) - (fun rvc s rcc cc rny ry y x - => (llet2_opp2 rvc (#ident.Z_sub_get_borrow @ s @ x @ cst ry y)) - when (cc =? 0) && (ZRange.normalize ry <=? -ZRange.normalize rny)%zrange && is_bounded_by_bool cc (ZRange.normalize rcc)) - ; make_rewriteol - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_with_get_carry @ ?? @ #?ℤ' @ (-'??') @ ??)) - (fun rvc s rcc cc rny ry y x - => (llet2_opp2 rvc (#ident.Z_sub_with_get_borrow @ s @ cst (ZRange.opp rcc) ##(-cc)%Z @ x @ cst ry y)) - when (cc (llet2_opp2 rvc (#ident.Z_sub_get_borrow @ s @ x @ cst ry y)) - when (cc =? 0) && (ZRange.normalize ry <=? -ZRange.normalize rny)%zrange && is_bounded_by_bool cc (ZRange.normalize rcc)) - ; make_rewriteol - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_with_get_carry @ ?? @ #?ℤ' @ ?? @ (-'??'))) - (fun rvc s rcc cc x rny ry y - => (llet2_opp2 rvc (#ident.Z_sub_with_get_borrow @ s @ cst (ZRange.opp rcc) ##(-cc)%Z @ x @ cst ry y) - when (cc (llet2_opp2 rvc (#ident.Z_sub_with_get_borrow @ s @ cst rc c @ x @ cst (ZRange.opp ryy) ##(-yy)%Z)) - when (yy <=? 0) && (ZRange.normalize rc <=? -ZRange.normalize rnc)%zrange && is_bounded_by_bool yy (ZRange.normalize ryy)) - ; make_rewriteol - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_with_get_carry @ ?? @ (-'??') @ ?? @ #?ℤ')) - (fun rvc s rnc rc c x ryy yy - => (llet2_opp2 rvc (#ident.Z_sub_with_get_borrow @ s @ cst rc c @ x @ cst (ZRange.opp ryy) ##(-yy)%Z)) - when (yy <=? 0) && (ZRange.normalize rc <=? -ZRange.normalize rnc)%zrange && is_bounded_by_bool yy (ZRange.normalize ryy)) - ; make_rewriteol - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_with_get_carry @ ?? @ #?ℤ' @ #?ℤ' @ ??)) - (fun rvc s rcc cc ryy yy x - => (llet2_opp2 rvc (#ident.Z_sub_with_get_borrow @ s @ cst (ZRange.opp rcc) ##(-cc)%Z @ x @ cst (ZRange.opp ryy) ##(-yy)%Z)) - when (yy <=? 0) && (cc <=? 0) && ((yy + cc) (llet2_opp2 rvc (#ident.Z_sub_with_get_borrow @ s @ cst (ZRange.opp rcc) ##(-cc)%Z @ x @ cst (ZRange.opp ryy) ##(-yy)%Z)) - when (yy <=? 0) && (cc <=? 0) && ((yy + cc) ##(Z.add_get_carry_full s xx yy) when is_bounded_by_bool s (ZRange.normalize rs) && is_bounded_by_bool xx (ZRange.normalize rxx) && is_bounded_by_bool yy (ZRange.normalize ryy)) - ; make_rewriteo - (#pattern.ident.Z_add_get_carry @ #?ℤ' @ #?ℤ' @ ??') - (fun rs s rxx xx ry y => (cst ry y, cst r[0~>0] ##0) when (xx =? 0) && (ZRange.normalize ry <=? r[0~>s-1])%zrange && is_bounded_by_bool xx (ZRange.normalize rxx) && is_bounded_by_bool s (ZRange.normalize rs)) - ; make_rewriteo - (#pattern.ident.Z_add_get_carry @ #?ℤ' @ ??' @ #?ℤ') - (fun rs s ry y rxx xx => (cst ry y, cst r[0~>0] ##0) when (xx =? 0) && (ZRange.normalize ry <=? r[0~>s-1])%zrange && is_bounded_by_bool xx (ZRange.normalize rxx) && is_bounded_by_bool s (ZRange.normalize rs)) - - ; make_rewriteo (#pattern.ident.Z_add_with_carry @ #?ℤ' @ ?? @ ??) (fun rcc cc x y => x + y when (cc =? 0) && is_bounded_by_bool cc (ZRange.normalize rcc)) - (*; make_rewrite_step (#pattern.ident.Z_add_with_carry @ ?? @ ?? @ ??) (fun x y z => $x + $y + $z)*) - - ; make_rewriteo - (#pattern.ident.Z_add_with_get_carry @ #?ℤ' @ #?ℤ' @ #?ℤ' @ #?ℤ') - (fun rs s rcc cc rxx xx ryy yy => ##(Z.add_with_get_carry_full s cc xx yy) when is_bounded_by_bool s (ZRange.normalize rs) && is_bounded_by_bool cc (ZRange.normalize rcc) && is_bounded_by_bool xx (ZRange.normalize rxx) && is_bounded_by_bool yy (ZRange.normalize ryy)) - ; make_rewriteo - (#pattern.ident.Z_add_with_get_carry @ #?ℤ' @ #?ℤ' @ #?ℤ' @ ??') - (fun rs s rcc cc rxx xx ry y => (cst ry y, cst r[0~>0] ##0) when (cc =? 0) && (xx =? 0) && (ZRange.normalize ry <=? r[0~>s-1])%zrange && is_bounded_by_bool s (ZRange.normalize rs) && is_bounded_by_bool cc (ZRange.normalize rcc) && is_bounded_by_bool xx (ZRange.normalize rxx)) - ; make_rewriteo - (#pattern.ident.Z_add_with_get_carry @ #?ℤ' @ #?ℤ' @ ??' @ #?ℤ') - (fun rs s rcc cc ry y rxx xx => (cst ry y, cst r[0~>0] ##0) when (cc =? 0) && (xx =? 0) && (ZRange.normalize ry <=? r[0~>s-1])%zrange && is_bounded_by_bool s (ZRange.normalize rs) && is_bounded_by_bool cc (ZRange.normalize rcc) && is_bounded_by_bool xx (ZRange.normalize rxx)) - (*; make_rewriteo - (#pattern.ident.Z_add_with_get_carry @ ?? @ ?? @ #?ℤ @ #?ℤ) (fun s c xx yy => (c, ##0) when (xx =? 0) && (yy =? 0))*) - ; make_rewriteol (* carry = 0: ADC x y -> ADD x y *) - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_with_get_carry @ ?? @ #?ℤ' @ ?? @ ??)) - (fun rvc s rcc cc x y - => (llet2 rvc (#ident.Z_add_get_carry @ s @ x @ y)) - when (cc =? 0) && is_bounded_by_bool cc (ZRange.normalize rcc)) - ; make_rewriteol (* ADC 0 0 -> (ADX 0 0, 0) *) (* except we don't do ADX, because C stringification doesn't handle it *) - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_with_get_carry @ #?ℤ' @ ??' @ #?ℤ' @ #?ℤ')) - (fun rvc rs s rc c rxx xx ryy yy - => (llet vc := cst2 rvc (#ident.Z_add_with_get_carry @ cst rs ##s @ cst rc c @ cst rxx ##xx @ cst ryy ##yy) in - (cst (fst rvc) (#ident.fst @ cst2 rvc ($vc)), cst r[0~>0] ##0)) - when (xx =? 0) && (yy =? 0) && (ZRange.normalize rc <=? r[0~>s-1])%zrange && is_bounded_by_bool 0 (snd rvc) && is_bounded_by_bool s (ZRange.normalize rs) && is_bounded_by_bool xx (ZRange.normalize rxx) && is_bounded_by_bool yy (ZRange.normalize ryy)) - - - (* let-bind any adc/sbb/mulx *) - ; make_rewritel - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_with_get_carry @ ?? @ ?? @ ?? @ ??)) - (fun rvc s c x y => llet2 rvc (#ident.Z_add_with_get_carry @ s @ c @ x @ y)) - ; make_rewritel - (#pattern.ident.Z_cast @ (#pattern.ident.Z_add_with_carry @ ?? @ ?? @ ??)) - (fun rv c x y => (llet vc := cst rv (#ident.Z_add_with_carry @ c @ x @ y) in - (cst rv ($vc)))) - ; make_rewritel - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_add_get_carry @ ?? @ ?? @ ??)) - (fun rvc s x y => llet2 rvc (#ident.Z_add_get_carry @ s @ x @ y)) - ; make_rewritel - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_sub_with_get_borrow @ ?? @ ?? @ ?? @ ??)) - (fun rvc s c x y => llet2 rvc (#ident.Z_sub_with_get_borrow @ s @ c @ x @ y)) - ; make_rewritel - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_sub_get_borrow @ ?? @ ?? @ ??)) - (fun rvc s x y => llet2 rvc (#ident.Z_sub_get_borrow @ s @ x @ y)) - ; make_rewritel - (#pattern.ident.Z_cast2 @ (#pattern.ident.Z_mul_split @ ?? @ ?? @ ??)) - (fun rvc s x y => llet2 rvc (#ident.Z_mul_split @ s @ x @ y)) - - ; make_rewrite_step (#pattern.ident.Z_cast2 @ (??, ??)) (fun r v1 v2 => (#(ident.Z_cast (fst r)) @ $v1, #(ident.Z_cast (snd r)) @ $v2)) - - ; make_rewriteo - (#pattern.ident.Z_cast @ (#pattern.ident.Z_cast @ ??)) - (fun r1 r2 x => #(ident.Z_cast r2) @ x when ZRange.is_tighter_than_bool (ZRange.normalize r2) (ZRange.normalize r1)) - ]. + := Eval cbv [Make.interp_rewrite_rules myapp myflatten] in + myflatten + [reify + [(forall A B x y, @fst A B (x, y) = x) + ; (forall A B x y, @snd A B (x, y) = y) + ; (forall r v, lower r = upper r -> cstZ r v = cstZ r ('(lower r))) + ; (forall r0 v, 0 ∈ r0 -> cstZ r0 0 + v = v) + ; (forall r0 v, 0 ∈ r0 -> v + cstZ r0 0 = v) + ; (forall r0 v, 0 ∈ r0 -> cstZ r0 0 - v = -v) + ; (forall r0 v, 0 ∈ r0 -> cstZ r0 0 << v = 0) + ; (forall r0 rnv rv v, + (rv <= -n rnv)%zrange -> 0 ∈ r0 + -> cstZ r0 0 - cstZ rnv (-(cstZ rv v)) = cstZ rv v) + ; (forall rnv rv v, + (rv <= -n rnv)%zrange + -> -(cstZ rnv (-(cstZ rv v))) = cstZ rv v) + + ; (forall s r0 y, 0 ∈ r0 -> Z.mul_split s (cstZ r0 0) y = (cstZ r[0~>0] 0, cstZ r[0~>0] 0)) + ; (forall s r0 y, 0 ∈ r0 -> Z.mul_split s y (cstZ r0 0) = (cstZ r[0~>0] 0, cstZ r[0~>0] 0)) + ; (forall rs s r1 ry y, + 1 ∈ r1 -> s ∈ rs -> (ry <= r[0~>s-1])%zrange + -> Z.mul_split (cstZ rs ('s)) (cstZ r1 1) (cstZ ry y) + = (cstZ ry y, cstZ r[0~>0] 0)) + ; (forall rs s r1 ry y, + 1 ∈ r1 -> s ∈ rs -> (ry <= r[0~>s-1])%zrange + -> Z.mul_split (cstZ rs ('s)) (cstZ ry y) (cstZ r1 1) + = (cstZ ry y, cstZ r[0~>0] 0)) + + ; (forall rvc s rny ry y x, + (ry <= -n rny)%zrange + -> cstZZ rvc (Z.add_get_carry_full s (cstZ rny (-cstZ ry y)) x) + = dlet2_opp2 rvc (Z.sub_get_borrow_full s x (cstZ ry y))) + ; (forall rvc s rny ry y x, + (ry <= -n rny)%zrange + -> cstZZ rvc (Z.add_get_carry_full s x (cstZ rny (-cstZ ry y))) + = dlet2_opp2 rvc (Z.sub_get_borrow_full s x (cstZ ry y))) + ; (forall rvc s ryy yy x, + yy ∈ ryy -> yy < 0 + -> cstZZ rvc (Z.add_get_carry_full s (cstZ ryy ('yy)) x) + = dlet2_opp2 rvc (Z.sub_get_borrow_full s x (cstZ (-ryy) ('(-yy))))) + ; (forall rvc s ryy yy x, + yy ∈ ryy -> yy < 0 + -> cstZZ rvc (Z.add_get_carry_full s x (cstZ ryy ('yy))) + = dlet2_opp2 rvc (Z.sub_get_borrow_full s x (cstZ (-ryy) ('(-yy))))) + ; (forall rvc s rnc rc c rny ry y x, + (ry <= -n rny)%zrange -> (rc <= -n rnc)%zrange + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ rnc (-cstZ rc c)) (cstZ rny (-cstZ ry y)) x) + = dlet2_opp2 rvc (Z.sub_with_get_borrow_full s (cstZ rc c) x (cstZ ry y))) + ; (forall rvc s rnc rc c rny ry y x, + (ry <= -n rny)%zrange -> (rc <= -n rnc)%zrange + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ rnc (-cstZ rc c)) x (cstZ rny (-cstZ ry y))) + = dlet2_opp2 rvc (Z.sub_with_get_borrow_full s (cstZ rc c) x (cstZ ry y))) + ; (forall rvc s r0 rny ry y x, + 0 ∈ r0 -> (ry <= -n rny)%zrange + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ r0 0) (cstZ rny (-cstZ ry y)) x) + = dlet2_opp2 rvc (Z.sub_get_borrow_full s x (cstZ ry y))) + ; (forall rvc s rcc cc rny ry y x, + cc < 0 -> cc ∈ rcc -> (ry <= -n rny)%zrange + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ rcc ('cc)) (cstZ rny (-cstZ ry y)) x) + = dlet2_opp2 rvc (Z.sub_with_get_borrow_full s (cstZ (-rcc) ('(-cc))) x (cstZ ry y))) + ; (forall rvc s r0 rny ry y x, + 0 ∈ r0 -> (ry <= -n rny)%zrange + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ r0 0) x (cstZ rny (-cstZ ry y))) + = dlet2_opp2 rvc (Z.sub_get_borrow_full s x (cstZ ry y))) + ; (forall rvc s rcc cc rny ry y x, + cc < 0 -> cc ∈ rcc -> (ry <= -n rny)%zrange + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ rcc ('cc)) x (cstZ rny (-cstZ ry y))) + = dlet2_opp2 rvc (Z.sub_with_get_borrow_full s (cstZ (-rcc) ('(-cc))) x (cstZ ry y))) + ; (forall rvc s rnc rc c ryy yy x, + yy <= 0 -> yy ∈ ryy -> (rc <= -n rnc)%zrange + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ rnc (-cstZ rc c)) (cstZ ryy ('yy)) x) + = dlet2_opp2 rvc (Z.sub_with_get_borrow_full s (cstZ rc c) x (cstZ (-ryy) ('(-yy))))) + ; (forall rvc s rnc rc c ryy yy x, + yy <= 0 -> yy ∈ ryy -> (rc <= -n rnc)%zrange + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ rnc (-cstZ rc c)) x (cstZ ryy ('yy))) + = dlet2_opp2 rvc (Z.sub_with_get_borrow_full s (cstZ rc c) x (cstZ (-ryy) ('(-yy))))) + ; (forall rvc s rcc cc ryy yy x, + yy <= 0 -> cc <= 0 -> yy + cc < 0 (* at least one must be strictly negative *) -> yy ∈ ryy -> cc ∈ rcc + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ rcc ('cc)) (cstZ ryy ('yy)) x) + = dlet2_opp2 rvc (Z.sub_with_get_borrow_full s (cstZ (-rcc) ('(-cc))) x (cstZ (-ryy) ('(-yy))))) + ; (forall rvc s rcc cc ryy yy x, + yy <= 0 -> cc <= 0 -> yy + cc < 0 (* at least one must be strictly negative *) -> yy ∈ ryy -> cc ∈ rcc + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ rcc ('cc)) x (cstZ ryy ('yy))) + = dlet2_opp2 rvc (Z.sub_with_get_borrow_full s (cstZ (-rcc) ('(-cc))) x (cstZ (-ryy) ('(-yy))))) + + + ; (forall rs s rxx xx ryy yy, + s ∈ rs -> xx ∈ rxx -> yy ∈ ryy + -> Z.add_get_carry_full (cstZ rs ('s)) (cstZ rxx ('xx)) (cstZ ryy ('yy)) + = litZZ (Z.add_get_carry_full s xx yy)) + ; (forall rs s r0 ry y, + s ∈ rs -> 0 ∈ r0 -> (ry <= r[0~>s-1])%zrange + -> Z.add_get_carry_full (cstZ rs ('s)) (cstZ r0 0) (cstZ ry y) + = (cstZ ry y, cstZ r[0~>0] 0)) + ; (forall rs s r0 ry y, + s ∈ rs -> 0 ∈ r0 -> (ry <= r[0~>s-1])%zrange + -> Z.add_get_carry_full (cstZ rs ('s)) (cstZ ry y) (cstZ r0 0) + = (cstZ ry y, cstZ r[0~>0] 0)) + + ; (forall r0 x y, 0 ∈ r0 -> Z.add_with_carry (cstZ r0 0) x y = x + y) + + ; (forall rs s rcc cc rxx xx ryy yy, + s ∈ rs -> cc ∈ rcc -> xx ∈ rxx -> yy ∈ ryy + -> Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ rcc ('cc)) (cstZ rxx ('xx)) (cstZ ryy ('yy)) + = litZZ (Z.add_with_get_carry_full s cc xx yy)) + ; (forall rs s r0c r0x ry y, + s ∈ rs -> 0 ∈ r0c -> 0 ∈ r0x -> (ry <= r[0~>s-1])%zrange + -> Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ r0c 0) (cstZ r0x 0) (cstZ ry y) + = (cstZ ry y, cstZ r[0~>0] 0)) + ; (forall rs s r0c r0x ry y, + s ∈ rs -> 0 ∈ r0c -> 0 ∈ r0x -> (ry <= r[0~>s-1])%zrange + -> Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ r0c 0) (cstZ ry y) (cstZ r0x 0) + = (cstZ ry y, cstZ r[0~>0] 0)) + + ; (forall rvc s r0 x y, (* carry = 0: ADC x y -> ADD x y *) + 0 ∈ r0 + -> cstZZ rvc (Z.add_with_get_carry_full s (cstZ r0 0) x y) + = dlet2 rvc (Z.add_get_carry_full s x y)) + ; (forall rvc rs s rc c r0x r0y, (* ADC 0 0 -> (ADX 0 0, 0) *) (* except we don't do ADX, because C stringification doesn't handle it *) + 0 ∈ r0x -> 0 ∈ r0y -> (rc <= r[0~>s-1])%zrange -> 0 ∈ snd rvc -> s ∈ rs + -> cstZZ rvc (Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ rc c) (cstZ r0x 0) (cstZ r0y 0)) + = (dlet vc := (cstZZ rvc (Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ rc c) (cstZ r0x 0) (cstZ r0y 0))) in + (cstZ (fst rvc) (fst (cstZZ rvc vc)), + cstZ r[0~>0] 0))) + + (* let-bind any adc/sbb/mulx *) + ; (forall rvc s c x y, + cstZZ rvc (Z.add_with_get_carry_full s c x y) + = dlet2 rvc (Z.add_with_get_carry_full s c x y)) + ; (forall rv c x y, + cstZ rv (Z.add_with_carry c x y) + = (dlet vc := cstZ rv (Z.add_with_carry c x y) in + cstZ rv vc)) + ; (forall rvc s x y, + cstZZ rvc (Z.add_get_carry_full s x y) + = dlet2 rvc (Z.add_get_carry_full s x y)) + ; (forall rvc s c x y, + cstZZ rvc (Z.sub_with_get_borrow_full s c x y) + = dlet2 rvc (Z.sub_with_get_borrow_full s c x y)) + ; (forall rvc s x y, + cstZZ rvc (Z.sub_get_borrow_full s x y) + = dlet2 rvc (Z.sub_get_borrow_full s x y)) + ; (forall rvc s x y, + cstZZ rvc (Z.mul_split s x y) + = dlet2 rvc (Z.mul_split s x y)) + ]%Z%zrange + ; reify + [ (* [do_again], so that if one of the arguments is concrete, we automatically get the rewrite rule for [Z_cast] applying to it *) + do_again (forall r x y, cstZZ r (x, y) = (cstZ (fst r) x, cstZ (snd r) y)) + ] + ; reify + [(forall r1 r2 x, (r2 <= n r1)%zrange -> cstZ r1 (cstZ r2 x) = cstZ r2 x) + ]%Z%zrange + ]. Definition strip_literal_casts_rewrite_rules : rewrite_rulesT - := [make_rewriteo (#?ℤ') (fun rx x => ##x when is_bounded_by_bool x (ZRange.normalize rx))]. + := reify + [(forall rx x, x ∈ rx -> cstZ rx ('x) = 'x)]%Z%zrange. Definition nbe_dtree' @@ -2017,12 +2721,14 @@ Module Compilers. Local Notation pcst2 v := (#pattern.ident.Z_cast2 @ v)%pattern. Local Coercion ZRange.constant : Z >-> zrange. (* for ease of use with sanity-checking bounds *) - Let bounds1_good (f : zrange -> zrange) (output x_bs : zrange) - := is_tighter_than_bool (f (ZRange.normalize x_bs)) (ZRange.normalize output). - Let bounds2_good (f : zrange -> zrange -> zrange) (output x_bs y_bs : zrange) - := is_tighter_than_bool (f (ZRange.normalize x_bs) (ZRange.normalize y_bs)) (ZRange.normalize output). - Let range_in_bitwidth r s - := is_tighter_than_bool (ZRange.normalize r) r[0~>s-1]%zrange. + Local Notation bounds1_good f + := (fun (output x_bs : zrange) + => is_tighter_than_bool (f (ZRange.normalize x_bs)) (ZRange.normalize output) = true). + Local Notation bounds2_good f + := (fun (output x_bs y_bs : zrange) + => is_tighter_than_bool (f (ZRange.normalize x_bs) (ZRange.normalize y_bs)) (ZRange.normalize output) = true). + Local Notation range_in_bitwidth r s + := (is_tighter_than_bool (ZRange.normalize r) r[0~>s-1]%zrange = true). Local Notation shiftl_good := (bounds2_good ZRange.shiftl). Local Notation shiftr_good := (bounds2_good ZRange.shiftr). Local Notation land_good := (bounds2_good ZRange.land). @@ -2031,203 +2737,321 @@ Module Compilers. Local Notation lit_good x rx := (is_bounded_by_bool x (ZRange.normalize rx)). Definition fancy_with_casts_rewrite_rules : rewrite_rulesT - := [ - (* + := Eval cbv [Make.interp_rewrite_rules myapp myflatten] in + myflatten + [reify + [(* (Z.add_get_carry_concrete 2^256) @@ (?x, ?y << 128) --> (add 128) @@ (x, y) (Z.add_get_carry_concrete 2^256) @@ (?x << 128, ?y) --> (add 128) @@ (y, x) (Z.add_get_carry_concrete 2^256) @@ (?x, ?y >> 128) --> (add (- 128)) @@ (x, y) (Z.add_get_carry_concrete 2^256) @@ (?x >> 128, ?y) --> (add (- 128)) @@ (y, x) (Z.add_get_carry_concrete 2^256) @@ (?x, ?y) --> (add 0) @@ (y, x) -*) - make_rewriteo - (pcst2 (#pattern.ident.Z_add_get_carry @ #?ℤ' @ ??' @ (pcst (#pattern.ident.Z_shiftl @ (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) @ #?ℤ')))) - (fun '((r1, r2)%core) rs s rx x rshiftl rland ry y rmask mask roffset offset => cst2 (r1, r2)%core (#(ident.fancy_add (Z.log2 s) offset) @ (cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftl_good rshiftl rland offset && land_good rland ry mask && range_in_bitwidth rshiftl s && (mask =? Z.ones (Z.log2 s - offset)) && (0 <=? offset) && (offset <=? Z.log2 s) && lit_good s rs && lit_good mask rmask && lit_good offset roffset) - - ; make_rewriteo - (pcst2 (#pattern.ident.Z_add_get_carry @ #?ℤ' @ (pcst (#pattern.ident.Z_shiftl @ (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) @ #?ℤ')) @ ??')) - (fun '((r1, r2)%core) rs s rshiftl rland ry y rmask mask roffset offset rx x => cst2 (r1, r2)%core (#(ident.fancy_add (Z.log2 s) offset) @ (cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftl_good rshiftl rland offset && land_good rland ry mask && range_in_bitwidth rshiftl s && (mask =? Z.ones (Z.log2 s - offset)) && (0 <=? offset) && (offset <=? Z.log2 s) && lit_good s rs && lit_good mask rmask && lit_good offset roffset) - - ; make_rewriteo - (pcst2 (#pattern.ident.Z_add_get_carry @ #?ℤ' @ ??' @ (pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ')))) - (fun '((r1, r2)%core) rs s rx x rshiftr ry y roffset offset => cst2 (r1, r2)%core (#(ident.fancy_add (Z.log2 s) (-offset)) @ (cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftr_good rshiftr ry offset && range_in_bitwidth rshiftr s && lit_good s rs && lit_good offset roffset) - - ; make_rewriteo - (pcst2 (#pattern.ident.Z_add_get_carry @ #?ℤ' @ pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ') @ ??')) - (fun '((r1, r2)%core) rs s rshiftr ry y roffset offset rx x => cst2 (r1, r2)%core (#(ident.fancy_add (Z.log2 s) (-offset)) @ (cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftr_good rshiftr ry offset && range_in_bitwidth rshiftr s && lit_good s rs && lit_good offset roffset) - - ; make_rewriteo - (pcst2 (#pattern.ident.Z_add_get_carry @ #?ℤ' @ ??' @ ??')) - (fun '((r1, r2)%core) rs s rx x ry y => cst2 (r1, r2)%core (#(ident.fancy_add (Z.log2 s) 0) @ (cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && range_in_bitwidth ry s && lit_good s rs) -(* + *) + (forall r rs s rx x rshiftl rland ry y rmask mask roffset offset, + s = 2^Z.log2 s -> s ∈ rs -> offset ∈ roffset -> mask ∈ rmask -> shiftl_good rshiftl rland offset -> land_good rland ry mask -> range_in_bitwidth rshiftl s -> (mask = Z.ones (Z.log2 s - offset)) -> (0 <= offset <= Z.log2 s) + -> cstZZ r (Z.add_get_carry_full (cstZ rs ('s)) (cstZ rx x) (cstZ rshiftl ((cstZ rland (cstZ ry y &' cstZ rmask ('mask))) << cstZ roffset ('offset)))) + = cstZZ r (ident.interp (ident.fancy_add (Z.log2 s) (offset)) (cstZ rx x, cstZ ry y))) + ; (forall r rs s rx x rshiftl rland ry y rmask mask roffset offset, + (s = 2^Z.log2 s) -> (mask = Z.ones (Z.log2 s - offset)) -> (0 <= offset <= Z.log2 s) -> s ∈ rs -> mask ∈ rmask -> offset ∈ roffset -> shiftl_good rshiftl rland offset -> land_good rland ry mask -> range_in_bitwidth rshiftl s + -> cstZZ r (Z.add_get_carry_full (cstZ rs ('s)) (cstZ rx x) (cstZ rshiftl (cstZ rland (cstZ ry y &' cstZ rmask ('mask)) << cstZ roffset ('offset)))) + = cstZZ r (ident.interp (ident.fancy_add (Z.log2 s) offset) (cstZ rx x, cstZ ry y))) + + ; (forall r rs s rshiftl rland ry y rmask mask roffset offset rx x, + s ∈ rs -> mask ∈ rmask -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftl_good rshiftl rland offset -> land_good rland ry mask -> range_in_bitwidth rshiftl s -> (mask = Z.ones (Z.log2 s - offset)) -> (0 <= offset <= Z.log2 s) + -> cstZZ r (Z.add_get_carry_full (cstZ rs ('s)) (cstZ rshiftl (Z.shiftl (cstZ rland (Z.land (cstZ ry y) (cstZ rmask ('mask)))) (cstZ roffset ('offset)))) (cstZ rx x)) + = cstZZ r (ident.interp (ident.fancy_add (Z.log2 s) offset) (cstZ rx x, cstZ ry y))) + + ; (forall r rs s rx x rshiftr ry y roffset offset, + s ∈ rs -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftr_good rshiftr ry offset -> range_in_bitwidth rshiftr s + -> cstZZ r (Z.add_get_carry_full (cstZ rs ('s)) (cstZ rx x) (cstZ rshiftr (Z.shiftr (cstZ ry y) (cstZ roffset ('offset))))) + = cstZZ r (ident.interp (ident.fancy_add (Z.log2 s) (-offset)) (cstZ rx x, cstZ ry y))) + + ; (forall r rs s rshiftr ry y roffset offset rx x, + s ∈ rs -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftr_good rshiftr ry offset -> range_in_bitwidth rshiftr s + -> cstZZ r (Z.add_get_carry_full (cstZ rs ('s)) (cstZ rshiftr (Z.shiftr (cstZ ry y) (cstZ roffset ('offset)))) (cstZ rx x)) + = cstZZ r (ident.interp (ident.fancy_add (Z.log2 s) (-offset)) (cstZ rx x, cstZ ry y))) + + ; (forall r rs s rx x ry y, + s ∈ rs -> (s = 2^Z.log2 s) -> range_in_bitwidth ry s + -> cstZZ r (Z.add_get_carry_full (cstZ rs ('s)) (cstZ rx x) (cstZ ry y)) + = cstZZ r (ident.interp (ident.fancy_add (Z.log2 s) 0) (cstZ rx x, cstZ ry y))) + + (* (Z.add_with_get_carry_concrete 2^256) @@ (?c, ?x, ?y << 128) --> (addc 128) @@ (c, x, y) (Z.add_with_get_carry_concrete 2^256) @@ (?c, ?x << 128, ?y) --> (addc 128) @@ (c, y, x) (Z.add_with_get_carry_concrete 2^256) @@ (?c, ?x, ?y >> 128) --> (addc (- 128)) @@ (c, x, y) (Z.add_with_get_carry_concrete 2^256) @@ (?c, ?x >> 128, ?y) --> (addc (- 128)) @@ (c, y, x) (Z.add_with_get_carry_concrete 2^256) @@ (?c, ?x, ?y) --> (addc 0) @@ (c, y, x) - *) - ; make_rewriteo - (pcst2 (#pattern.ident.Z_add_with_get_carry @ #?ℤ' @ ??' @ ??' @ pcst (#pattern.ident.Z_shiftl @ (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) @ #?ℤ'))) - (fun '((r1, r2)%core) rs s rc c rx x rshiftl rland ry y rmask mask roffset offset => cst2 (r1, r2)%core (#(ident.fancy_addc (Z.log2 s) offset) @ (cst rc c, cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftl_good rshiftl rland offset && land_good rland ry mask && range_in_bitwidth rshiftl s && (mask =? Z.ones (Z.log2 s - offset)) && (0 <=? offset) && (offset <=? Z.log2 s) && lit_good s rs && lit_good mask rmask && lit_good offset roffset) - - ; make_rewriteo - (pcst2 (#pattern.ident.Z_add_with_get_carry @ #?ℤ' @ ??' @ pcst (#pattern.ident.Z_shiftl @ (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) @ #?ℤ') @ ??')) - (fun '((r1, r2)%core) rs s rc c rshiftl rland ry y rmask mask roffset offset rx x => cst2 (r1, r2)%core (#(ident.fancy_addc (Z.log2 s) offset) @ (cst rc c, cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftl_good rshiftl rland offset && range_in_bitwidth rshiftl s && land_good rland ry mask && (mask =? Z.ones (Z.log2 s - offset)) && (0 <=? offset) && (offset <=? Z.log2 s) && lit_good s rs && lit_good mask rmask && lit_good offset roffset) + *) + ; (forall r rs s rc c rx x rshiftl rland ry y rmask mask roffset offset, + s ∈ rs -> mask ∈ rmask -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftl_good rshiftl rland offset -> land_good rland ry mask -> range_in_bitwidth rshiftl s -> (mask = Z.ones (Z.log2 s - offset)) -> (0 <= offset <= Z.log2 s) + -> cstZZ r (Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ rc c) (cstZ rx x) (cstZ rshiftl (Z.shiftl (cstZ rland (Z.land (cstZ ry y) (cstZ rmask ('mask)))) (cstZ roffset ('offset))))) + = cstZZ r (ident.interp (ident.fancy_addc (Z.log2 s) offset) (cstZ rc c, cstZ rx x, cstZ ry y))) + + ; (forall r rs s rc c rshiftl rland ry y rmask mask roffset offset rx x, + s ∈ rs -> mask ∈ rmask -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftl_good rshiftl rland offset -> range_in_bitwidth rshiftl s -> land_good rland ry mask -> (mask = Z.ones (Z.log2 s - offset)) -> (0 <= offset <= Z.log2 s) + -> cstZZ r (Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ rc c) (cstZ rshiftl (Z.shiftl (cstZ rland (Z.land (cstZ ry y) (cstZ rmask ('mask)))) (cstZ roffset ('offset)))) (cstZ rx x)) + = cstZZ r (ident.interp (ident.fancy_addc (Z.log2 s) offset) (cstZ rc c, cstZ rx x, cstZ ry y))) + + ; (forall r rs s rc c rx x rshiftr ry y roffset offset, + s ∈ rs -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftr_good rshiftr ry offset -> range_in_bitwidth rshiftr s + -> cstZZ r (Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ rc c) (cstZ rx x) (cstZ rshiftr (Z.shiftr (cstZ ry y) (cstZ roffset ('offset))))) + = cstZZ r (ident.interp (ident.fancy_addc (Z.log2 s) (-offset)) (cstZ rc c, cstZ rx x, cstZ ry y))) + + ; (forall r rs s rc c rshiftr ry y roffset offset rx x, + s ∈ rs -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftr_good rshiftr ry offset -> range_in_bitwidth rshiftr s + -> cstZZ r (Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ rc c) (cstZ rshiftr (Z.shiftr (cstZ ry y) (cstZ roffset ('offset)))) (cstZ rx x)) + = cstZZ r (ident.interp (ident.fancy_addc (Z.log2 s) (-offset)) (cstZ rc c, cstZ rx x, cstZ ry y))) + + ; (forall r rs s rc c rx x ry y, + s ∈ rs -> (s = 2^Z.log2 s) -> range_in_bitwidth ry s + -> cstZZ r (Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ rc c) (cstZ rx x) (cstZ ry y)) + = cstZZ r (ident.interp (ident.fancy_addc (Z.log2 s) 0) (cstZ rc c, cstZ rx x, cstZ ry y))) + + (* +(Z.sub_get_borrow_concrete 2^256) @@ (?x, ?y << 128) --> (sub 128) @@ (x, y) +(Z.sub_get_borrow_concrete 2^256) @@ (?x, ?y >> 128) --> (sub (- 128)) @@ (x, y) +(Z.sub_get_borrow_concrete 2^256) @@ (?x, ?y) --> (sub 0) @@ (y, x) + *) - ; make_rewriteo - (pcst2 (#pattern.ident.Z_add_with_get_carry @ #?ℤ' @ ??' @ ??' @ pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ'))) - (fun '((r1, r2)%core) rs s rc c rx x rshiftr ry y roffset offset => cst2 (r1, r2)%core (#(ident.fancy_addc (Z.log2 s) (-offset)) @ (cst rc c, cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftr_good rshiftr ry offset && range_in_bitwidth rshiftr s && lit_good s rs && lit_good offset roffset) + ; (forall r rs s rx x rshiftl rland ry y rmask mask roffset offset, + s ∈ rs -> mask ∈ rmask -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftl_good rshiftl rland offset -> range_in_bitwidth rshiftl s -> land_good rland ry mask -> (mask = Z.ones (Z.log2 s - offset)) -> (0 <= offset <= Z.log2 s) + -> cstZZ r (Z.sub_get_borrow_full (cstZ rs ('s)) (cstZ rx x) (cstZ rshiftl (Z.shiftl (cstZ rland (Z.land (cstZ ry y) (cstZ rmask ('mask)))) (cstZ roffset ('offset))))) + = cstZZ r (ident.interp (ident.fancy_sub (Z.log2 s) offset) (cstZ rx x, cstZ ry y))) - ; make_rewriteo - (pcst2 (#pattern.ident.Z_add_with_get_carry @ #?ℤ' @ ??' @ pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ') @ ??')) - (fun '((r1, r2)%core) rs s rc c rshiftr ry y roffset offset rx x => cst2 (r1, r2)%core (#(ident.fancy_addc (Z.log2 s) (-offset)) @ (cst rc c, cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftr_good rshiftr ry offset && range_in_bitwidth rshiftr s && lit_good s rs && lit_good offset roffset) + ; (forall r rs s rx x rshiftr ry y roffset offset, + s ∈ rs -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftr_good rshiftr ry offset -> range_in_bitwidth rshiftr s + -> cstZZ r (Z.sub_get_borrow_full (cstZ rs ('s)) (cstZ rx x) (cstZ rshiftr (Z.shiftr (cstZ ry y) (cstZ roffset ('offset))))) + = cstZZ r (ident.interp (ident.fancy_sub (Z.log2 s) (-offset)) (cstZ rx x, cstZ ry y))) - ; make_rewriteo - (pcst2 (#pattern.ident.Z_add_with_get_carry @ #?ℤ' @ ??' @ ??' @ ??')) - (fun '((r1, r2)%core) rs s rc c rx x ry y => cst2 (r1, r2)%core (#(ident.fancy_addc (Z.log2 s) 0) @ (cst rc c, cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && range_in_bitwidth ry s && lit_good s rs) + ; (forall r rs s rx x ry y, + s ∈ rs -> (s = 2^Z.log2 s) -> range_in_bitwidth ry s + -> cstZZ r (Z.sub_get_borrow_full (cstZ rs ('s)) (cstZ rx x) (cstZ ry y)) + = cstZZ r (ident.interp (ident.fancy_sub (Z.log2 s) 0) (cstZ rx x, cstZ ry y))) -(* -(Z.sub_get_borrow_concrete 2^256) @@ (?x, ?y << 128) --> (sub 128) @@ (x, y) -(Z.sub_get_borrow_concrete 2^256) @@ (?x, ?y >> 128) --> (sub (- 128)) @@ (x, y) -(Z.sub_get_borrow_concrete 2^256) @@ (?x, ?y) --> (sub 0) @@ (y, x) - *) - ; make_rewriteo - (pcst2 (#pattern.ident.Z_sub_get_borrow @ #?ℤ' @ ??' @ pcst (#pattern.ident.Z_shiftl @ (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) @ #?ℤ'))) - (fun '((r1, r2)%core) rs s rx x rshiftl rland ry y rmask mask roffset offset => cst2 (r1, r2)%core (#(ident.fancy_sub (Z.log2 s) offset) @ (cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftl_good rshiftl rland offset && range_in_bitwidth rshiftl s && land_good rland ry mask && (mask =? Z.ones (Z.log2 s - offset)) && (0 <=? offset) && (offset <=? Z.log2 s) && lit_good s rs && lit_good mask rmask && lit_good offset roffset) - - ; make_rewriteo - (pcst2 (#pattern.ident.Z_sub_get_borrow @ #?ℤ' @ ??' @ pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ'))) - (fun '((r1, r2)%core) rs s rx x rshiftr ry y roffset offset => cst2 (r1, r2)%core (#(ident.fancy_sub (Z.log2 s) (-offset)) @ (cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftr_good rshiftr ry offset && range_in_bitwidth rshiftr s && lit_good s rs && lit_good offset roffset) - - ; make_rewriteo - (pcst2 (#pattern.ident.Z_sub_get_borrow @ #?ℤ' @ ??' @ ??')) - (fun '((r1, r2)%core) rs s rx x ry y => cst2 (r1, r2)%core (#(ident.fancy_sub (Z.log2 s) 0) @ (cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && range_in_bitwidth ry s && lit_good s rs) -(* + (* (Z.sub_with_get_borrow_concrete 2^256) @@ (?c, ?x, ?y << 128) --> (subb 128) @@ (c, x, y) (Z.sub_with_get_borrow_concrete 2^256) @@ (?c, ?x, ?y >> 128) --> (subb (- 128)) @@ (c, x, y) (Z.sub_with_get_borrow_concrete 2^256) @@ (?c, ?x, ?y) --> (subb 0) @@ (c, y, x) - *) - ; make_rewriteo - (pcst2 (#pattern.ident.Z_sub_with_get_borrow @ #?ℤ' @ ??' @ ??' @ pcst (#pattern.ident.Z_shiftl @ (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) @ #?ℤ'))) - (fun '((r1, r2)%core) rs s rb b rx x rshiftl rland ry y rmask mask roffset offset => cst2 (r1, r2)%core (#(ident.fancy_subb (Z.log2 s) offset) @ (cst rb b, cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftl_good rshiftl rland offset && range_in_bitwidth rshiftl s && land_good rland ry mask && (mask =? Z.ones (Z.log2 s - offset)) && (0 <=? offset) && (offset <=? Z.log2 s) && lit_good s rs && lit_good mask rmask && lit_good offset roffset) - - ; make_rewriteo - (pcst2 (#pattern.ident.Z_sub_with_get_borrow @ #?ℤ' @ ??' @ ??' @ pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ'))) - (fun '((r1, r2)%core) rs s rb b rx x rshiftr ry y roffset offset => cst2 (r1, r2)%core (#(ident.fancy_subb (Z.log2 s) (-offset)) @ (cst rb b, cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && shiftr_good rshiftr ry offset && range_in_bitwidth rshiftr s && lit_good s rs && lit_good offset roffset) - - ; make_rewriteo - (pcst2 (#pattern.ident.Z_sub_with_get_borrow @ #?ℤ' @ ??' @ ??' @ ??')) - (fun '((r1, r2)%core) rs s rb b rx x ry y => cst2 (r1, r2)%core (#(ident.fancy_subb (Z.log2 s) 0) @ (cst rb b, cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && range_in_bitwidth ry s && lit_good s rs) - - (*(Z.rshi_concrete 2^256 ?n) @@ (?c, ?x, ?y) --> (rshi n) @@ (x, y)*) - ; make_rewriteo - (pcst (#pattern.ident.Z_rshi @ #?ℤ' @ ??' @ ??' @ #?ℤ')) - (fun r rs s rx x ry y rn n => cst r (#(ident.fancy_rshi (Z.log2 s) n) @ (cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && lit_good s rs && lit_good n rn) -(* + *) + + ; (forall r rs s rb b rx x rshiftl rland ry y rmask mask roffset offset, + s ∈ rs -> mask ∈ rmask -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftl_good rshiftl rland offset -> range_in_bitwidth rshiftl s -> land_good rland ry mask -> (mask = Z.ones (Z.log2 s - offset)) -> (0 <= offset <= Z.log2 s) + -> cstZZ r (Z.sub_with_get_borrow_full (cstZ rs ('s)) (cstZ rb b) (cstZ rx x) (cstZ rshiftl (Z.shiftl (cstZ rland (Z.land (cstZ ry y) (cstZ rmask ('mask)))) (cstZ roffset ('offset))))) + = cstZZ r (ident.interp (ident.fancy_subb (Z.log2 s) offset) (cstZ rb b, cstZ rx x, cstZ ry y))) + + ; (forall r rs s rb b rx x rshiftr ry y roffset offset, + s ∈ rs -> offset ∈ roffset -> (s = 2^Z.log2 s) -> shiftr_good rshiftr ry offset -> range_in_bitwidth rshiftr s + -> cstZZ r (Z.sub_with_get_borrow_full (cstZ rs ('s)) (cstZ rb b) (cstZ rx x) (cstZ rshiftr (Z.shiftr (cstZ ry y) (cstZ roffset ('offset))))) + = cstZZ r (ident.interp (ident.fancy_subb (Z.log2 s) (-offset)) (cstZ rb b, cstZ rx x, cstZ ry y))) + + ; (forall r rs s rb b rx x ry y, + s ∈ rs -> (s = 2^Z.log2 s) -> range_in_bitwidth ry s + -> cstZZ r (Z.sub_with_get_borrow_full (cstZ rs ('s)) (cstZ rb b) (cstZ rx x) (cstZ ry y)) + = cstZZ r (ident.interp (ident.fancy_subb (Z.log2 s) 0) (cstZ rb b, cstZ rx x, cstZ ry y))) + + (*(Z.rshi_concrete 2^256 ?n) @@ (?c, ?x, ?y) --> (rshi n) @@ (x, y)*) + + ; (forall r rs s rx x ry y rn n, + s ∈ rs -> n ∈ rn -> (s = 2^Z.log2 s) + -> cstZ r (Z.rshi (cstZ rs ('s)) (cstZ rx x) (cstZ ry y) (cstZ rn ('n))) + = cstZ r (ident.interp (ident.fancy_rshi (Z.log2 s) n) (cstZ rx x, cstZ ry y))) + + (* Z.zselect @@ (Z.cc_m_concrete 2^256 ?c, ?x, ?y) --> selm @@ (c, x, y) Z.zselect @@ (?c &' 1, ?x, ?y) --> sell @@ (c, x, y) Z.zselect @@ (?c, ?x, ?y) --> selc @@ (c, x, y) - *) - ; make_rewriteo - (pcst (#pattern.ident.Z_zselect @ pcst (#pattern.ident.Z_cc_m @ #?ℤ' @ ??') @ ??' @ ??')) - (fun r rccm rs s rc c rx x ry y => cst r (#(ident.fancy_selm (Z.log2 s)) @ (cst rc c, cst rx x, cst ry y)) when (s =? 2^Z.log2 s) && cc_m_good rccm s rc && lit_good s rs) - - ; make_rewriteo - (pcst (#pattern.ident.Z_zselect @ pcst (#pattern.ident.Z_land @ #?ℤ' @ ??') @ ??' @ ??')) - (fun r rland rmask mask rc c rx x ry y => cst r (#ident.fancy_sell @ (cst rc c, cst rx x, cst ry y)) when (mask =? 1) && land_good rland mask rc && lit_good mask rmask) - - ; make_rewriteo - (pcst (#pattern.ident.Z_zselect @ pcst (#pattern.ident.Z_land @ ??' @ #?ℤ') @ ??' @ ??')) - (fun r rland rc c rmask mask rx x ry y => cst r (#ident.fancy_sell @ (cst rc c, cst rx x, cst ry y)) when (mask =? 1) && land_good rland rc mask && lit_good mask rmask) - - ; make_rewrite - (pcst (#pattern.ident.Z_zselect @ ?? @ ?? @ ??)) - (fun r c x y => cst r (#ident.fancy_selc @ (c, x, y))) - -(*Z.add_modulo @@ (?x, ?y, ?m) --> addm @@ (x, y, m)*) - ; make_rewrite - (#pattern.ident.Z_add_modulo @ ?? @ ?? @ ??) - (fun x y m => #ident.fancy_addm @ (x, y, m)) -(* + *) + ; (forall r rccm rs s rc c rx x ry y, + s ∈ rs -> (s = 2^Z.log2 s) -> cc_m_good rccm s rc + -> cstZ r (Z.zselect (cstZ rccm (Z.cc_m (cstZ rs ('s)) (cstZ rc c))) (cstZ rx x) (cstZ ry y)) + = cstZ r (ident.interp (ident.fancy_selm (Z.log2 s)) (cstZ rc c, cstZ rx x, cstZ ry y))) + + ; (forall r rland r1 rc c rx x ry y, + 1 ∈ r1 -> land_good rland 1 rc + -> cstZ r (Z.zselect (cstZ rland (cstZ r1 1 &' cstZ rc c)) (cstZ rx x) (cstZ ry y)) + = cstZ r (ident.interp ident.fancy_sell (cstZ rc c, cstZ rx x, cstZ ry y))) + + ; (forall r rland rc c r1 rx x ry y, + 1 ∈ r1 -> land_good rland rc 1 + -> cstZ r (Z.zselect (cstZ rland (cstZ rc c &' cstZ r1 1)) (cstZ rx x) (cstZ ry y)) + = cstZ r (ident.interp ident.fancy_sell (cstZ rc c, cstZ rx x, cstZ ry y))) + + ; (forall r c x y, + cstZ r (Z.zselect c x y) + = cstZ r (ident.interp ident.fancy_selc (c, x, y))) + + (*Z.add_modulo @@ (?x, ?y, ?m) --> addm @@ (x, y, m)*) + ; (forall x y m, + Z.add_modulo x y m + = ident.interp ident.fancy_addm (x, y, m)) + + (* Z.mul @@ (?x &' (2^128-1), ?y &' (2^128-1)) --> mulll @@ (x, y) Z.mul @@ (?x &' (2^128-1), ?y >> 128) --> mullh @@ (x, y) Z.mul @@ (?x >> 128, ?y &' (2^128-1)) --> mulhl @@ (x, y) Z.mul @@ (?x >> 128, ?y >> 128) --> mulhh @@ (x, y) - *) - (* literal on left *) - ; make_rewriteo - (#?ℤ' *' pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) - (fun r rx x rland ry y rmask mask => let s := (2*Z.log2_up mask)%Z in x <- invert_low s x; cst r (#(ident.fancy_mulll s) @ (##x, cst ry y)) when (mask =? 2^(s/2)-1) && land_good rland ry mask && lit_good x rx && lit_good mask rmask) - ; make_rewriteo - (#?ℤ' *' pcst (#pattern.ident.Z_land @ #?ℤ' @ ??')) - (fun r rx x rland rmask mask ry y => let s := (2*Z.log2_up mask)%Z in x <- invert_low s x; cst r (#(ident.fancy_mulll s) @ (##x, cst ry y)) when (mask =? 2^(s/2)-1) && land_good rland mask ry && lit_good x rx && lit_good mask rmask) - ; make_rewriteo - (#?ℤ' *' pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ')) - (fun r rx x rshiftr ry y roffset offset => let s := (2*offset)%Z in x <- invert_low s x; cst r (#(ident.fancy_mullh s) @ (##x, cst ry y)) when shiftr_good rshiftr ry offset && lit_good x rx && lit_good offset roffset) - ; make_rewriteo - (#?ℤ' *' pcst (#pattern.ident.Z_land @ #?ℤ' @ ??')) - (fun r rx x rland rmask mask ry y => let s := (2*Z.log2_up mask)%Z in x <- invert_high s x; cst r (#(ident.fancy_mulhl s) @ (##x, cst ry y)) when (mask =? 2^(s/2)-1) && land_good rland mask ry && lit_good x rx && lit_good mask rmask) - ; make_rewriteo - (#?ℤ' *' pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) - (fun r rx x rland ry y rmask mask => let s := (2*Z.log2_up mask)%Z in x <- invert_high s x; cst r (#(ident.fancy_mulhl s) @ (##x, cst ry y)) when (mask =? 2^(s/2)-1) && land_good rland ry mask && lit_good x rx && lit_good mask rmask) - ; make_rewriteo - (#?ℤ' *' pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ')) - (fun r rx x rshiftr ry y roffset offset => let s := (2*offset)%Z in x <- invert_high s x; cst r (#(ident.fancy_mulhh s) @ (##x, cst ry y)) when shiftr_good rshiftr ry offset && lit_good x rx && lit_good offset roffset) - (* literal on right *) - ; make_rewriteo - (pcst (#pattern.ident.Z_land @ #?ℤ' @ ??') *' #?ℤ') - (fun r rland rmask mask rx x ry y => let s := (2*Z.log2_up mask)%Z in y <- invert_low s y; cst r (#(ident.fancy_mulll s) @ (cst rx x, ##y)) when (mask =? 2^(s/2)-1) && land_good rland mask rx && lit_good y ry && lit_good mask rmask) - ; make_rewriteo - (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ') *' #?ℤ') - (fun r rland rx x rmask mask ry y => let s := (2*Z.log2_up mask)%Z in y <- invert_low s y; cst r (#(ident.fancy_mulll s) @ (cst rx x, ##y)) when (mask =? 2^(s/2)-1) && land_good rland rx mask && lit_good y ry && lit_good mask rmask) - ; make_rewriteo - (pcst (#pattern.ident.Z_land @ #?ℤ' @ ??') *' #?ℤ') - (fun r rland rmask mask rx x ry y => let s := (2*Z.log2_up mask)%Z in y <- invert_high s y; cst r (#(ident.fancy_mullh s) @ (cst rx x, ##y)) when (mask =? 2^(s/2)-1) && land_good rland mask rx && lit_good y ry && lit_good mask rmask) - ; make_rewriteo - (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ') *' #?ℤ') - (fun r rland rx x rmask mask ry y => let s := (2*Z.log2_up mask)%Z in y <- invert_high s y; cst r (#(ident.fancy_mullh s) @ (cst rx x, ##y)) when (mask =? 2^(s/2)-1) && land_good rland rx mask && lit_good y ry && lit_good mask rmask) - ; make_rewriteo - (pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ') *' #?ℤ') - (fun r rshiftr rx x roffset offset ry y => let s := (2*offset)%Z in y <- invert_low s y; cst r (#(ident.fancy_mulhl s) @ (cst rx x, ##y)) when shiftr_good rshiftr rx offset && lit_good y ry && lit_good offset roffset) - ; make_rewriteo - (pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ') *' #?ℤ') - (fun r rshiftr rx x roffset offset ry y => let s := (2*offset)%Z in y <- invert_high s y; cst r (#(ident.fancy_mulhh s) @ (cst rx x, ##y)) when shiftr_good rshiftr rx offset && lit_good y ry && lit_good offset roffset) - (* no literal *) - ; make_rewriteo - (pcst (#pattern.ident.Z_land @ #?ℤ' @ ??') *' pcst (#pattern.ident.Z_land @ #?ℤ' @ ??')) - (fun r rland1 rmask1 mask1 rx x rland2 rmask2 mask2 ry y => let s := (2*Z.log2_up mask1)%Z in cst r (#(ident.fancy_mulll s) @ (cst rx x, cst ry y)) when (mask1 =? 2^(s/2)-1) && (mask2 =? 2^(s/2)-1) && land_good rland1 mask1 rx && land_good rland2 mask2 ry && lit_good mask1 rmask1 && lit_good mask2 rmask2) - ; make_rewriteo - (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ') *' pcst (#pattern.ident.Z_land @ #?ℤ' @ ??')) - (fun r rland1 rx x rmask1 mask1 rland2 rmask2 mask2 ry y => let s := (2*Z.log2_up mask1)%Z in cst r (#(ident.fancy_mulll s) @ (cst rx x, cst ry y)) when (mask1 =? 2^(s/2)-1) && (mask2 =? 2^(s/2)-1) && land_good rland1 rx mask1 && land_good rland2 mask2 ry && lit_good mask1 rmask1 && lit_good mask2 rmask2) - ; make_rewriteo - (pcst (#pattern.ident.Z_land @ #?ℤ' @ ??') *' pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) - (fun r rland1 rmask1 mask1 rx x rland2 ry y rmask2 mask2 => let s := (2*Z.log2_up mask1)%Z in cst r (#(ident.fancy_mulll s) @ (cst rx x, cst ry y)) when (mask1 =? 2^(s/2)-1) && (mask2 =? 2^(s/2)-1) && land_good rland1 mask1 rx && land_good rland2 ry mask2 && lit_good mask1 rmask1 && lit_good mask2 rmask2) - ; make_rewriteo - (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ') *' pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) - (fun r rland1 rx x rmask1 mask1 rland2 ry y rmask2 mask2 => let s := (2*Z.log2_up mask1)%Z in cst r (#(ident.fancy_mulll s) @ (cst rx x, cst ry y)) when (mask1 =? 2^(s/2)-1) && (mask2 =? 2^(s/2)-1) && land_good rland1 rx mask1 && land_good rland2 ry mask2 && lit_good mask1 rmask1 && lit_good mask2 rmask2) - ; make_rewriteo - (pcst (#pattern.ident.Z_land @ #?ℤ' @ ??') *' pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ')) - (fun r rland1 rmask mask rx x rshiftr2 ry y roffset offset => let s := (2*offset)%Z in cst r (#(ident.fancy_mullh s) @ (cst rx x, cst ry y)) when (mask =? 2^(s/2)-1) && land_good rland1 mask rx && shiftr_good rshiftr2 ry offset && lit_good mask rmask && lit_good offset roffset) - ; make_rewriteo - (pcst (#pattern.ident.Z_land @ ??' @ #?ℤ') *' pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ')) - (fun r rland1 rx x rmask mask rshiftr2 ry y roffset offset => let s := (2*offset)%Z in cst r (#(ident.fancy_mullh s) @ (cst rx x, cst ry y)) when (mask =? 2^(s/2)-1) && land_good rland1 rx mask && shiftr_good rshiftr2 ry offset && lit_good mask rmask && lit_good offset roffset) - ; make_rewriteo - (pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ') *' pcst (#pattern.ident.Z_land @ #?ℤ' @ ??')) - (fun r rshiftr1 rx x roffset offset rland2 rmask mask ry y => let s := (2*offset)%Z in cst r (#(ident.fancy_mulhl s) @ (cst rx x, cst ry y)) when (mask =? 2^(s/2)-1) && shiftr_good rshiftr1 rx offset && land_good rland2 mask ry && lit_good mask rmask && lit_good offset roffset) - ; make_rewriteo - (pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ') *' pcst (#pattern.ident.Z_land @ ??' @ #?ℤ')) - (fun r rshiftr1 rx x roffset offset rland2 ry y rmask mask => let s := (2*offset)%Z in cst r (#(ident.fancy_mulhl s) @ (cst rx x, cst ry y)) when (mask =? 2^(s/2)-1) && shiftr_good rshiftr1 rx offset && land_good rland2 ry mask && lit_good mask rmask && lit_good offset roffset) - ; make_rewriteo - (pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ') *' pcst (#pattern.ident.Z_shiftr @ ??' @ #?ℤ')) - (fun r rshiftr1 rx x roffset1 offset1 rshiftr2 ry y roffset2 offset2 => let s := (2*offset1)%Z in cst r (#(ident.fancy_mulhh s) @ (cst rx x, cst ry y)) when (offset1 =? offset2) && shiftr_good rshiftr1 rx offset1 && shiftr_good rshiftr2 ry offset2 && lit_good offset1 roffset1 && lit_good offset2 roffset2) - - - - (** Dummy rule to make sure we use the two value ranges; this can be removed *) - ; make_rewriteo - (??') - (fun rx x => cst rx x when is_tighter_than_bool rx value_range || is_tighter_than_bool rx flag_range) - - ]. + *) + (* literal on left *) + ; (forall r rx x rland ry y rmask mask, + plet s := (2*Z.log2_up mask)%Z in + plet xo := invert_low s x in + plet xv := match xo with Some x => x | None => 0 end in + xo <> None -> x ∈ rx -> mask ∈ rmask -> (mask = 2^(s/2)-1) -> land_good rland ry mask + -> cstZ r (cstZ rx ('x) * cstZ rland (Z.land (cstZ ry y) (cstZ rmask ('mask)))) + = cstZ r (ident.interp (ident.fancy_mulll s) ('xv, cstZ ry y))) + + ; (forall r rx x rland rmask mask ry y, + plet s := (2*Z.log2_up mask)%Z in + plet xo := invert_low s x in + plet xv := match xo with Some x => x | None => 0 end in + xo <> None -> x ∈ rx -> mask ∈ rmask -> (mask = 2^(s/2)-1) -> land_good rland mask ry + -> cstZ r (cstZ rx ('x) * cstZ rland (Z.land (cstZ rmask ('mask)) (cstZ ry y))) + = cstZ r (ident.interp (ident.fancy_mulll s) ('xv, cstZ ry y))) + + ; (forall r rx x rshiftr ry y roffset offset, + plet s := (2*offset)%Z in + plet xo := invert_low s x in + plet xv := match xo with Some x => x | None => 0 end in + xo <> None -> x ∈ rx -> offset ∈ roffset -> shiftr_good rshiftr ry offset + -> cstZ r (cstZ rx ('x) * cstZ rshiftr (Z.shiftr (cstZ ry y) (cstZ roffset ('offset)))) + = cstZ r (ident.interp (ident.fancy_mullh s) ('xv, cstZ ry y))) + + ; (forall r rx x rland rmask mask ry y, + plet s := (2*Z.log2_up mask)%Z in + plet xo := invert_high s x in + plet xv := match xo with Some x => x | None => 0 end in + xo <> None -> x ∈ rx -> mask ∈ rmask -> (mask = 2^(s/2)-1) -> land_good rland mask ry + -> cstZ r (cstZ rx ('x) * cstZ rland (Z.land (cstZ rmask ('mask)) (cstZ ry y))) + = cstZ r (ident.interp (ident.fancy_mulhl s) ('xv, cstZ ry y))) + + ; (forall r rx x rland ry y rmask mask, + plet s := (2*Z.log2_up mask)%Z in + plet xo := invert_high s x in + plet xv := match xo with Some x => x | None => 0 end in + xo <> None -> x ∈ rx -> mask ∈ rmask -> (mask = 2^(s/2)-1) -> land_good rland ry mask + -> cstZ r (cstZ rx ('x) * cstZ rland (Z.land (cstZ ry y) (cstZ rmask ('mask)))) + = cstZ r (ident.interp (ident.fancy_mulhl s) ('xv, cstZ ry y))) + + ; (forall r rx x rshiftr ry y roffset offset, + plet s := (2*offset)%Z in + plet xo := invert_high s x in + plet xv := match xo with Some x => x | None => 0 end in + xo <> None -> x ∈ rx -> offset ∈ roffset -> shiftr_good rshiftr ry offset + -> cstZ r (cstZ rx ('x) * cstZ rshiftr (Z.shiftr (cstZ ry y) (cstZ roffset ('offset)))) + = cstZ r (ident.interp (ident.fancy_mulhh s) ('xv, cstZ ry y))) + + (* literal on right *) + ; (forall r rland rmask mask rx x ry y, + plet s := (2*Z.log2_up mask)%Z in + plet yo := invert_low s y in + plet yv := match yo with Some y => y | None => 0 end in + yo <> None -> y ∈ ry -> mask ∈ rmask -> (mask = 2^(s/2)-1) -> land_good rland mask rx + -> cstZ r (cstZ rland (Z.land (cstZ rmask ('mask)) (cstZ rx x)) * cstZ ry ('y)) + = cstZ r (ident.interp (ident.fancy_mulll s) (cstZ rx x, 'yv))) + + ; (forall r rland rx x rmask mask ry y, + plet s := (2*Z.log2_up mask)%Z in + plet yo := invert_low s y in + plet yv := match yo with Some y => y | None => 0 end in + yo <> None -> y ∈ ry -> mask ∈ rmask -> (mask = 2^(s/2)-1) -> land_good rland rx mask + -> cstZ r (cstZ rland (Z.land (cstZ rx x) (cstZ rmask ('mask))) * cstZ ry ('y)) + = cstZ r (ident.interp (ident.fancy_mulll s) (cstZ rx x, 'yv))) + + ; (forall r rland rmask mask rx x ry y, + plet s := (2*Z.log2_up mask)%Z in + plet yo := invert_high s y in + plet yv := match yo with Some y => y | None => 0 end in + yo <> None -> y ∈ ry -> mask ∈ rmask -> (mask = 2^(s/2)-1) -> land_good rland mask rx + -> cstZ r (cstZ rland (Z.land (cstZ rmask ('mask)) (cstZ rx x)) * cstZ ry ('y)) + = cstZ r (ident.interp (ident.fancy_mullh s) (cstZ rx x, 'yv))) + + ; (forall r rland rx x rmask mask ry y, + plet s := (2*Z.log2_up mask)%Z in + plet yo := invert_high s y in + plet yv := match yo with Some y => y | None => 0 end in + yo <> None -> y ∈ ry -> mask ∈ rmask -> (mask = 2^(s/2)-1) -> land_good rland rx mask + -> cstZ r (cstZ rland (Z.land (cstZ rx x) (cstZ rmask ('mask))) * cstZ ry ('y)) + = cstZ r (ident.interp (ident.fancy_mullh s) (cstZ rx x, 'yv))) + + ; (forall r rshiftr rx x roffset offset ry y, + plet s := (2*offset)%Z in + plet yo := invert_low s y in + plet yv := match yo with Some y => y | None => 0 end in + yo <> None -> y ∈ ry -> offset ∈ roffset -> shiftr_good rshiftr rx offset + -> cstZ r (cstZ rshiftr (Z.shiftr (cstZ rx x) (cstZ roffset ('offset))) * cstZ ry ('y)) + = cstZ r (ident.interp (ident.fancy_mulhl s) (cstZ rx x, 'yv))) + + ; (forall r rshiftr rx x roffset offset ry y, + plet s := (2*offset)%Z in + plet yo := invert_high s y in + plet yv := match yo with Some y => y | None => 0 end in + yo <> None -> y ∈ ry -> offset ∈ roffset -> shiftr_good rshiftr rx offset + -> cstZ r (cstZ rshiftr (Z.shiftr (cstZ rx x) (cstZ roffset ('offset))) * cstZ ry ('y)) + = cstZ r (ident.interp (ident.fancy_mulhh s) (cstZ rx x, 'yv))) + + (* no literal *) + ; (forall r rland1 rmask1 mask1 rx x rland2 rmask2 mask2 ry y, + plet s := (2*Z.log2_up mask1)%Z in + mask1 ∈ rmask1 -> mask2 ∈ rmask2 -> (mask1 = 2^(s/2)-1) -> (mask2 = 2^(s/2)-1) -> land_good rland1 mask1 rx -> land_good rland2 mask2 ry + -> cstZ r (cstZ rland1 (Z.land (cstZ rmask1 ('mask1)) (cstZ rx x)) * cstZ rland2 (Z.land (cstZ rmask2 ('mask2)) (cstZ ry y))) + = cstZ r (ident.interp (ident.fancy_mulll s) (cstZ rx x, cstZ ry y))) + + ; (forall r rland1 rx x rmask1 mask1 rland2 rmask2 mask2 ry y, + plet s := (2*Z.log2_up mask1)%Z in + mask1 ∈ rmask1 -> mask2 ∈ rmask2 -> (mask1 = 2^(s/2)-1) -> (mask2 = 2^(s/2)-1) -> land_good rland1 rx mask1 -> land_good rland2 mask2 ry + -> cstZ r (cstZ rland1 (Z.land (cstZ rx x) (cstZ rmask1 ('mask1))) * cstZ rland2 (Z.land (cstZ rmask2 ('mask2)) (cstZ ry y))) + = cstZ r (ident.interp (ident.fancy_mulll s) (cstZ rx x, cstZ ry y))) + + ; (forall r rland1 rmask1 mask1 rx x rland2 ry y rmask2 mask2, + plet s := (2*Z.log2_up mask1)%Z in + mask1 ∈ rmask1 -> mask2 ∈ rmask2 -> (mask1 = 2^(s/2)-1) -> (mask2 = 2^(s/2)-1) -> land_good rland1 mask1 rx -> land_good rland2 ry mask2 + -> cstZ r (cstZ rland1 (Z.land (cstZ rmask1 ('mask1)) (cstZ rx x)) * cstZ rland2 (Z.land (cstZ ry y) (cstZ rmask2 ('mask2)))) + = cstZ r (ident.interp (ident.fancy_mulll s) (cstZ rx x, cstZ ry y))) + + ; (forall r rland1 rx x rmask1 mask1 rland2 ry y rmask2 mask2, + plet s := (2*Z.log2_up mask1)%Z in + mask1 ∈ rmask1 -> mask2 ∈ rmask2 -> (mask1 = 2^(s/2)-1) -> (mask2 = 2^(s/2)-1) -> land_good rland1 rx mask1 -> land_good rland2 ry mask2 + -> cstZ r (cstZ rland1 (Z.land (cstZ rx x) (cstZ rmask1 ('mask1))) * cstZ rland2 (Z.land (cstZ ry y) (cstZ rmask2 ('mask2)))) + = cstZ r (ident.interp (ident.fancy_mulll s) (cstZ rx x, cstZ ry y))) + + ; (forall r rland1 rmask mask rx x rshiftr2 ry y roffset offset, + plet s := (2*offset)%Z in + mask ∈ rmask -> offset ∈ roffset -> (mask = 2^(s/2)-1) -> land_good rland1 mask rx -> shiftr_good rshiftr2 ry offset + -> cstZ r (cstZ rland1 (Z.land (cstZ rmask ('mask)) (cstZ rx x)) * cstZ rshiftr2 (Z.shiftr (cstZ ry y) (cstZ roffset ('offset)))) + = cstZ r (ident.interp (ident.fancy_mullh s) (cstZ rx x, cstZ ry y))) + + ; (forall r rland1 rx x rmask mask rshiftr2 ry y roffset offset, + plet s := (2*offset)%Z in + mask ∈ rmask -> offset ∈ roffset -> (mask = 2^(s/2)-1) -> land_good rland1 rx mask -> shiftr_good rshiftr2 ry offset + -> cstZ r (cstZ rland1 (Z.land (cstZ rx x) (cstZ rmask ('mask))) * cstZ rshiftr2 (Z.shiftr (cstZ ry y) (cstZ roffset ('offset)))) + = cstZ r (ident.interp (ident.fancy_mullh s) (cstZ rx x, cstZ ry y))) + + ; (forall r rshiftr1 rx x roffset offset rland2 rmask mask ry y, + plet s := (2*offset)%Z in + mask ∈ rmask -> offset ∈ roffset -> (mask = 2^(s/2)-1) -> shiftr_good rshiftr1 rx offset -> land_good rland2 mask ry + -> cstZ r (cstZ rshiftr1 (Z.shiftr (cstZ rx x) (cstZ roffset ('offset))) * cstZ rland2 (Z.land (cstZ rmask ('mask)) (cstZ ry y))) + = cstZ r (ident.interp (ident.fancy_mulhl s) (cstZ rx x, cstZ ry y))) + + ; (forall r rshiftr1 rx x roffset offset rland2 ry y rmask mask, + plet s := (2*offset)%Z in + mask ∈ rmask -> offset ∈ roffset -> (mask = 2^(s/2)-1) -> shiftr_good rshiftr1 rx offset -> land_good rland2 ry mask + -> cstZ r (cstZ rshiftr1 (Z.shiftr (cstZ rx x) (cstZ roffset ('offset))) * cstZ rland2 (Z.land (cstZ ry y) (cstZ rmask ('mask)))) + = cstZ r (ident.interp (ident.fancy_mulhl s) (cstZ rx x, cstZ ry y))) + + ; (forall r rshiftr1 rx x roffset1 offset1 rshiftr2 ry y roffset2 offset2, + plet s := (2*offset1)%Z in + offset1 ∈ roffset1 -> offset2 ∈ roffset2 -> (offset1 = offset2) -> shiftr_good rshiftr1 rx offset1 -> shiftr_good rshiftr2 ry offset2 + -> cstZ r (cstZ rshiftr1 (Z.shiftr (cstZ rx x) (cstZ roffset1 ('offset1))) * cstZ rshiftr2 (Z.shiftr (cstZ ry y) (cstZ roffset2 ('offset2)))) + = cstZ r (ident.interp (ident.fancy_mulhh s) (cstZ rx x, cstZ ry y))) + + (** Dummy rule to make sure we use the two value ranges; this can be removed *) + ; (forall rx x, + ((is_tighter_than_bool rx value_range = true) + \/ (is_tighter_than_bool rx flag_range = true)) + -> cstZ rx x = cstZ rx x) + ]%Z%zrange + ]. Definition fancy_dtree' := Eval compute in @compile_rewrites ident var pattern.ident (@pattern.ident.arg_types) pattern.Raw.ident (@pattern.ident.strip_types) pattern.Raw.ident.ident_beq 100 fancy_rewrite_rules. diff --git a/src/RewriterRulesInterpGood.v b/src/RewriterRulesInterpGood.v index a2aa56a36..954340fa4 100644 --- a/src/RewriterRulesInterpGood.v +++ b/src/RewriterRulesInterpGood.v @@ -475,7 +475,10 @@ Module Compilers. | [ |- ?x = ?x ] => reflexivity | [ |- True ] => exact I | [ H : ?x = true, H' : ?x = false |- _ ] => exfalso; clear -H H'; congruence + | [ H : true = false |- _ ]=> exfalso; clear -H; congruence + | [ H : false = true |- _ ]=> exfalso; clear -H; congruence end + | progress cbv [option_beq] in * | match goal with | [ H : context[ZRange.normalize (ZRange.normalize _)] |- _ ] => rewrite ZRange.normalize_idempotent in H @@ -689,6 +692,8 @@ Module Compilers. |- context[?v mod ?m] ] => unique assert (is_bounded_by_bool v r[0~>x-1] = true) by (eapply ZRange.is_bounded_by_of_is_tighter_than; eassumption) + | _ => progress Z.ltb_to_lt + | _ => progress subst end. Local Ltac unfold_cast_lemmas := -- cgit v1.2.3