diff options
author | Andres Erbsen <andreser@mit.edu> | 2019-01-08 04:21:38 -0500 |
---|---|---|
committer | Andres Erbsen <andreser@mit.edu> | 2019-01-09 22:49:02 -0500 |
commit | 3ca227f1137e6a3b65bc33f5689e1c230d591595 (patch) | |
tree | e1e5a2dd2a2f34f239d3276227ddbdc69eeeb667 /src/Experiments/SimplyTypedArithmetic.v | |
parent | 3ec21c64b3682465ca8e159a187689b207c71de4 (diff) |
remove old pipeline
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 12784 |
1 files changed, 0 insertions, 12784 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v deleted file mode 100644 index 92799222e..000000000 --- a/src/Experiments/SimplyTypedArithmetic.v +++ /dev/null @@ -1,12784 +0,0 @@ -(* Following http://adam.chlipala.net/theses/andreser.pdf chapter 3 *) -Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz. -Require Import Coq.Strings.String. -Require Import Coq.MSets.MSetPositive. -Require Import Coq.FSets.FMapPositive. -Require Import Coq.derive.Derive. -Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable. -Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn. -Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil. -Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. -Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. -Require Import Crypto.Algebra.Ring. -Require Import Crypto.Algebra.SubsetoidRing. -Require Import Crypto.Arithmetic.PrimeFieldTheorems. -Require Import Crypto.Arithmetic.BarrettReduction.Generalized. -Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. -Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. -Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. -Require Import Crypto.Util.ZRange. -Require Import Crypto.Util.ZRange.Operations. -Require Import Crypto.Util.Tactics.RunTacticAsConstr. -Require Import Crypto.Util.Tactics.Head. -Require Import Crypto.Util.Option. -Require Import Crypto.Util.OptionList. -Require Import Crypto.Util.Sum. -Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div Crypto.Util.ZUtil.Hints.Core. -Require Import Crypto.Util.ZUtil.Hints.PullPush. -Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. -Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. -Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. -Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. -Require Import Crypto.Util.ZUtil.Le. -Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. -Require Import Crypto.Util.ZUtil.Log2. -Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. -Require Import Crypto.Util.ZUtil.Notations. -Require Import Crypto.Util.ZUtil.Shift. -Require Import Crypto.Util.ZUtil.LandLorShiftBounds. -Require Import Crypto.Util.ZUtil.Testbit. -Require Import Crypto.Util.ZUtil.Notations. -Require Import Crypto.Util.Tactics.SpecializeBy. -Require Import Crypto.Util.Tactics.SplitInContext. -Require Import Crypto.Util.Tactics.SubstEvars. -Require Crypto.Util.Strings.String. -Require Import Crypto.Util.Strings.Decimal. -Require Import Crypto.Util.Strings.HexString. -Require Import Crypto.Util.Notations. -Require Import Crypto.Util.ZUtil.Definitions. -Require Import Crypto.Util.ZUtil.CC Crypto.Util.ZUtil.Rshi. -Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo. -Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. -Require Import Crypto.Util.ZUtil.Hints.Core. -Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div. -Require Import Crypto.Util.ZUtil.Hints.PullPush. -Require Import Crypto.Util.ZUtil.EquivModulo. -Require Import Crypto.Util.Tactics.DebugPrint. -Require Import Crypto.Util.CPSNotations. -Require Import Crypto.Util.Equality. -Import ListNotations. Local Open Scope Z_scope. - -Module Associational. - Definition eval (p:list (Z*Z)) : Z := - fold_right (fun x y => x + y) 0%Z (map (fun t => fst t * snd t) p). - - Lemma eval_nil : eval nil = 0. - Proof. trivial. Qed. - Lemma eval_cons p q : eval (p::q) = fst p * snd p + eval q. - Proof. trivial. Qed. - Lemma eval_app p q: eval (p++q) = eval p + eval q. - Proof. induction p; rewrite <-?List.app_comm_cons; - rewrite ?eval_nil, ?eval_cons; nsatz. Qed. - - Hint Rewrite eval_nil eval_cons eval_app : push_eval. - Local Ltac push := autorewrite with - push_eval push_map push_partition push_flat_map - push_fold_right push_nth_default cancel_pair. - - Lemma eval_map_mul (a x:Z) (p:list (Z*Z)) - : eval (List.map (fun t => (a*fst t, x*snd t)) p) = a*x*eval p. - Proof. induction p; push; nsatz. Qed. - Hint Rewrite eval_map_mul : push_eval. - - Definition mul (p q:list (Z*Z)) : list (Z*Z) := - flat_map (fun t => - map (fun t' => - (fst t * fst t', snd t * snd t')) - q) p. - Lemma eval_mul p q : eval (mul p q) = eval p * eval q. - Proof. induction p; cbv [mul]; push; nsatz. Qed. - Hint Rewrite eval_mul : push_eval. - - Definition negate_snd (p:list (Z*Z)) : list (Z*Z) := - map (fun cx => (fst cx, -snd cx)) p. - Lemma eval_negate_snd p : eval (negate_snd p) = - eval p. - Proof. induction p; cbv [negate_snd]; push; nsatz. Qed. - Hint Rewrite eval_negate_snd : push_eval. - - Example base10_2digit_mul (a0:Z) (a1:Z) (b0:Z) (b1:Z) : - {ab| eval ab = eval [(10,a1);(1,a0)] * eval [(10,b1);(1,b0)]}. - eexists ?[ab]. - (* Goal: eval ?ab = eval [(10,a1);(1,a0)] * eval [(10,b1);(1,b0)] *) - rewrite <-eval_mul. - (* Goal: eval ?ab = eval (mul [(10,a1);(1,a0)] [(10,b1);(1,b0)]) *) - cbv -[Z.mul eval]; cbn -[eval]. - (* Goal: eval ?ab = eval [(100,(a1*b1));(10,a1*b0);(10,a0*b1);(1,a0*b0)]%RT *) - trivial. Defined. - - Definition split (s:Z) (p:list (Z*Z)) : list (Z*Z) * list (Z*Z) - := let hi_lo := partition (fun t => fst t mod s =? 0) p in - (snd hi_lo, map (fun t => (fst t / s, snd t)) (fst hi_lo)). - Lemma eval_split s p (s_nz:s<>0) : - eval (fst (split s p)) + s * eval (snd (split s p)) = eval p. - Proof. cbv [Let_In split]; induction p; - repeat match goal with - | |- context[?a/?b] => - unique pose proof (Z_div_exact_full_2 a b ltac:(trivial) ltac:(trivial)) - | _ => progress push - | _ => progress break_match - | _ => progress nsatz end. Qed. - - Lemma reduction_rule a b s c (modulus_nz:s-c<>0) : - (a + s * b) mod (s - c) = (a + c * b) mod (s - c). - Proof. replace (a + s * b) with ((a + c*b) + b*(s-c)) by nsatz. - rewrite Z.add_mod,Z_mod_mult,Z.add_0_r,Z.mod_mod;trivial. Qed. - - Definition reduce (s:Z) (c:list _) (p:list _) : list (Z*Z) := - let lo_hi := split s p in fst lo_hi ++ mul c (snd lo_hi). - - Lemma eval_reduce s c p (s_nz:s<>0) (modulus_nz:s-eval c<>0) : - eval (reduce s c p) mod (s - eval c) = eval p mod (s - eval c). - Proof. cbv [reduce]; push. - rewrite <-reduction_rule, eval_split; trivial. Qed. - Hint Rewrite eval_reduce : push_eval. - - Definition bind_snd (p : list (Z*Z)) := - map (fun t => dlet_nd t2 := snd t in (fst t, t2)) p. - - Lemma bind_snd_correct p : bind_snd p = p. - Proof. - cbv [bind_snd]; induction p as [| [? ?] ]; - push; [|rewrite IHp]; reflexivity. - Qed. - - Lemma eval_rev p : eval (rev p) = eval p. - Proof. induction p; cbn [rev]; push; lia. Qed. - - Section Carries. - Definition carryterm (w fw:Z) (t:Z * Z) := - if (Z.eqb (fst t) w) - then dlet_nd t2 := snd t in - dlet_nd d2 := t2 / fw in - dlet_nd m2 := t2 mod fw in - [(w * fw, d2);(w,m2)] - else [t]. - - Lemma eval_carryterm w fw (t:Z * Z) (fw_nonzero:fw<>0): - eval (carryterm w fw t) = eval [t]. - Proof using Type*. - cbv [carryterm Let_In]; break_match; push; [|trivial]. - pose proof (Z.div_mod (snd t) fw fw_nonzero). - rewrite Z.eqb_eq in *. - nsatz. - Qed. Hint Rewrite eval_carryterm using auto : push_eval. - - Definition carry (w fw:Z) (p:list (Z * Z)):= - flat_map (carryterm w fw) p. - - Lemma eval_carry w fw p (fw_nonzero:fw<>0): - eval (carry w fw p) = eval p. - Proof using Type*. cbv [carry]; induction p; push; nsatz. Qed. - Hint Rewrite eval_carry using auto : push_eval. - End Carries. -End Associational. - -Module Positional. Section Positional. - Context (weight : nat -> Z) - (weight_0 : weight 0%nat = 1) - (weight_nz : forall i, weight i <> 0). - - Definition to_associational (n:nat) (xs:list Z) : list (Z*Z) - := combine (map weight (List.seq 0 n)) xs. - Definition eval n x := Associational.eval (@to_associational n x). - Lemma eval_to_associational n x : - Associational.eval (@to_associational n x) = eval n x. - Proof. trivial. Qed. - Hint Rewrite @eval_to_associational : push_eval. - Lemma eval_nil n : eval n [] = 0. - Proof. cbv [eval to_associational]. rewrite combine_nil_r. reflexivity. Qed. - Hint Rewrite eval_nil : push_eval. - Lemma eval0 p : eval 0 p = 0. - Proof. cbv [eval to_associational]. reflexivity. Qed. - Hint Rewrite eval0 : push_eval. - - Lemma eval_snoc n m x y : n = length x -> m = S n -> eval m (x ++ [y]) = eval n x + weight n * y. - Proof. - cbv [eval to_associational]; intros; subst n m. - rewrite seq_snoc, map_app. - rewrite combine_app_samelength by distr_length. - autorewrite with push_eval. simpl. - autorewrite with push_eval cancel_pair; ring. - Qed. - - (* SKIP over this: zeros, add_to_nth *) - Local Ltac push := autorewrite with push_eval push_map distr_length - push_flat_map push_fold_right push_nth_default cancel_pair natsimplify. - Definition zeros n : list Z := repeat 0 n. - Lemma length_zeros n : length (zeros n) = n. Proof. cbv [zeros]; distr_length. Qed. - Hint Rewrite length_zeros : distr_length. - Lemma eval_zeros n : eval n (zeros n) = 0. - Proof. - cbv [eval Associational.eval to_associational zeros]. - rewrite <- (seq_length n 0) at 2. - generalize dependent (List.seq 0 n); intro xs. - induction xs; simpl; nsatz. Qed. - Definition add_to_nth i x (ls : list Z) : list Z - := ListUtil.update_nth i (fun y => x + y) ls. - Lemma length_add_to_nth i x ls : length (add_to_nth i x ls) = length ls. - Proof. cbv [add_to_nth]; distr_length. Qed. - Hint Rewrite length_add_to_nth : distr_length. - Lemma eval_add_to_nth (n:nat) (i:nat) (x:Z) (xs:list Z) (H:(i<length xs)%nat) - (Hn : length xs = n) (* N.B. We really only need [i < Nat.min n (length xs)] *) : - eval n (add_to_nth i x xs) = weight i * x + eval n xs. - Proof. - subst n. - cbv [eval to_associational add_to_nth]. - rewrite ListUtil.combine_update_nth_r at 1. - rewrite <-(update_nth_id i (List.combine _ _)) at 2. - rewrite <-!(ListUtil.splice_nth_equiv_update_nth_update _ _ - (weight 0, 0)) by (push; lia); cbv [ListUtil.splice_nth id]. - repeat match goal with - | _ => progress push - | _ => progress break_match - | _ => progress (apply Zminus_eq; ring_simplify) - | _ => rewrite <-ListUtil.map_nth_default_always - end; lia. Qed. - Hint Rewrite @eval_add_to_nth eval_zeros : push_eval. - - Definition place (t:Z*Z) (i:nat) : nat * Z := - nat_rect - (fun _ => (nat * Z)%type) - (O, fst t * snd t) - (fun i' place_i' - => let i := S i' in - if (fst t mod weight i =? 0) - then (i, let c := fst t / weight i in c * snd t) - else place_i') - i. - - Lemma place_in_range (t:Z*Z) (n:nat) : (fst (place t n) < S n)%nat. - Proof. induction n; cbv [place nat_rect] in *; break_match; autorewrite with cancel_pair; try omega. Qed. - Lemma weight_place t i : weight (fst (place t i)) * snd (place t i) = fst t * snd t. - Proof. induction i; cbv [place nat_rect] in *; break_match; push; - repeat match goal with |- context[?a/?b] => - unique pose proof (Z_div_exact_full_2 a b ltac:(auto) ltac:(auto)) - end; nsatz. Qed. - Hint Rewrite weight_place : push_eval. - - Definition from_associational n (p:list (Z*Z)) := - List.fold_right (fun t ls => - dlet_nd p := place t (pred n) in - add_to_nth (fst p) (snd p) ls ) (zeros n) p. - Lemma eval_from_associational n p (n_nz:n<>O \/ p = nil) : - eval n (from_associational n p) = Associational.eval p. - Proof. destruct n_nz; [ induction p | subst p ]; - cbv [from_associational Let_In] in *; push; try - pose proof place_in_range a (pred n); try omega; try nsatz; - apply fold_right_invariant; cbv [zeros add_to_nth]; - intros; rewrite ?map_length, ?List.repeat_length, ?seq_length, ?length_update_nth; - try omega. Qed. - Hint Rewrite @eval_from_associational : push_eval. - Lemma length_from_associational n p : length (from_associational n p) = n. - Proof. cbv [from_associational Let_In]. apply fold_right_invariant; intros; distr_length. Qed. - Hint Rewrite length_from_associational : distr_length. - - Section mulmod. - Context (s:Z) (s_nz:s <> 0) - (c:list (Z*Z)) - (m_nz:s - Associational.eval c <> 0). - Definition mulmod (n:nat) (a b:list Z) : list Z - := let a_a := to_associational n a in - let b_a := to_associational n b in - let ab_a := Associational.mul a_a b_a in - let abm_a := Associational.reduce s c ab_a in - from_associational n abm_a. - Lemma eval_mulmod n (f g:list Z) - (Hf : length f = n) (Hg : length g = n) : - eval n (mulmod n f g) mod (s - Associational.eval c) - = (eval n f * eval n g) mod (s - Associational.eval c). - Proof. cbv [mulmod]; push; trivial. - destruct f, g; simpl in *; [ right; subst n | left; try omega.. ]. - clear; cbv -[Associational.reduce]. - induction c as [|?? IHc]; simpl; trivial. Qed. - End mulmod. - Hint Rewrite @eval_mulmod : push_eval. - - Definition add (n:nat) (a b:list Z) : list Z - := let a_a := to_associational n a in - let b_a := to_associational n b in - from_associational n (a_a ++ b_a). - Lemma eval_add n (f g:list Z) - (Hf : length f = n) (Hg : length g = n) : - eval n (add n f g) = (eval n f + eval n g). - Proof. cbv [add]; push; trivial. destruct n; auto. Qed. - Hint Rewrite @eval_add : push_eval. - Lemma length_add n f g - (Hf : length f = n) (Hg : length g = n) : - length (add n f g) = n. - Proof. clear -Hf Hf; cbv [add]; distr_length. Qed. - Hint Rewrite @length_add : distr_length. - - Section Carries. - Definition carry n m (index:nat) (p:list Z) : list Z := - from_associational - m (@Associational.carry (weight index) - (weight (S index) / weight index) - (to_associational n p)). - - Lemma length_carry n m index p : length (carry n m index p) = m. - Proof. cbv [carry]; distr_length. Qed. - Lemma eval_carry n m i p: (n <> 0%nat) -> (m <> 0%nat) -> - weight (S i) / weight i <> 0 -> - eval m (carry n m i p) = eval n p. - Proof. - cbv [carry]; intros; push; [|tauto]. - rewrite @Associational.eval_carry by eauto. - apply eval_to_associational. - Qed. Hint Rewrite @eval_carry : push_eval. - - Definition carry_reduce n (s:Z) (c:list (Z * Z)) - (index:nat) (p : list Z) := - from_associational - n (Associational.reduce - s c (to_associational (S n) (@carry n (S n) index p))). - - Lemma eval_carry_reduce n s c index p : - (s <> 0) -> (s - Associational.eval c <> 0) -> (n <> 0%nat) -> - (weight (S index) / weight index <> 0) -> - eval n (carry_reduce n s c index p) mod (s - Associational.eval c) - = eval n p mod (s - Associational.eval c). - Proof. cbv [carry_reduce]; intros; push; auto. Qed. - Hint Rewrite @eval_carry_reduce : push_eval. - Lemma length_carry_reduce n s c index p - : length p = n -> length (carry_reduce n s c index p) = n. - Proof. cbv [carry_reduce]; distr_length. Qed. - Hint Rewrite @length_carry_reduce : distr_length. - - (* N.B. It is important to reverse [idxs] here, because fold_right is - written such that the first terms in the list are actually used - last in the computation. For example, running: - - `Eval cbv - [Z.add] in (fun a b c d => fold_right Z.add d [a;b;c]).` - - will produce [fun a b c d => (a + (b + (c + d)))].*) - Definition chained_carries n s c p (idxs : list nat) := - fold_right (fun a b => carry_reduce n s c a b) p (rev idxs). - - Lemma eval_chained_carries n s c p idxs : - (s <> 0) -> (s - Associational.eval c <> 0) -> (n <> 0%nat) -> - (forall i, In i idxs -> weight (S i) / weight i <> 0) -> - eval n (chained_carries n s c p idxs) mod (s - Associational.eval c) - = eval n p mod (s - Associational.eval c). - Proof using Type*. - cbv [chained_carries]; intros; push. - apply fold_right_invariant; [|intro; rewrite <-in_rev]; - destruct n; intros; push; auto. - Qed. Hint Rewrite @eval_chained_carries : push_eval. - Lemma length_chained_carries n s c p idxs - : length p = n -> length (@chained_carries n s c p idxs) = n. - Proof. - intros; cbv [chained_carries]; induction (rev idxs) as [|x xs IHxs]; - cbn [fold_right]; distr_length. - Qed. Hint Rewrite @length_chained_carries : distr_length. - - (* carries without modular reduction; useful for converting between bases *) - Definition chained_carries_no_reduce n p (idxs : list nat) := - fold_right (fun a b => carry n n a b) p (rev idxs). - Lemma eval_chained_carries_no_reduce n p idxs: - (forall i, In i idxs -> weight (S i) / weight i <> 0) -> - eval n (chained_carries_no_reduce n p idxs) = eval n p. - Proof. - cbv [chained_carries_no_reduce]; intros. - destruct n; [push;reflexivity|]. - apply fold_right_invariant; [|intro; rewrite <-in_rev]; - intros; push; auto. - Qed. Hint Rewrite @eval_chained_carries_no_reduce : push_eval. - - (* Reverse of [eval]; translate from Z to basesystem by putting - everything in first digit and then carrying. *) - Definition encode n s c (x : Z) : list Z := - chained_carries n s c (from_associational n [(1,x)]) (seq 0 n). - Lemma eval_encode n s c x : - (s <> 0) -> (s - Associational.eval c <> 0) -> (n <> 0%nat) -> - (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) -> - eval n (encode n s c x) mod (s - Associational.eval c) - = x mod (s - Associational.eval c). - Proof using Type*. cbv [encode]; intros; push; auto; f_equal; omega. Qed. - Lemma length_encode n s c x - : length (encode n s c x) = n. - Proof. cbv [encode]; repeat distr_length. Qed. - - End Carries. - Hint Rewrite @eval_encode : push_eval. - Hint Rewrite @length_encode : distr_length. - - Section sub. - Context (n:nat) - (s:Z) (s_nz:s <> 0) - (c:list (Z * Z)) - (m_nz:s - Associational.eval c <> 0) - (coef:Z). - - Definition negate_snd (a:list Z) : list Z - := let A := to_associational n a in - let negA := Associational.negate_snd A in - from_associational n negA. - - Definition scmul (x:Z) (a:list Z) : list Z - := let A := to_associational n a in - let R := Associational.mul A [(1, x)] in - from_associational n R. - - Definition balance : list Z - := scmul coef (encode n s c (s - Associational.eval c)). - - Definition sub (a b:list Z) : list Z - := let ca := add n balance a in - let _b := negate_snd b in - add n ca _b. - Lemma eval_sub a b - : (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) -> - (List.length a = n) -> (List.length b = n) -> - eval n (sub a b) mod (s - Associational.eval c) - = (eval n a - eval n b) mod (s - Associational.eval c). - Proof. - destruct (zerop n); subst; try reflexivity. - intros; cbv [sub balance scmul negate_snd]; push; repeat distr_length; - eauto with omega. - push_Zmod; push; pull_Zmod; push_Zmod; pull_Zmod; distr_length; eauto. - Qed. - Hint Rewrite eval_sub : push_eval. - Lemma length_sub a b - : length a = n -> length b = n -> - length (sub a b) = n. - Proof. intros; cbv [sub balance scmul negate_snd]; repeat distr_length. Qed. - Hint Rewrite length_sub : distr_length. - Definition opp (a:list Z) : list Z - := sub (zeros n) a. - Lemma eval_opp - (a:list Z) - : (length a = n) -> - (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) -> - eval n (opp a) mod (s - Associational.eval c) - = (- eval n a) mod (s - Associational.eval c). - Proof. intros; cbv [opp]; push; distr_length; auto. Qed. - Lemma length_opp a - : length a = n -> length (opp a) = n. - Proof. cbv [opp]; intros; repeat distr_length. Qed. - End sub. - Hint Rewrite @eval_opp @eval_sub : push_eval. - Hint Rewrite @length_sub @length_opp : distr_length. -End Positional. -(* Hint Rewrite disappears after the end of a section *) -Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp : distr_length. -End Positional. - -Record weight_properties {weight : nat -> Z} := - { - weight_0 : weight 0%nat = 1; - weight_positive : forall i, 0 < weight i; - weight_multiples : forall i, weight (S i) mod weight i = 0; - weight_divides : forall i : nat, 0 < weight (S i) / weight i; - }. -Hint Resolve weight_0 weight_positive weight_multiples weight_divides. - -Section mod_ops. - Import Positional. - Local Coercion Z.of_nat : nat >-> Z. - Local Coercion QArith_base.inject_Z : Z >-> Q. - (* Design constraints: - - inputs must be [Z] (b/c reification does not support Q) - - internal structure must not match on the arguments (b/c reification does not support [positive]) *) - Context (limbwidth_num limbwidth_den : Z) - (limbwidth_good : 0 < limbwidth_den <= limbwidth_num) - (s : Z) - (c : list (Z*Z)) - (n : nat) - (len_c : nat) - (idxs : list nat) - (len_idxs : nat) - (m_nz:s - Associational.eval c <> 0) (s_nz:s <> 0) - (Hn_nz : n <> 0%nat) - (Hc : length c = len_c) - (Hidxs : length idxs = len_idxs). - Definition weight (i : nat) - := 2^(-(-(limbwidth_num * i) / limbwidth_den)). - - Local Ltac Q_cbv := - cbv [Qceiling inject_Z Qle Qfloor Qdiv Qnum Qden Qmult Qinv Qopp]. - - Local Lemma weight_ZQ_correct i - (limbwidth := (limbwidth_num / limbwidth_den)%Q) - : weight i = 2^Qceiling(limbwidth*i). - Proof. - clear -limbwidth_good. - cbv [limbwidth weight]; Q_cbv. - destruct limbwidth_num, limbwidth_den, i; try reflexivity; - repeat rewrite ?Pos.mul_1_l, ?Pos.mul_1_r, ?Z.mul_0_l, ?Zdiv_0_l, ?Zdiv_0_r, ?Z.mul_1_l, ?Z.mul_1_r, <- ?Z.opp_eq_mul_m1, ?Pos2Z.opp_pos; - try reflexivity; try lia. - Qed. - - Local Ltac t_weight_with lem := - clear -limbwidth_good; - intros; rewrite !weight_ZQ_correct; - apply lem; - try omega; Q_cbv; destruct limbwidth_den; cbn; try lia. - - Definition wprops : @weight_properties weight. - Proof. - constructor. - { cbv [weight Z.of_nat]; autorewrite with zsimplify_fast; reflexivity. } - { intros; apply Z.gt_lt. t_weight_with (@pow_ceil_mul_nat_pos 2). } - { t_weight_with (@pow_ceil_mul_nat_multiples 2). } - { intros; apply Z.gt_lt. t_weight_with (@pow_ceil_mul_nat_divide 2). } - Defined. - Local Hint Immediate (weight_0 wprops). - Local Hint Immediate (weight_positive wprops). - Local Hint Immediate (weight_multiples wprops). - Local Hint Immediate (weight_divides wprops). - Local Hint Resolve Z.positive_is_nonzero Z.lt_gt. - - Local Lemma weight_1_gt_1 : weight 1 > 1. - Proof. - clear -limbwidth_good. - cut (1 < weight 1); [ lia | ]. - cbv [weight Z.of_nat]; autorewrite with zsimplify_fast. - apply Z.pow_gt_1; [ omega | ]. - Z.div_mod_to_quot_rem_in_goal; nia. - Qed. - - Derive carry_mulmod - SuchThat (forall (f g : list Z) - (Hf : length f = n) - (Hg : length g = n), - (eval weight n (carry_mulmod f g)) mod (s - Associational.eval c) - = (eval weight n f * eval weight n g) mod (s - Associational.eval c)) - As eval_carry_mulmod. - Proof. - intros. - rewrite <-eval_mulmod with (s:=s) (c:=c) by auto. - etransitivity; - [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs) - by auto; reflexivity ]. - eapply f_equal2; [|trivial]. eapply f_equal. - expand_lists (). - subst carry_mulmod; reflexivity. - Qed. - - Derive carrymod - SuchThat (forall (f : list Z) - (Hf : length f = n), - (eval weight n (carrymod f)) mod (s - Associational.eval c) - = (eval weight n f) mod (s - Associational.eval c)) - As eval_carrymod. - Proof. - intros. - etransitivity; - [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs) - by auto; reflexivity ]. - eapply f_equal2; [|trivial]. eapply f_equal. - expand_lists (). - subst carrymod; reflexivity. - Qed. - - Derive addmod - SuchThat (forall (f g : list Z) - (Hf : length f = n) - (Hg : length g = n), - (eval weight n (addmod f g)) mod (s - Associational.eval c) - = (eval weight n f + eval weight n g) mod (s - Associational.eval c)) - As eval_addmod. - Proof. - intros. - rewrite <-eval_add by auto. - eapply f_equal2; [|trivial]. eapply f_equal. - expand_lists (). - subst addmod; reflexivity. - Qed. - - Derive submod - SuchThat (forall (coef:Z) - (f g : list Z) - (Hf : length f = n) - (Hg : length g = n), - (eval weight n (submod coef f g)) mod (s - Associational.eval c) - = (eval weight n f - eval weight n g) mod (s - Associational.eval c)) - As eval_submod. - Proof. - intros. - rewrite <-eval_sub with (coef:=coef) by auto. - eapply f_equal2; [|trivial]. eapply f_equal. - expand_lists (). - subst submod; reflexivity. - Qed. - - Derive oppmod - SuchThat (forall (coef:Z) - (f: list Z) - (Hf : length f = n), - (eval weight n (oppmod coef f)) mod (s - Associational.eval c) - = (- eval weight n f) mod (s - Associational.eval c)) - As eval_oppmod. - Proof. - intros. - rewrite <-eval_opp with (coef:=coef) by auto. - eapply f_equal2; [|trivial]. eapply f_equal. - expand_lists (). - subst oppmod; reflexivity. - Qed. - - Derive encodemod - SuchThat (forall (f:Z), - (eval weight n (encodemod f)) mod (s - Associational.eval c) - = f mod (s - Associational.eval c)) - As eval_encodemod. - Proof. - intros. - etransitivity. - 2:rewrite <-@eval_encode with (weight:=weight) (n:=n) by auto; reflexivity. - eapply f_equal2; [|trivial]. eapply f_equal. - expand_lists (). - subst encodemod; reflexivity. - Qed. -End mod_ops. - -Module Saturated. - Hint Resolve weight_positive weight_0 weight_multiples weight_divides. - Hint Resolve Z.positive_is_nonzero Z.lt_gt Nat2Z.is_nonneg. - - Section Weight. - Context weight {wprops : @weight_properties weight}. - - Lemma weight_multiples_full' j : forall i, weight (i+j) mod weight i = 0. - Proof. - induction j; intros; - repeat match goal with - | _ => rewrite Nat.add_succ_r - | _ => rewrite IHj - | |- context [weight (S ?x) mod weight _] => - rewrite (Z.div_mod (weight (S x)) (weight x)), weight_multiples by auto - | _ => progress autorewrite with push_Zmod natsimplify zsimplify_fast - | _ => reflexivity - end. - Qed. - - Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0. - Proof. - intros; replace j with (i + (j - i))%nat by omega. - apply weight_multiples_full'. - Qed. - - Lemma weight_divides_full j i : (i <= j)%nat -> 0 < weight j / weight i. - Proof. auto using Z.gt_lt, Z.div_positive_gt_0, weight_multiples_full. Qed. - - Lemma weight_div_mod j i : (i <= j)%nat -> weight j = weight i * (weight j / weight i). - Proof. intros. apply Z.div_exact; auto using weight_multiples_full. Qed. - End Weight. - - Module Associational. - Section Associational. - - Definition sat_multerm s (t t' : (Z * Z)) : list (Z * Z) := - dlet_nd xy := Z.mul_split s (snd t) (snd t') in - [(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)]. - - Definition sat_mul s (p q : list (Z * Z)) : list (Z * Z) := - flat_map (fun t => flat_map (fun t' => sat_multerm s t t') q) p. - - Lemma eval_map_sat_multerm s a q (s_nonzero:s<>0): - Associational.eval (flat_map (sat_multerm s a) q) = fst a * snd a * Associational.eval q. - Proof. - cbv [sat_multerm Let_In]; induction q; - repeat match goal with - | _ => progress autorewrite with cancel_pair push_eval to_div_mod in * - | _ => progress simpl flat_map - | _ => rewrite IHq - | _ => rewrite Z.mod_eq by assumption - | _ => ring_simplify; omega - end. - Qed. - Hint Rewrite eval_map_sat_multerm using (omega || assumption) : push_eval. - - Lemma eval_sat_mul s p q (s_nonzero:s<>0): - Associational.eval (sat_mul s p q) = Associational.eval p * Associational.eval q. - Proof. - cbv [sat_mul]; induction p; [reflexivity|]. - repeat match goal with - | _ => progress (autorewrite with push_flat_map push_eval in * ) - | _ => rewrite IHp - | _ => ring_simplify; omega - end. - Qed. - Hint Rewrite eval_sat_mul : push_eval. - - Definition sat_multerm_const s (t t' : (Z * Z)) : list (Z * Z) := - if snd t =? 1 - then [(fst t * fst t', snd t')] - else if snd t =? -1 - then [(fst t * fst t', - snd t')] - else if snd t =? 0 - then nil - else dlet_nd xy := Z.mul_split s (snd t) (snd t') in - [(fst t * fst t', fst xy); (fst t * fst t' * s, snd xy)]. - - Definition sat_mul_const s (p q : list (Z * Z)) : list (Z * Z) := - flat_map (fun t => flat_map (fun t' => sat_multerm_const s t t') q) p. - - Lemma eval_map_sat_multerm_const s a q (s_nonzero:s<>0): - Associational.eval (flat_map (sat_multerm_const s a) q) = fst a * snd a * Associational.eval q. - Proof. - cbv [sat_multerm_const Let_In]; induction q; - repeat match goal with - | _ => progress autorewrite with cancel_pair push_eval to_div_mod in * - | _ => progress simpl flat_map - | H : _ = 1 |- _ => rewrite H - | H : _ = -1 |- _ => rewrite H - | H : _ = 0 |- _ => rewrite H - | _ => progress break_match; Z.ltb_to_lt - | _ => rewrite IHq - | _ => rewrite Z.mod_eq by assumption - | _ => ring_simplify; omega - end. - Qed. - Hint Rewrite eval_map_sat_multerm_const using (omega || assumption) : push_eval. - - Lemma eval_sat_mul_const s p q (s_nonzero:s<>0): - Associational.eval (sat_mul_const s p q) = Associational.eval p * Associational.eval q. - Proof. - cbv [sat_mul_const]; induction p; [reflexivity|]. - repeat match goal with - | _ => progress (autorewrite with push_flat_map push_eval in * ) - | _ => rewrite IHp - | _ => ring_simplify; omega - end. - Qed. - Hint Rewrite eval_sat_mul_const : push_eval. - End Associational. - End Associational. - - Section DivMod. - Lemma mod_step a b c d: 0 < a -> 0 < b -> - c mod a + a * ((c / a + d) mod b) = (a * d + c) mod (a * b). - Proof. - intros; rewrite Z.rem_mul_r by omega. push_Zmod. - autorewrite with zsimplify pull_Zmod. repeat (f_equal; try ring). - Qed. - - Lemma div_step a b c d : 0 < a -> 0 < b -> - (c / a + d) / b = (a * d + c) / (a * b). - Proof. intros; Z.div_mod_to_quot_rem_in_goal; nia. Qed. - - Lemma add_mod_div_multiple a b n m: - n > 0 -> - 0 <= m / n -> - m mod n = 0 -> - (a / n + b) mod (m / n) = (a + n * b) mod m / n. - Proof. - intros. rewrite <-!Z.div_add' by auto using Z.positive_is_nonzero. - rewrite Z.mod_pull_div, Z.mul_div_eq' by auto using Z.gt_lt. - repeat (f_equal; try omega). - Qed. - - Lemma add_mod_l_multiple a b n m: - 0 < n / m -> m <> 0 -> n mod m = 0 -> - (a mod n + b) mod m = (a + b) mod m. - Proof. - intros. - rewrite (proj2 (Z.div_exact n m ltac:(auto))) by auto. - rewrite Z.rem_mul_r by auto. - push_Zmod. autorewrite with zsimplify. - pull_Zmod. reflexivity. - Qed. - - Definition is_div_mod {T} (evalf : T -> Z) dm y n := - evalf (fst dm) = y mod n /\ snd dm = y / n. - - Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x : - n1 > 0 -> - 0 < n2 / n1 -> - n2 mod n1 = 0 -> - evalf2 (fst dm2) = evalf1 (fst dm1) + n1 * ((snd dm1 + x) mod (n2 / n1)) -> - snd dm2 = (snd dm1 + x) / (n2 / n1) -> - y2 = y1 + n1 * x -> - @is_div_mod T evalf1 dm1 y1 n1 -> - @is_div_mod T evalf2 dm2 y2 n2. - Proof. - intros; subst y2; cbv [is_div_mod] in *. - repeat match goal with - | H: _ /\ _ |- _ => destruct H - | H: ?LHS = _ |- _ => match LHS with context [dm2] => rewrite H end - | H: ?LHS = _ |- _ => match LHS with context [dm1] => rewrite H end - | _ => rewrite mod_step by omega - | _ => rewrite div_step by omega - | _ => rewrite Z.mul_div_eq_full by omega - end. - split; f_equal; omega. - Qed. - - Lemma is_div_mod_result_equal {T} evalf dm y1 y2 n : - y1 = y2 -> - @is_div_mod T evalf dm y1 n -> - @is_div_mod T evalf dm y2 n. - Proof. congruence. Qed. - End DivMod. -End Saturated. - -Module Columns. - Import Saturated. - Section Columns. - Context weight {wprops : @weight_properties weight}. - - Definition eval n (x : list (list Z)) : Z := Positional.eval weight n (map sum x). - - Lemma eval_nil n : eval n [] = 0. - Proof. cbv [eval]; simpl. apply Positional.eval_nil. Qed. - Hint Rewrite eval_nil : push_eval. - Lemma eval_snoc n x y : n = length x -> eval (S n) (x ++ [y]) = eval n x + weight n * sum y. - Proof. - cbv [eval]; intros; subst. rewrite map_app. simpl map. - apply Positional.eval_snoc; distr_length. - Qed. Hint Rewrite eval_snoc using (solve [distr_length]) : push_eval. - - Hint Rewrite <- Z.div_add' using omega : pull_Zdiv. - - Ltac cases := - match goal with - | |- _ /\ _ => split - | H: _ /\ _ |- _ => destruct H - | H: _ \/ _ |- _ => destruct H - | _ => progress break_match; try discriminate - end. - - Section Flatten. - Section flatten_column. - Context (fw : Z). (* maximum size of the result *) - - (* Outputs (sum, carry) *) - Definition flatten_column (digit: list Z) : (Z * Z) := - list_rect (fun _ => (Z * Z)%type) (0,0) - (fun xx tl flatten_column_tl => - list_rect - (fun _ => (Z * Z)%type) (xx mod fw, xx / fw) - (fun yy tl' _ => - list_rect - (fun _ => (Z * Z)%type) (dlet_nd x := xx in dlet_nd y := yy in Z.add_get_carry_full fw x y) - (fun _ _ _ => - dlet_nd x := xx in - dlet_nd rec := flatten_column_tl in (* recursively get the sum and carry *) - dlet_nd sum_carry := Z.add_get_carry_full fw x (fst rec) in (* add the new value to the sum *) - dlet_nd carry' := snd sum_carry + snd rec in (* add the two carries together *) - (fst sum_carry, carry')) - tl') - tl) - digit. - End flatten_column. - - Definition flatten_step (digit:list Z) (acc_carry:list Z * Z) : list Z * Z := - dlet sum_carry := flatten_column (weight (S (length (fst acc_carry))) / weight (length (fst acc_carry))) (snd acc_carry::digit) in - (fst acc_carry ++ fst sum_carry :: nil, snd sum_carry). - - Definition flatten (xs : list (list Z)) : list Z * Z := - fold_right (fun a b => flatten_step a b) (nil,0) (rev xs). - - Ltac push_fast := - repeat match goal with - | _ => progress cbv [Let_In] - | |- context [list_rect _ _ _ ?ls] => rewrite single_list_rect_to_match; destruct ls - | _ => progress (unfold flatten_step in *; fold flatten_step in * ) - | _ => rewrite Nat.add_1_r - | _ => rewrite Z.mul_div_eq_full by (auto; omega) - | _ => rewrite weight_multiples - | _ => reflexivity - | _ => solve [repeat (f_equal; try ring)] - | _ => congruence - | _ => progress cases - end. - Ltac push := - repeat match goal with - | _ => progress push_fast - | _ => progress autorewrite with cancel_pair to_div_mod - | _ => progress autorewrite with push_sum push_fold_right push_nth_default in * - | _ => progress autorewrite with pull_Zmod pull_Zdiv zsimplify_fast - | _ => progress autorewrite with list distr_length push_eval - end. - - Lemma flatten_column_mod fw (xs : list Z) : - fst (flatten_column fw xs) = sum xs mod fw. - Proof. - induction xs; simpl flatten_column; cbv [Let_In]; - repeat match goal with - | _ => rewrite IHxs - | _ => progress push - end. - Qed. Hint Rewrite flatten_column_mod : to_div_mod. - - Lemma flatten_column_div fw (xs : list Z) (fw_nz : fw <> 0) : - snd (flatten_column fw xs) = sum xs / fw. - Proof. - induction xs; simpl flatten_column; cbv [Let_In]; - repeat match goal with - | _ => rewrite IHxs - | _ => rewrite Z.mul_div_eq_full by omega - | _ => progress push - end. - Qed. Hint Rewrite flatten_column_div using auto with zarith : to_div_mod. - - Hint Rewrite Positional.eval_nil : push_eval. - Hint Resolve Z.gt_lt. - - Lemma length_flatten_step digit state : - length (fst (flatten_step digit state)) = S (length (fst state)). - Proof. cbv [flatten_step]; push. Qed. - Hint Rewrite length_flatten_step : distr_length. - Lemma length_flatten inp : length (fst (flatten inp)) = length inp. - Proof. cbv [flatten]. induction inp using rev_ind; push. Qed. - Hint Rewrite length_flatten : distr_length. - - Lemma flatten_div_mod n inp : - length inp = n -> - (Positional.eval weight n (fst (flatten inp)) - = (eval n inp) mod (weight n)) - /\ (snd (flatten inp) = eval n inp / weight n). - Proof. - (* to make the invariant take the right form, we make everything depend on output length, not input length *) - intro. subst n. rewrite <-(length_flatten inp). cbv [flatten]. - induction inp using rev_ind; intros; [push|]. - repeat match goal with - | _ => rewrite Nat.add_1_r - | _ => progress (fold (flatten inp) in * ) - | _ => erewrite Positional.eval_snoc by (distr_length; reflexivity) - | H: _ = _ mod (weight _) |- _ => rewrite H - | H: _ = _ / (weight _) |- _ => rewrite H - | _ => progress rewrite ?mod_step, ?div_step by auto - | _ => progress autorewrite with cancel_pair to_div_mod push_sum list push_fold_right push_eval - | _ => progress (distr_length; push_fast) - end. - Qed. - - Lemma flatten_mod {n} inp : - length inp = n -> - (Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n)). - Proof. apply flatten_div_mod. Qed. - Hint Rewrite @flatten_mod : push_eval. - - Lemma flatten_div {n} inp : - length inp = n -> snd (flatten inp) = eval n inp / weight n. - Proof. apply flatten_div_mod. Qed. - Hint Rewrite @flatten_div : push_eval. - - Lemma flatten_snoc x inp : flatten (inp ++ [x]) = flatten_step x (flatten inp). - Proof. cbv [flatten]. rewrite rev_unit. reflexivity. Qed. - - Lemma flatten_partitions inp: - forall n i, length inp = n -> (i < n)%nat -> - nth_default 0 (fst (flatten inp)) i = ((eval n inp) mod (weight (S i))) / weight i. - Proof. - induction inp using rev_ind; intros; destruct n; distr_length. - rewrite flatten_snoc. - push; distr_length; - [rewrite IHinp with (n:=n) by omega; rewrite weight_div_mod with (j:=n) (i:=S i) by (eauto; omega); push_Zmod; push |]. - repeat match goal with - | _ => progress replace (length inp) with n by omega - | _ => progress replace i with n by omega - | _ => progress push - | _ => erewrite flatten_div by eauto - | _ => rewrite <-Z.div_add' by auto - | _ => rewrite Z.mul_div_eq' by auto - | _ => rewrite Z.mod_pull_div by auto using Z.lt_le_incl - | _ => progress autorewrite with push_nth_default natsimplify - end. - Qed. - End Flatten. - - Section FromAssociational. - (* nils *) - Definition nils n : list (list Z) := repeat nil n. - Lemma length_nils n : length (nils n) = n. Proof. cbv [nils]. distr_length. Qed. - Hint Rewrite length_nils : distr_length. - Lemma eval_nils n : eval n (nils n) = 0. - Proof. - erewrite <-Positional.eval_zeros by eauto. - cbv [eval nils]; rewrite List.map_repeat; reflexivity. - Qed. Hint Rewrite eval_nils : push_eval. - - (* cons_to_nth *) - Definition cons_to_nth i x (xs : list (list Z)) : list (list Z) := - ListUtil.update_nth i (fun y => cons x y) xs. - Lemma length_cons_to_nth i x xs : length (cons_to_nth i x xs) = length xs. - Proof. cbv [cons_to_nth]. distr_length. Qed. - Hint Rewrite length_cons_to_nth : distr_length. - Lemma cons_to_nth_add_to_nth xs : forall i x, - map sum (cons_to_nth i x xs) = Positional.add_to_nth i x (map sum xs). - Proof. - cbv [cons_to_nth]; induction xs as [|? ? IHxs]; - intros i x; destruct i; simpl; rewrite ?IHxs; reflexivity. - Qed. - Lemma eval_cons_to_nth n i x xs : (i < length xs)%nat -> length xs = n -> - eval n (cons_to_nth i x xs) = weight i * x + eval n xs. - Proof using Type. - cbv [eval]; intros. rewrite cons_to_nth_add_to_nth. - apply Positional.eval_add_to_nth; distr_length. - Qed. Hint Rewrite eval_cons_to_nth using (solve [distr_length]) : push_eval. - - Hint Rewrite Positional.eval_zeros : push_eval. - Hint Rewrite Positional.length_from_associational : distr_length. - Hint Rewrite Positional.eval_add_to_nth using (solve [distr_length]): push_eval. - - (* from_associational *) - Definition from_associational n (p:list (Z*Z)) : list (list Z) := - List.fold_right (fun t ls => - dlet_nd p := Positional.place weight t (pred n) in - cons_to_nth (fst p) (snd p) ls ) (nils n) p. - Lemma length_from_associational n p : length (from_associational n p) = n. - Proof. cbv [from_associational Let_In]. apply fold_right_invariant; intros; distr_length. Qed. - Hint Rewrite length_from_associational: distr_length. - Lemma eval_from_associational n p (n_nonzero:n<>0%nat\/p=nil): - eval n (from_associational n p) = Associational.eval p. - Proof. - erewrite <-Positional.eval_from_associational by eauto. - induction p; [ autorewrite with push_eval; solve [auto] |]. - cbv [from_associational Positional.from_associational]; autorewrite with push_fold_right. - fold (from_associational n p); fold (Positional.from_associational weight n p). - cbv [Let_In]. - match goal with |- context [Positional.place _ ?x ?n] => - pose proof (Positional.place_in_range weight x n) end. - repeat match goal with - | _ => rewrite Nat.succ_pred in * by auto - | _ => rewrite IHp by auto - | _ => progress autorewrite with push_eval - | _ => progress cases - | _ => congruence - end. - Qed. - - Lemma from_associational_step n t p : - from_associational n (t :: p) = - cons_to_nth (fst (Positional.place weight t (Nat.pred n))) - (snd (Positional.place weight t (Nat.pred n))) - (from_associational n p). - Proof. reflexivity. Qed. - End FromAssociational. - End Columns. -End Columns. - -Module Rows. - Import Saturated. - Section Rows. - Context weight {wprops : @weight_properties weight}. - - Local Notation rows := (list (list Z)) (only parsing). - Local Notation cols := (list (list Z)) (only parsing). - - Hint Rewrite Positional.eval_nil Positional.eval0 @Positional.eval_snoc - Positional.eval_to_associational - Columns.eval_nil Columns.eval_snoc using (auto; solve [distr_length]) : push_eval. - Hint Resolve in_eq in_cons. - - Definition eval n (inp : rows) := - sum (map (Positional.eval weight n) inp). - Lemma eval_nil n : eval n nil = 0. - Proof. cbv [eval]. rewrite map_nil, sum_nil; reflexivity. Qed. - Hint Rewrite eval_nil : push_eval. - Lemma eval0 x : eval 0 x = 0. - Proof. cbv [eval]. induction x; autorewrite with push_map push_sum push_eval; omega. Qed. - Hint Rewrite eval0 : push_eval. - Lemma eval_cons n r inp : eval n (r :: inp) = Positional.eval weight n r + eval n inp. - Proof. cbv [eval]; autorewrite with push_map push_sum; reflexivity. Qed. - Hint Rewrite eval_cons : push_eval. - Lemma eval_app n x y : eval n (x ++ y) = eval n x + eval n y. - Proof. cbv [eval]; autorewrite with push_map push_sum; reflexivity. Qed. - Hint Rewrite eval_app : push_eval. - - Ltac In_cases := - repeat match goal with - | H: In _ (_ ++ _) |- _ => apply in_app_or in H; destruct H - | H: In _ (_ :: _) |- _ => apply in_inv in H; destruct H - | H: In _ nil |- _ => contradiction H - | H: forall x, In x (?y :: ?ls) -> ?P |- _ => - unique pose proof (H y ltac:(apply in_eq)); - unique assert (forall x, In x ls -> P) by auto - | H: forall x, In x (?ls ++ ?y :: nil) -> ?P |- _ => - unique pose proof (H y ltac:(auto using in_or_app, in_eq)); - unique assert (forall x, In x ls -> P) by eauto using in_or_app - end. - - Section FromAssociational. - (* extract row *) - Definition extract_row (inp : cols) : cols * list Z := (map (fun c => tl c) inp, map (fun c => hd 0 c) inp). - - Lemma eval_extract_row (inp : cols): forall n, - length inp = n -> - Positional.eval weight n (snd (extract_row inp)) = Columns.eval weight n inp - Columns.eval weight n (fst (extract_row inp)) . - Proof. - cbv [extract_row]. - induction inp using rev_ind; [ | destruct n ]; - repeat match goal with - | _ => progress intros - | _ => progress distr_length - | _ => rewrite Positional.eval_snoc with (n:=n) by distr_length - | _ => progress autorewrite with cancel_pair push_eval push_map in * - | _ => ring - end. - rewrite IHinp by distr_length. - destruct x; cbn [hd tl]; rewrite ?sum_nil, ?sum_cons; ring. - Qed. Hint Rewrite eval_extract_row using (solve [distr_length]) : push_eval. - - Lemma length_fst_extract_row n (inp : cols) : - length inp = n -> length (fst (extract_row inp)) = n. - Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed. - Hint Rewrite length_fst_extract_row : distr_length. - - Lemma length_snd_extract_row n (inp : cols) : - length inp = n -> length (snd (extract_row inp)) = n. - Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed. - Hint Rewrite length_snd_extract_row : distr_length. - - (* max column size *) - Definition max_column_size (x:cols) := fold_right (fun a b => Nat.max a b) 0%nat (map (fun c => length c) x). - - (* TODO: move to where list is defined *) - Hint Rewrite @app_nil_l : list. - Hint Rewrite <-@app_comm_cons: list. - - Lemma max_column_size_nil : max_column_size nil = 0%nat. - Proof. reflexivity. Qed. Hint Rewrite max_column_size_nil : push_max_column_size. - Lemma max_column_size_cons col (inp : cols) : - max_column_size (col :: inp) = Nat.max (length col) (max_column_size inp). - Proof. reflexivity. Qed. Hint Rewrite max_column_size_cons : push_max_column_size. - Lemma max_column_size_app (x y : cols) : - max_column_size (x ++ y) = Nat.max (max_column_size x) (max_column_size y). - Proof. induction x; autorewrite with list push_max_column_size; lia. Qed. - Hint Rewrite max_column_size_app : push_max_column_size. - Lemma max_column_size0 (inp : cols) : - forall n, - length inp = n -> (* this is not needed to make the lemma true, but prevents reliance on the implementation of Columns.eval*) - max_column_size inp = 0%nat -> Columns.eval weight n inp = 0. - Proof. - induction inp as [|x inp] using rev_ind; destruct n; try destruct x; intros; - autorewrite with push_max_column_size push_eval push_sum distr_length in *; try lia. - rewrite IHinp; distr_length; lia. - Qed. - - (* from_columns *) - Definition from_columns' n start_state : cols * rows := - fold_right (fun _ (state : cols * rows) => - let cols'_row := extract_row (fst state) in - (fst cols'_row, snd state ++ [snd cols'_row]) - ) start_state (repeat 0 n). - - Definition from_columns (inp : cols) : rows := snd (from_columns' (max_column_size inp) (inp, [])). - - Lemma eval_from_columns'_with_length m st n: - (length (fst st) = n) -> - length (fst (from_columns' m st)) = n /\ - ((forall r, In r (snd st) -> length r = n) -> - forall r, In r (snd (from_columns' m st)) -> length r = n) /\ - eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st) - - Columns.eval weight n (fst (from_columns' m st)). - Proof. - cbv [from_columns']; intros. - apply fold_right_invariant; intros; - repeat match goal with - | _ => progress (intros; subst) - | _ => progress autorewrite with cancel_pair push_eval - | _ => progress In_cases - | _ => split; try omega - | H: _ /\ _ |- _ => destruct H - | _ => solve [auto using length_fst_extract_row, length_snd_extract_row] - end. - Qed. - Lemma length_fst_from_columns' m st : - length (fst (from_columns' m st)) = length (fst st). - Proof. apply eval_from_columns'_with_length; reflexivity. Qed. - Hint Rewrite length_fst_from_columns' : distr_length. - Lemma length_snd_from_columns' m st : - (forall r, In r (snd st) -> length r = length (fst st)) -> - forall r, In r (snd (from_columns' m st)) -> length r = length (fst st). - Proof. apply eval_from_columns'_with_length. reflexivity. Qed. - Hint Rewrite length_snd_from_columns' : distr_length. - Lemma eval_from_columns' m st n : - (length (fst st) = n) -> - eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st) - - Columns.eval weight n (fst (from_columns' m st)). - Proof. apply eval_from_columns'_with_length. Qed. - Hint Rewrite eval_from_columns' using (auto; solve [distr_length]) : push_eval. - - Lemma max_column_size_extract_row inp : - max_column_size (fst (extract_row inp)) = (max_column_size inp - 1)%nat. - Proof. - cbv [extract_row]. autorewrite with cancel_pair. - induction inp; [ reflexivity | ]. - autorewrite with push_max_column_size push_map distr_length. - rewrite IHinp. auto using Nat.sub_max_distr_r. - Qed. - Hint Rewrite max_column_size_extract_row : push_max_column_size. - - Lemma max_column_size_from_columns' m st : - max_column_size (fst (from_columns' m st)) = (max_column_size (fst st) - m)%nat. - Proof. - cbv [from_columns']; induction m; intros; cbn - [max_column_size extract_row]; - autorewrite with push_max_column_size; lia. - Qed. - Hint Rewrite max_column_size_from_columns' : push_max_column_size. - - Lemma eval_from_columns (inp : cols) : - forall n, length inp = n -> eval n (from_columns inp) = Columns.eval weight n inp. - Proof. - intros; cbv [from_columns]; - repeat match goal with - | _ => progress autorewrite with cancel_pair push_eval push_max_column_size - | _ => rewrite max_column_size0 with (inp := fst (from_columns' _ _)) by - (autorewrite with push_max_column_size; distr_length) - | _ => omega - end. - Qed. - Hint Rewrite eval_from_columns using (auto; solve [distr_length]) : push_eval. - - Lemma length_from_columns inp: - forall r, In r (from_columns inp) -> length r = length inp. - Proof. - cbv [from_columns]; intros. - change inp with (fst (inp, @nil (list Z))). - eapply length_snd_from_columns'; eauto. - autorewrite with cancel_pair; intros; In_cases. - Qed. - Hint Rewrite length_from_columns : distr_length. - - (* from associational *) - Definition from_associational n (p : list (Z * Z)) := from_columns (Columns.from_associational weight n p). - - Lemma eval_from_associational n p: (n <> 0%nat \/ p = nil) -> - eval n (from_associational n p) = Associational.eval p. - Proof. - intros. cbv [from_associational]. - rewrite eval_from_columns by auto using Columns.length_from_associational. - auto using Columns.eval_from_associational. - Qed. - - Lemma length_from_associational n p : - forall r, In r (from_associational n p) -> length r = n. - Proof. - cbv [from_associational]; intros. - match goal with H: _ |- _ => apply length_from_columns in H end. - rewrite Columns.length_from_associational in *; auto. - Qed. - - Lemma max_column_size_zero_iff x : - max_column_size x = 0%nat <-> (forall c, In c x -> c = nil). - Proof. - cbv [max_column_size]; induction x; intros; [ cbn; tauto | ]. - autorewrite with push_fold_right push_map. - rewrite max_0_iff, IHx. - split; intros; [ | rewrite length_zero_iff_nil; solve [auto] ]. - match goal with H : _ /\ _ |- _ => destruct H end. - In_cases; subst; auto using length0_nil. - Qed. - - Lemma max_column_size_Columns_from_associational n p : - n <> 0%nat -> p <> nil -> - max_column_size (Columns.from_associational weight n p) <> 0%nat. - Proof. - intros. - rewrite max_column_size_zero_iff. - intro. destruct p; [congruence | ]. - rewrite Columns.from_associational_step in *. - cbv [Columns.cons_to_nth] in *. - match goal with H : forall c, In c (update_nth ?n ?f ?ls) -> _ |- _ => - assert (n < length (update_nth n f ls))%nat; - [ | specialize (H (nth n (update_nth n f ls) nil) ltac:(auto using nth_In)) ] - end. - { distr_length. - rewrite Columns.length_from_associational. - remember (Nat.pred n) as m. replace n with (S m) by omega. - apply Positional.place_in_range. } - rewrite <-nth_default_eq in *. - autorewrite with push_nth_default in *. - rewrite eq_nat_dec_refl in *. - congruence. - Qed. - - Lemma from_associational_nonnil n p : - n <> 0%nat -> p <> nil -> - from_associational n p <> nil. - Proof. - intros; cbv [from_associational from_columns from_columns']. - pose proof (max_column_size_Columns_from_associational n p ltac:(auto) ltac:(auto)). - case_eq (max_column_size (Columns.from_associational weight n p)); [omega|]. - intros; cbn. - rewrite <-length_zero_iff_nil. distr_length. - Qed. - End FromAssociational. - - Section Flatten. - Local Notation fw := (fun i => weight (S i) / weight i) (only parsing). - - Section SumRows. - Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z * nat := - fold_right (fun next (state : list Z * Z * nat) => - let i := snd state in - let low_high' := - let low_high := fst state in - let low := fst low_high in - let high := snd low_high in - dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) high (fst next) (snd next) in - (low ++ [fst sum_carry], snd sum_carry) in - (low_high', S i)) start_state (rev (combine row1 row2)). - Definition sum_rows row1 row2 := fst (sum_rows' (nil, 0, 0%nat) row1 row2). - - Ltac push := - repeat match goal with - | _ => progress intros - | _ => progress cbv [Let_In] - | _ => rewrite Nat.add_1_r - | _ => erewrite Positional.eval_snoc by eauto - | H : length _ = _ |- _ => rewrite H - | H: 0%nat = _ |- _ => rewrite <-H - | [p := _ |- _] => subst p - | _ => progress autorewrite with cancel_pair natsimplify push_sum_rows list push_nth_default - | _ => progress autorewrite with cancel_pair in * - | _ => progress distr_length - | _ => progress break_match - | _ => ring - | _ => solve [ repeat (f_equal; try ring) ] - | _ => tauto - | _ => solve [eauto] - end. - - Lemma sum_rows'_cons state x1 row1 x2 row2 : - sum_rows' state (x1 :: row1) (x2 :: row2) = - sum_rows' (fst (fst state) ++ [(snd (fst state) + x1 + x2) mod (fw (snd state))], - (snd (fst state) + x1 + x2) / fw (snd state), - S (snd state)) row1 row2. - Proof. - cbv [sum_rows' Let_In]; autorewrite with push_combine. - rewrite !fold_left_rev_right. cbn [fold_left]. - autorewrite with cancel_pair to_div_mod. congruence. - Qed. - - Lemma sum_rows'_nil state : - sum_rows' state nil nil = state. - Proof. reflexivity. Qed. - - Hint Rewrite sum_rows'_cons sum_rows'_nil : push_sum_rows. - - Lemma sum_rows'_div_mod_length row1 : - forall nm start_state row2 row1' row2', - let m := snd start_state in - let n := length row1 in - length row2 = n -> - length row1' = m -> - length row2' = m -> - length (fst (fst start_state)) = m -> - (nm = n + m)%nat -> - let eval := Positional.eval weight in - is_div_mod (eval m) (fst start_state) (eval m row1' + eval m row2') (weight m) -> - length (fst (fst (sum_rows' start_state row1 row2))) = nm - /\ is_div_mod (eval nm) (fst (sum_rows' start_state row1 row2)) - (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) - (weight nm). - Proof. - induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [ ]. - rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). - apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length in *; try omega. - eapply is_div_mod_step with (x := x1 + x2); try eassumption; push. - Qed. - - Lemma sum_rows_div_mod n row1 row2 : - length row1 = n -> length row2 = n -> - let eval := Positional.eval weight in - is_div_mod (eval n) (sum_rows row1 row2) (eval n row1 + eval n row2) (weight n). - Proof. - cbv [sum_rows]; intros. - apply sum_rows'_div_mod_length with (row1':=nil) (row2':=nil); - cbv [is_div_mod]; autorewrite with cancel_pair push_eval zsimplify; distr_length. - Qed. - - Lemma sum_rows_mod n row1 row2 : - length row1 = n -> length row2 = n -> - Positional.eval weight n (fst (sum_rows row1 row2)) - = (Positional.eval weight n row1 + Positional.eval weight n row2) mod (weight n). - Proof. apply sum_rows_div_mod. Qed. - Lemma sum_rows_div row1 row2 n: - length row1 = n -> length row2 = n -> - snd (sum_rows row1 row2) - = (Positional.eval weight n row1 + Positional.eval weight n row2) / (weight n). - Proof. apply sum_rows_div_mod. Qed. - - Lemma sum_rows'_partitions row1 : - forall nm start_state row2 row1' row2', - let m := snd start_state in - let n := length row1 in - length row2 = n -> - length row1' = m -> - length row2' = m -> - length (fst (fst start_state)) = m -> - nm = (n + m)%nat -> - let eval := Positional.eval weight in - snd (fst start_state) = (eval m row1' + eval m row2') / weight m -> - (forall j, (j < m)%nat -> - nth_default 0 (fst (fst start_state)) j = ((eval m row1' + eval m row2') mod (weight (S j))) / (weight j)) -> - forall i, (i < nm)%nat -> - nth_default 0 (fst (fst (sum_rows' start_state row1 row2))) i - = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i). - Proof. - induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; []. - - rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). - apply IHrow1; clear IHrow1; push; - repeat match goal with - | H : ?LHS = _ |- _ => - match LHS with context [start_state] => rewrite H end - | H : context [nth_default 0 (fst (fst start_state))] |- _ => rewrite H by omega - | _ => rewrite <-(Z.add_assoc _ x1 x2) - end. - { rewrite div_step by auto using Z.gt_lt. - rewrite Z.mul_div_eq_full by auto; rewrite weight_multiples by auto. push. } - { rewrite weight_div_mod with (j:=snd start_state) (i:=S j) by (auto; omega). - push_Zmod. autorewrite with zsimplify_fast. reflexivity. } - { push. replace (snd start_state) with j in * by omega. - push. rewrite add_mod_div_multiple by auto using Z.lt_le_incl. - push. } - Qed. - - Lemma sum_rows_partitions row1: forall row2 n i, - length row1 = n -> length row2 = n -> (i < n)%nat -> - nth_default 0 (fst (sum_rows row1 row2)) i - = ((Positional.eval weight n row1 + Positional.eval weight n row2) mod weight (S i)) / (weight i). - Proof. - cbv [sum_rows]; intros. rewrite <-(Nat.add_0_r n). - rewrite <-(app_nil_l row1), <-(app_nil_l row2). - apply sum_rows'_partitions; intros; - autorewrite with cancel_pair push_eval zsimplify_fast push_nth_default; distr_length. - Qed. - - Lemma length_sum_rows row1 row2 n: - length row1 = n -> length row2 = n -> - length (fst (sum_rows row1 row2)) = n. - Proof. - cbv [sum_rows]; intros. - eapply sum_rows'_div_mod_length; cbv [is_div_mod]; - autorewrite with cancel_pair; distr_length; auto using nil_length0. - Qed. Hint Rewrite length_sum_rows : distr_length. - End SumRows. - Hint Resolve length_sum_rows. - Hint Rewrite sum_rows_mod using (auto; solve [distr_length; auto]) : push_eval. - - Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z := - fold_right (fun next_row (state : list Z * Z)=> - let out_carry := sum_rows next_row (fst state) in - (fst out_carry, snd state + snd out_carry)) start_state inp. - - (* In order for the output to have the right length and bounds, - we insert rows of zeroes if there are fewer than two rows. *) - Definition flatten n (inp : rows) : list Z * Z := - let default := Positional.zeros n in - flatten' (hd default inp, 0) (hd default (tl inp) :: tl (tl inp)). - - Lemma flatten'_cons state r inp : - flatten' state (r :: inp) = (fst (sum_rows r (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows r (fst (flatten' state inp)))). - Proof. cbv [flatten']; autorewrite with list push_fold_right. reflexivity. Qed. - Lemma flatten'_snoc state r inp : - flatten' state (inp ++ r :: nil) = flatten' (fst (sum_rows r (fst state)), snd state + snd (sum_rows r (fst state))) inp. - Proof. cbv [flatten']; autorewrite with list push_fold_right. reflexivity. Qed. - Lemma flatten'_nil state : flatten' state [] = state. Proof. reflexivity. Qed. - Hint Rewrite flatten'_cons flatten'_snoc flatten'_nil : push_flatten. - - Ltac push := - repeat match goal with - | _ => progress intros - | H: length ?x = ?n |- context [snd (sum_rows ?x _)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto) - | H: length ?x = ?n |- context [snd (sum_rows _ ?x)] => rewrite sum_rows_div with (n:=n) by (distr_length; eauto) - | H: length _ = _ |- _ => rewrite H - | _ => progress autorewrite with cancel_pair push_flatten push_eval distr_length zsimplify_fast - | _ => progress In_cases - | |- _ /\ _ => split - | |- context [?x mod ?y] => unique pose proof (Z.mul_div_eq_full x y ltac:(auto)); lia - | _ => apply length_sum_rows - | _ => solve [repeat (ring_simplify; f_equal; try ring)] - | _ => congruence - | _ => solve [eauto] - end. - - Lemma flatten'_div_mod_length n inp : forall start_state, - length (fst start_state) = n -> - (forall row, In row inp -> length row = n) -> - length (fst (flatten' start_state inp)) = n - /\ (inp <> nil -> - is_div_mod (Positional.eval weight n) (flatten' start_state inp) - (Positional.eval weight n (fst start_state) + eval n inp + weight n * snd start_state) - (weight n)). - Proof. - induction inp using rev_ind; push; [apply IHinp; push|]. - destruct (dec (inp = nil)); [subst inp; cbv [is_div_mod] - | eapply is_div_mod_result_equal; try apply IHinp]; push. - { autorewrite with zsimplify; push. } - { rewrite Z.div_add' by auto; push. } - Qed. - - Hint Rewrite (@Positional.length_zeros weight) : distr_length. - Hint Rewrite (@Positional.eval_zeros weight) using auto : push_eval. - - Lemma flatten_div_mod inp n : - (forall row, In row inp -> length row = n) -> - is_div_mod (Positional.eval weight n) (flatten n inp) (eval n inp) (weight n). - Proof. - intros; cbv [flatten]. - destruct inp; [|destruct inp]; cbn [hd tl]. - { cbv [is_div_mod]; push. - erewrite sum_rows_div by (distr_length; reflexivity). - push. } - { cbv [is_div_mod]; push. } - { eapply is_div_mod_result_equal; try apply flatten'_div_mod_length; push. } - Qed. - - Lemma flatten_mod inp n : - (forall row, In row inp -> length row = n) -> - Positional.eval weight n (fst (flatten n inp)) = (eval n inp) mod (weight n). - Proof. apply flatten_div_mod. Qed. - Lemma flatten_div inp n : - (forall row, In row inp -> length row = n) -> - snd (flatten n inp) = (eval n inp) / (weight n). - Proof. apply flatten_div_mod. Qed. - - Lemma length_flatten' n start_state inp : - length (fst start_state) = n -> - (forall row, In row inp -> length row = n) -> - length (fst (flatten' start_state inp)) = n. - Proof. apply flatten'_div_mod_length. Qed. - Hint Rewrite length_flatten' : distr_length. - - Lemma length_flatten n inp : - (forall row, In row inp -> length row = n) -> - length (fst (flatten n inp)) = n. - Proof. - intros. - apply length_flatten'; push; - destruct inp as [|? [|? ?] ]; try congruence; cbn [hd tl] in *; push; - subst row; distr_length. - Qed. Hint Rewrite length_flatten : distr_length. - - Lemma flatten'_partitions n inp : forall start_state, - inp <> nil -> - length (fst start_state) = n -> - (forall row, In row inp -> length row = n) -> - forall i, (i < n)%nat -> - nth_default 0 (fst (flatten' start_state inp)) i - = ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i). - Proof. - induction inp using rev_ind; push. - destruct (dec (inp = nil)). - { subst inp; push. rewrite sum_rows_partitions with (n:=n) by eauto. push. } - { erewrite IHinp; push. - rewrite add_mod_l_multiple by auto using weight_divides_full, weight_multiples_full. - push. } - Qed. - - Lemma flatten_partitions inp n : - (forall row, In row inp -> length row = n) -> - forall i, (i < n)%nat -> - nth_default 0 (fst (flatten n inp)) i = (eval n inp mod weight (S i)) / (weight i). - Proof. - intros; cbv [flatten]. - intros; destruct inp as [| ? [| ? ?] ]; try congruence; cbn [hd tl] in *; try solve [push]. - { cbn. autorewrite with push_nth_default. - rewrite sum_rows_partitions with (n:=n) by distr_length. - autorewrite with push_eval zsimplify_fast. - auto with zarith. } - { push. rewrite sum_rows_partitions with (n:=n) by distr_length; push. } - { rewrite flatten'_partitions with (n:=n); push. } - Qed. - - Definition partition n x := - map (fun i => (x mod weight (S i)) / weight i) (seq 0 n). - - Lemma nth_default_partitions x : forall p n, - (forall i, (i < n)%nat -> nth_default 0 p i = (x mod weight (S i)) / weight i) -> - length p = n -> - p = partition n x. - Proof. - cbv [partition]; induction p using rev_ind; intros; distr_length; subst n; [reflexivity|]. - rewrite Nat.add_1_r, seq_snoc. - autorewrite with natsimplify push_map. - rewrite <-IHp; auto; intros; - match goal with H : context [nth_default _ (p ++ [ _ ])] |- _ => - rewrite <-H by omega end. - { autorewrite with push_nth_default natsimplify. reflexivity. } - { autorewrite with push_nth_default natsimplify. - break_match; omega. } - Qed. - - Lemma partition_step n x : - partition (S n) x = partition n x ++ [(x mod weight (S n)) / weight n]. - Proof. - cbv [partition]. rewrite seq_snoc. - autorewrite with natsimplify push_map. reflexivity. - Qed. - - Lemma length_partition n x : length (partition n x) = n. - Proof. cbv [partition]; distr_length. Qed. - Hint Rewrite length_partition : distr_length. - - Lemma eval_partition n x : - Positional.eval weight n (partition n x) = x mod (weight n). - Proof. - induction n; intros. - { cbn. rewrite (weight_0); auto with zarith. } - { rewrite (Z.div_mod (x mod weight (S n)) (weight n)) by auto. - rewrite <-Znumtheory.Zmod_div_mod by (try apply Z.mod_divide; auto). - rewrite partition_step, Positional.eval_snoc with (n:=n) by distr_length. - omega. } - Qed. - - Lemma flatten_partitions' inp n : - (forall row, In row inp -> length row = n) -> - fst (flatten n inp) = partition n (eval n inp). - Proof. auto using nth_default_partitions, flatten_partitions, length_flatten. Qed. - End Flatten. - - Section Ops. - Definition add n p q := flatten n [p; q]. - - (* TODO: Although cleaner, using Positional.negate snd inserts - dlets which prevent add-opp=>sub transformation in partial - evaluation. Should probably either make partial evaluation - handle that or remove the dlet in - Positional.from_associational. *) - Definition sub n p q := flatten n [p; map (fun x => dlet y := x in Z.opp y) q]. - - Hint Rewrite eval_cons eval_nil using solve [auto] : push_eval. - - Definition mul base n m (p q : list Z) := - let p_a := Positional.to_associational weight n p in - let q_a := Positional.to_associational weight n q in - let pq_a := Associational.sat_mul base p_a q_a in - flatten m (from_associational m pq_a). - - (* TODO : move sat_reduce and repeat_sat_reduce to Saturated.Associational *) - Definition sat_reduce base s c (p : list (Z * Z)) := - let lo_hi := Associational.split s p in - fst lo_hi ++ (Associational.sat_mul_const base c (snd lo_hi)). - - Definition repeat_sat_reduce base s c (p : list (Z * Z)) n := - fold_right (fun _ q => sat_reduce base s c q) p (seq 0 n). - - Definition mulmod base s c n nreductions (p q : list Z) := - let p_a := Positional.to_associational weight n p in - let q_a := Positional.to_associational weight n q in - let pq_a := Associational.sat_mul base p_a q_a in - let r_a := repeat_sat_reduce base s c pq_a nreductions in - flatten n (from_associational n r_a). - - Hint Rewrite Associational.eval_sat_mul_const Associational.eval_sat_mul Associational.eval_split using solve [auto] : push_eval. - Hint Rewrite eval_from_associational using solve [auto] : push_eval. - Hint Rewrite eval_partition using solve [auto] : push_eval. - Ltac solver := - intros; cbv [sub add mul mulmod sat_reduce]; - rewrite ?flatten_partitions' by (intros; In_cases; subst; distr_length; eauto using length_from_associational); - rewrite ?flatten_div by (intros; In_cases; subst; distr_length; eauto using length_from_associational); - autorewrite with push_eval; ring_simplify_subterms; - try reflexivity. - - Lemma add_partitions n p q : - n <> 0%nat -> length p = n -> length q = n -> - fst (add n p q) = partition n (Positional.eval weight n p + Positional.eval weight n q). - Proof. solver. Qed. - - Lemma add_div n p q : - n <> 0%nat -> length p = n -> length q = n -> - snd (add n p q) = (Positional.eval weight n p + Positional.eval weight n q) / weight n. - Proof. solver. Qed. - - Lemma eval_map_opp q : - forall n, length q = n -> - Positional.eval weight n (map Z.opp q) = - Positional.eval weight n q. - Proof. - induction q using rev_ind; intros; - repeat match goal with - | _ => progress autorewrite with push_map push_eval - | _ => erewrite !Positional.eval_snoc with (n:=length q) by distr_length - | _ => rewrite IHq by auto - | _ => ring - end. - Qed. Hint Rewrite eval_map_opp using solve [auto]: push_eval. - - Lemma sub_partitions n p q : - n <> 0%nat -> length p = n -> length q = n -> - fst (sub n p q) = partition n (Positional.eval weight n p - Positional.eval weight n q). - Proof. solver. Qed. - - Lemma sub_div n p q : - n <> 0%nat -> length p = n -> length q = n -> - snd (sub n p q) = (Positional.eval weight n p - Positional.eval weight n q) / weight n. - Proof. solver. Qed. - - Lemma mul_partitions base n m p q : - base <> 0 -> n <> 0%nat -> m <> 0%nat -> length p = n -> length q = n -> - fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q). - Proof. solver. Qed. - - Lemma eval_sat_reduce base s c p : - base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 -> - Associational.eval (sat_reduce base s c p) mod (s - Associational.eval c) - = Associational.eval p mod (s - Associational.eval c). - Proof. - intros; cbv [sat_reduce]. - autorewrite with push_eval. - rewrite <-Associational.reduction_rule by omega. - autorewrite with push_eval; reflexivity. - Qed. - Hint Rewrite eval_sat_reduce using auto : push_eval. - - Lemma eval_repeat_sat_reduce base s c p n : - base <> 0 -> s - Associational.eval c <> 0 -> s <> 0 -> - Associational.eval (repeat_sat_reduce base s c p n) mod (s - Associational.eval c) - = Associational.eval p mod (s - Associational.eval c). - Proof. - intros; cbv [repeat_sat_reduce]. - apply fold_right_invariant; intros; autorewrite with push_eval; auto. - Qed. - Hint Rewrite eval_repeat_sat_reduce using auto : push_eval. - - Lemma eval_mulmod base s c n nreductions p q : - base <> 0 -> s <> 0 -> s - Associational.eval c <> 0 -> - n <> 0%nat -> length p = n -> length q = n -> - (Positional.eval weight n (fst (mulmod base s c n nreductions p q)) - + weight n * (snd (mulmod base s c n nreductions p q))) mod (s - Associational.eval c) - = (Positional.eval weight n p * Positional.eval weight n q) mod (s - Associational.eval c). - Proof. - solver. - rewrite <-Z.div_mod'' by auto. - autorewrite with push_eval; reflexivity. - Qed. - End Ops. - End Rows. -End Rows. - -Module BaseConversion. - Import Positional. - Section BaseConversion. - Hint Resolve Z.gt_lt. - Context (sw dw : nat -> Z) (* source/destination weight functions *) - {swprops : @weight_properties sw} - {dwprops : @weight_properties dw}. - - Definition convert_bases (sn dn : nat) (p : list Z) : list Z := - let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in - chained_carries_no_reduce dw dn p' (seq 0 (pred dn)). - - Lemma eval_convert_bases sn dn p : - (dn <> 0%nat) -> length p = sn -> - eval dw dn (convert_bases sn dn p) = eval sw sn p. - Proof. - cbv [convert_bases]; intros. - rewrite eval_chained_carries_no_reduce; auto using Z.positive_is_nonzero. - rewrite eval_from_associational; auto. - Qed. - - Hint Rewrite - @Rows.eval_from_associational - @Associational.eval_carry - @Associational.eval_mul - @Positional.eval_to_associational - Associational.eval_carryterm - @eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval. - - Ltac push_eval := intros; autorewrite with push_eval; auto with zarith. - - (* convert from positional in one weight to the other, then to associational *) - Definition to_associational n m p : list (Z * Z) := - let p' := convert_bases n m p in - Positional.to_associational dw m p'. - - (* TODO : move to Associational? *) - Section reorder. - Definition reordering_carry (w fw : Z) (p : list (Z * Z)) := - fold_right (fun t acc => - let r := Associational.carryterm w fw t in - if fst t =? w then acc ++ r else r ++ acc) nil p. - - Lemma eval_reordering_carry w fw p (_:fw<>0): - Associational.eval (reordering_carry w fw p) = Associational.eval p. - Proof. - cbv [reordering_carry]. induction p; [reflexivity |]. - autorewrite with push_fold_right. break_match; push_eval. - Qed. - End reorder. - Hint Rewrite eval_reordering_carry using solve [auto using Z.positive_is_nonzero] : push_eval. - - (* carry at specified indices in dw, then use Rows.flatten to convert to Positional with sw *) - Definition from_associational idxs n (p : list (Z * Z)) : list Z := - (* important not to use Positional.carry here; we don't want to accumulate yet *) - let p' := fold_right (fun i acc => reordering_carry (dw i) (dw (S i) / dw i) acc) (Associational.bind_snd p) (rev idxs) in - fst (Rows.flatten sw n (Rows.from_associational sw n p')). - - Lemma eval_carries p idxs : - Associational.eval (fold_right (fun i acc => reordering_carry (dw i) (dw (S i) / dw i) acc) p idxs) = - Associational.eval p. - Proof. apply fold_right_invariant; push_eval. Qed. - Hint Rewrite eval_carries: push_eval. - - Lemma eval_to_associational n m p : - m <> 0%nat -> length p = n -> - Associational.eval (to_associational n m p) = Positional.eval sw n p. - Proof. cbv [to_associational]; push_eval. Qed. - Hint Rewrite eval_to_associational using solve [push_eval; distr_length] : push_eval. - - Lemma eval_from_associational idxs n p : - n <> 0%nat -> 0 <= Associational.eval p < sw n -> - Positional.eval sw n (from_associational idxs n p) = Associational.eval p. - Proof. - cbv [from_associational]; intros. - rewrite Rows.flatten_mod by eauto using Rows.length_from_associational. - rewrite Associational.bind_snd_correct. - push_eval. - Qed. - Hint Rewrite eval_from_associational using solve [push_eval; distr_length] : push_eval. - - Lemma from_associational_partitions n idxs p (_:n<>0%nat): - forall i, (i < n)%nat -> - nth_default 0 (from_associational idxs n p) i = (Associational.eval p) mod (sw (S i)) / sw i. - Proof. - intros; cbv [from_associational]. - rewrite Rows.flatten_partitions with (n:=n) by (eauto using Rows.length_from_associational; omega). - rewrite Associational.bind_snd_correct. - push_eval. - Qed. - - Lemma from_associational_eq n idxs p (_:n<>0%nat): - from_associational idxs n p = Rows.partition sw n (Associational.eval p). - Proof. - intros. cbv [from_associational]. - rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational. - rewrite Associational.bind_snd_correct. - push_eval. - Qed. - - Derive from_associational_inlined - SuchThat (forall idxs n p, - from_associational_inlined idxs n p = from_associational idxs n p) - As from_associational_inlined_correct. - Proof. - intros. - cbv beta iota delta [from_associational reordering_carry Associational.carryterm]. - cbv beta iota delta [Let_In]. (* inlines all shifts/lands from carryterm *) - cbv beta iota delta [from_associational Rows.from_associational Columns.from_associational]. - cbv beta iota delta [Let_In]. (* inlines the shifts from place *) - subst from_associational_inlined; reflexivity. - Qed. - - Derive to_associational_inlined - SuchThat (forall n m p, - to_associational_inlined n m p = to_associational n m p) - As to_associational_inlined_correct. - Proof. - intros. - cbv beta iota delta [ to_associational convert_bases - Positional.to_associational - Positional.from_associational - chained_carries_no_reduce - carry - Associational.carry - Associational.carryterm - ]. - cbv beta iota delta [Let_In]. - subst to_associational_inlined; reflexivity. - Qed. - - (* carry chain that aligns terms in the intermediate weight with the final weight *) - Definition aligned_carries (log_dw_sw nout : nat) - := (map (fun i => ((log_dw_sw * (i + 1)) - 1))%nat (seq 0 nout)). - - Section mul_converted. - Definition mul_converted - n1 n2 (* lengths in original format *) - m1 m2 (* lengths in converted format *) - (n3 : nat) (* final length *) - (idxs : list nat) (* carries to do -- this helps preemptively line up weights *) - (p1 p2 : list Z) := - let p1_a := to_associational n1 m1 p1 in - let p2_a := to_associational n2 m2 p2 in - let p3_a := Associational.mul p1_a p2_a in - from_associational idxs n3 p3_a. - - Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): - length p1 = n1 -> length p2 = n2 -> - 0 <= (Positional.eval sw n1 p1 * Positional.eval sw n2 p2) < sw n3 -> - Positional.eval sw n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval sw n1 p1) * (Positional.eval sw n2 p2). - Proof. cbv [mul_converted]; push_eval. Qed. - Hint Rewrite eval_mul_converted : push_eval. - - Lemma mul_converted_partitions n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): - length p1 = n1 -> length p2 = n2 -> - mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = Rows.partition sw n3 (Positional.eval sw n1 p1 * Positional.eval sw n2 p2). - Proof. - intros; cbv [mul_converted]. - rewrite from_associational_eq by auto. push_eval. - Qed. - End mul_converted. - End BaseConversion. - - (* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *) - Section widemul. - Context (log2base : Z) (log2base_pos : 0 < log2base). - Context (n : nat) (n_nz : n <> 0%nat) (n_le_log2base : Z.of_nat n <= log2base) - (nout : nat) (nout_2 : nout = 2%nat). (* nout is always 2, but partial evaluation is overeager if it's a constant *) - Let dw : nat -> Z := weight (log2base / Z.of_nat n) 1. - Let sw : nat -> Z := weight log2base 1. - - Local Lemma base_bounds : 0 < 1 <= log2base. Proof. auto with zarith. Qed. - Local Lemma dbase_bounds : 0 < 1 <= log2base / Z.of_nat n. Proof. auto with zarith. Qed. - Let dwprops : @weight_properties dw := wprops (log2base / Z.of_nat n) 1 dbase_bounds. - Let swprops : @weight_properties sw := wprops log2base 1 base_bounds. - - Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg. - - Definition widemul a b := mul_converted sw dw 1 1 n n nout (aligned_carries n nout) [a] [b]. - - Lemma widemul_correct a b : - 0 <= a * b < 2^log2base * 2^log2base -> - widemul a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base]. - Proof. - cbv [widemul]; intros. - rewrite mul_converted_partitions by auto with zarith. - subst nout sw; cbv [weight]; cbn. - autorewrite with zsimplify. - rewrite Z.pow_mul_r, Z.pow_2_r by omega. - Z.rewrite_mod_small. reflexivity. - Qed. - - Derive widemul_inlined - SuchThat (forall a b, - 0 <= a * b < 2^log2base * 2^log2base -> - widemul_inlined a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base]) - As widemul_inlined_correct. - Proof. - intros. - rewrite <-widemul_correct by auto. - cbv beta iota delta [widemul mul_converted]. - rewrite <-to_associational_inlined_correct with (p:=[a]). - rewrite <-to_associational_inlined_correct with (p:=[b]). - rewrite <-from_associational_inlined_correct. - subst widemul_inlined; reflexivity. - Qed. - - Derive widemul_inlined_reverse - SuchThat (forall a b, - 0 <= a * b < 2^log2base * 2^log2base -> - widemul_inlined_reverse a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base]) - As widemul_inlined_reverse_correct. - Proof. - intros. - rewrite <-widemul_inlined_correct by assumption. - cbv [widemul_inlined]. - match goal with |- _ = from_associational_inlined sw dw ?idxs ?n ?p => - transitivity (from_associational_inlined sw dw idxs n (rev p)); - [ | transitivity (from_associational sw dw idxs n p); [ | reflexivity ] ](* reverse to make addc chains line up *) - end. - Focus 2. { - rewrite from_associational_inlined_correct by (subst nout; auto). - cbv [from_associational]. - rewrite !Rows.flatten_partitions' by eauto using Rows.length_from_associational. - rewrite !Rows.eval_from_associational by (subst nout; auto). - f_equal. - rewrite !eval_carries, !Associational.bind_snd_correct, !Associational.eval_rev by auto. - reflexivity. } Unfocus. - subst widemul_inlined_reverse; reflexivity. - Qed. - End widemul. -End BaseConversion. - -Module Import MOVEME. - Fixpoint fold_andb_map {A B} (f : A -> B -> bool) (ls1 : list A) (ls2 : list B) - : bool - := match ls1, ls2 with - | nil, nil => true - | nil, _ => false - | cons x xs, cons y ys => andb (f x y) (@fold_andb_map A B f xs ys) - | cons _ _, _ => false - end. - Lemma fold_andb_map_map {A B C} f g ls1 ls2 - : @fold_andb_map A B f ls1 (@List.map C _ g ls2) - = fold_andb_map (fun a b => f a (g b)) ls1 ls2. - Proof. revert ls1 ls2; induction ls1, ls2; cbn; congruence. Qed. - - Lemma fold_andb_map_length A B f ls1 ls2 - (H : @fold_andb_map A B f ls1 ls2 = true) - : length ls1 = length ls2. - Proof. - revert ls1 ls2 H; induction ls1, ls2; cbn; intros; Bool.split_andb; f_equal; - try congruence; auto. - Qed. -End MOVEME. - -Definition expanding_id (n : nat) (ls : list Z) := expand_list (-1)%Z ls n. - -Lemma expanding_id_id n ls (H : List.length ls = n) - : expanding_id n ls = ls. -Proof. - unfold expanding_id. rewrite expand_list_correct by assumption; reflexivity. -Qed. - -Module Ring. - Local Notation is_bounded_by0 r v - := ((lower r <=? v) && (v <=? upper r)). - Local Notation is_bounded_by0o r - := (match r with Some r' => fun v => is_bounded_by0 r' v | None => fun _ => true end). - Local Notation is_bounded_by bounds ls - := (fold_andb_map (fun r v => is_bounded_by0o r v) bounds ls). - Local Notation is_bounded_by2 bounds ls - := (let '(a, b) := ls in andb (is_bounded_by bounds a) (is_bounded_by bounds b)). - - Lemma length_is_bounded_by bounds ls - : is_bounded_by bounds ls = true -> length ls = length bounds. - Proof. - intro H. - apply fold_andb_map_length in H; congruence. - Qed. - - Section ring_goal. - Context (limbwidth_num limbwidth_den : Z) - (n : nat) - (s : Z) - (c : list (Z * Z)) - (tight_bounds : list (option zrange)) - (length_tight_bounds : length tight_bounds = n) - (loose_bounds : list (option zrange)) - (length_loose_bounds : length loose_bounds = n). - Local Notation weight := (weight limbwidth_num limbwidth_den). - Local Notation eval := (Positional.eval weight n). - Let prime_bound : zrange - := r[0~>(s - Associational.eval c - 1)]%zrange. - Let m := Z.to_pos (s - Associational.eval c). - Context (m_eq : Z.pos m = s - Associational.eval c) - (sc_pos : 0 < s - Associational.eval c) - (Interp_rrelaxv : list Z -> list Z) - (HInterp_rrelaxv : forall arg, - is_bounded_by tight_bounds arg = true - -> is_bounded_by loose_bounds (Interp_rrelaxv arg) = true - /\ Interp_rrelaxv arg = expanding_id n arg) - (carry_mulmod : list Z -> list Z -> list Z) - (Hcarry_mulmod - : forall f g, - length f = n -> length g = n -> - (eval (carry_mulmod f g)) mod (s - Associational.eval c) - = (eval f * eval g) mod (s - Associational.eval c)) - (Interp_rcarry_mulv : list Z * list Z -> list Z) - (HInterp_rcarry_mulv : forall arg, - is_bounded_by2 loose_bounds arg = true - -> is_bounded_by tight_bounds (Interp_rcarry_mulv arg) = true - /\ Interp_rcarry_mulv arg = carry_mulmod (fst arg) (snd arg)) - (carrymod : list Z -> list Z) - (Hcarrymod - : forall f, - length f = n -> - (eval (carrymod f)) mod (s - Associational.eval c) - = (eval f) mod (s - Associational.eval c)) - (Interp_rcarryv : list Z -> list Z) - (HInterp_rcarryv : forall arg, - is_bounded_by loose_bounds arg = true - -> is_bounded_by tight_bounds (Interp_rcarryv arg) = true - /\ Interp_rcarryv arg = carrymod arg) - (addmod : list Z -> list Z -> list Z) - (Haddmod - : forall f g, - length f = n -> length g = n -> - (eval (addmod f g)) mod (s - Associational.eval c) - = (eval f + eval g) mod (s - Associational.eval c)) - (Interp_raddv : list Z * list Z -> list Z) - (HInterp_raddv : forall arg, - is_bounded_by2 tight_bounds arg = true - -> is_bounded_by loose_bounds (Interp_raddv arg) = true - /\ Interp_raddv arg = addmod (fst arg) (snd arg)) - (submod : list Z -> list Z -> list Z) - (Hsubmod - : forall f g, - length f = n -> length g = n -> - (eval (submod f g)) mod (s - Associational.eval c) - = (eval f - eval g) mod (s - Associational.eval c)) - (Interp_rsubv : list Z * list Z -> list Z) - (HInterp_rsubv : forall arg, - is_bounded_by2 tight_bounds arg = true - -> is_bounded_by loose_bounds (Interp_rsubv arg) = true - /\ Interp_rsubv arg = submod (fst arg) (snd arg)) - (oppmod : list Z -> list Z) - (Hoppmod - : forall f, - length f = n -> - (eval (oppmod f)) mod (s - Associational.eval c) - = (- eval f) mod (s - Associational.eval c)) - (Interp_roppv : list Z -> list Z) - (HInterp_roppv : forall arg, - is_bounded_by tight_bounds arg = true - -> is_bounded_by loose_bounds (Interp_roppv arg) = true - /\ Interp_roppv arg = oppmod arg) - (zeromod : list Z) - (Hzeromod - : (eval zeromod) mod (s - Associational.eval c) - = 0 mod (s - Associational.eval c)) - (Interp_rzerov : list Z) - (HInterp_rzerov : is_bounded_by tight_bounds Interp_rzerov = true - /\ Interp_rzerov = zeromod) - (onemod : list Z) - (Honemod - : (eval onemod) mod (s - Associational.eval c) - = 1 mod (s - Associational.eval c)) - (Interp_ronev : list Z) - (HInterp_ronev : is_bounded_by tight_bounds Interp_ronev = true - /\ Interp_ronev = onemod) - (encodemod : Z -> list Z) - (Hencodemod - : forall f, - (eval (encodemod f)) mod (s - Associational.eval c) - = f mod (s - Associational.eval c)) - (Interp_rencodev : Z -> list Z) - (HInterp_rencodev : forall arg, - is_bounded_by0 prime_bound arg = true - -> is_bounded_by tight_bounds (Interp_rencodev arg) = true - /\ Interp_rencodev arg = encodemod arg). - - Local Notation T := (list Z) (only parsing). - Local Notation encoded_ok ls - := (is_bounded_by tight_bounds ls = true) (only parsing). - Local Notation encoded_okf := (fun ls => encoded_ok ls) (only parsing). - - Definition Fdecode (v : T) : F m - := F.of_Z m (Positional.eval weight n v). - Definition T_eq (x y : T) - := Fdecode x = Fdecode y. - - Definition encodedT := sig encoded_okf. - - Definition ring_mul (x y : T) : T - := Interp_rcarry_mulv (Interp_rrelaxv x, Interp_rrelaxv y). - Definition ring_add (x y : T) : T := Interp_rcarryv (Interp_raddv (x, y)). - Definition ring_sub (x y : T) : T := Interp_rcarryv (Interp_rsubv (x, y)). - Definition ring_opp (x : T) : T := Interp_rcarryv (Interp_roppv x). - Definition ring_encode (x : F m) : T := Interp_rencodev (F.to_Z x). - - Definition GoodT : Prop - := @subsetoid_ring - (list Z) encoded_okf T_eq - Interp_rzerov Interp_ronev ring_opp ring_add ring_sub ring_mul - /\ @is_subsetoid_homomorphism - (F m) (fun _ => True) eq 1%F F.add F.mul - (list Z) encoded_okf T_eq Interp_ronev ring_add ring_mul ring_encode - /\ @is_subsetoid_homomorphism - (list Z) encoded_okf T_eq Interp_ronev ring_add ring_mul - (F m) (fun _ => True) eq 1%F F.add F.mul - Fdecode. - - Hint Rewrite ->@F.to_Z_add : push_FtoZ. - Hint Rewrite ->@F.to_Z_mul : push_FtoZ. - Hint Rewrite ->@F.to_Z_opp : push_FtoZ. - Hint Rewrite ->@F.to_Z_of_Z : push_FtoZ. - - Lemma Fm_bounded_alt (x : F m) - : (0 <=? F.to_Z x) && (F.to_Z x <=? Z.pos m - 1) = true. - Proof using m_eq. - clear -m_eq. - destruct x as [x H]; cbn [F.to_Z proj1_sig]. - pose proof (Z.mod_pos_bound x (Z.pos m)). - rewrite andb_true_iff; split; Z.ltb_to_lt; lia. - Qed. - - Lemma Good : GoodT. - Proof. - split_and; simpl in *. - eapply subsetoid_ring_by_ring_isomorphism; - cbv [ring_opp ring_add ring_sub ring_mul ring_encode F.sub] in *; - repeat match goal with - | _ => solve [ auto using andb_true_intro, conj with nocore ] - | _ => progress intros - | _ => progress cbn [fst snd] - | [ H : _ |- is_bounded_by _ _ = true ] => apply H - | [ |- _ <-> _ ] => reflexivity - | [ |- _ = _ :> Z ] => first [ reflexivity | rewrite <- m_eq; reflexivity ] - | [ H : context[?x] |- Fdecode ?x = _ ] => rewrite H - | [ H : context[?x _] |- Fdecode (?x _) = _ ] => rewrite H - | [ H : context[?x _ _] |- Fdecode (?x _ _) = _ ] => rewrite H - | _ => progress cbv [Fdecode] - | [ |- _ = _ :> F _ ] => apply F.eq_to_Z_iff - | _ => progress autorewrite with push_FtoZ - | _ => rewrite m_eq - | [ H : context[?x _ _] |- context[eval (?x _ _)] ] => rewrite H - | [ H : context[?x _] |- context[eval (?x _)] ] => rewrite H - | [ H : context[?x] |- context[eval ?x] ] => rewrite H - | [ |- context[List.length ?x] ] - => erewrite (length_is_bounded_by _ x) - by eauto using andb_true_intro, conj with nocore - | [ |- _ = _ :> Z ] - => push_Zmod; reflexivity - | _ => pull_Zmod; rewrite Z.add_opp_r - | _ => rewrite expanding_id_id - | [ |- context[F.to_Z _ mod (_ - _)] ] - => rewrite <- m_eq, F.mod_to_Z - | _ => rewrite <- m_eq; apply Fm_bounded_alt - end. - Qed. - End ring_goal. -End Ring. - -Module Compilers. - Module type. - Variant primitive := unit | Z | nat | bool. - Inductive type := type_primitive (_:primitive) | prod (A B : type) | arrow (s d : type) | list (A : type). - Module Export Coercions. - Global Coercion type_primitive : primitive >-> type. - End Coercions. - - (** Denote [type]s into their interpretation in [Type]/[Set] *) - Fixpoint interp (t : type) - := match t with - | unit => Datatypes.unit - | prod A B => interp A * interp B - | arrow A B => interp A -> interp B - | list A => Datatypes.list (interp A) - | nat => Datatypes.nat - | type_primitive Z => BinInt.Z - | bool => Datatypes.bool - end%type. - - Fixpoint final_codomain (t : type) : type - := match t with - | type_primitive _ as t - | prod _ _ as t - | list _ as t - => t - | arrow s d => final_codomain d - end. - - Definition domain (t : type) : type - := match t with - | arrow s d => s - | _ => type_primitive unit - end. - - Definition codomain (t : type) : type - := match t with - | arrow s d => d - | t => t - end. - - Fixpoint try_transport (P : type -> Type) (t1 t2 : type) : P t1 -> option (P t2) - := match t1, t2 return P t1 -> option (P t2) with - | unit, unit - | Z, Z - | nat, nat - | bool, bool - => @Some _ - | prod A B, prod A' B' - => fun v - => (v <- try_transport (fun A => P (prod A B)) A A' v; - try_transport (fun B => P (prod A' B)) B B' v)%option - | arrow s d, arrow s' d' - => fun v - => (v <- try_transport (fun s => P (arrow s d)) s s' v; - try_transport (fun d => P (arrow s' d)) d d' v)%option - | list A, list A' - => @try_transport (fun A => P (list A)) A A' - | unit, _ - | Z, _ - | nat, _ - | bool, _ - | prod _ _, _ - | arrow _ _, _ - | list _, _ - => fun _ => None - end. - - Ltac reify_primitive ty := - lazymatch eval cbv beta in ty with - | Datatypes.unit => unit - | Datatypes.nat => nat - | Datatypes.bool => bool - | BinInt.Z => Z - | ?ty => let dummy := match goal with - | _ => fail 1 "Unrecognized type:" ty - end in - constr:(I : I) - end. - - Ltac reify ty := - lazymatch eval cbv beta in ty with - | Datatypes.prod ?A ?B - => let rA := reify A in - let rB := reify B in - constr:(prod rA rB) - | ?A -> ?B - => let rA := reify A in - let rB := reify B in - constr:(arrow rA rB) - | Datatypes.list ?T - => let rT := reify T in - constr:(list rT) - | type.interp ?T => T - | _ => let rt := reify_primitive ty in - constr:(type_primitive rt) - end. - - Notation reify t := (ltac:(let rt := reify t in exact rt)) (only parsing). - Notation reify_type_of e := (reify ((fun t (_ : t) => t) _ e)) (only parsing). - - Module Export Notations. - Export Coercions. - Delimit Scope ctype_scope with ctype. - Bind Scope ctype_scope with type. - Notation "()" := unit : ctype_scope. - Notation "A * B" := (prod A B) : ctype_scope. - Notation "A -> B" := (arrow A B) : ctype_scope. - Notation type := type. - End Notations. - End type. - Export type.Notations. - - Module Uncurried. - Module expr. - Inductive expr {ident : type -> type -> Type} {var : type -> Type} : type -> Type := - | Var {t} (v : var t) : expr t - | TT : expr type.unit - | AppIdent {s d} (idc : ident s d) (args : expr s) : expr d - | App {s d} (f : expr (s -> d)) (x : expr s) : expr d - | Pair {A B} (a : expr A) (b : expr B) : expr (A * B) - | Abs {s d} (f : var s -> expr d) : expr (s -> d). - - Definition Expr {ident : type -> type -> Type} t := forall var, @expr ident var t. - - Definition APP {ident s d} (f : Expr (s -> d)) (x : Expr s) : Expr d - := fun var => @App ident var s d (f var) (x var). - - Module Export Notations. - Bind Scope expr_scope with expr. - Delimit Scope expr_scope with expr. - Bind Scope Expr_scope with Expr. - Delimit Scope Expr_scope with Expr. - - Infix "@" := App : expr_scope. - Infix "@" := APP : Expr_scope. - Infix "@@" := AppIdent : expr_scope. - Notation "( x , y , .. , z )" := (Pair .. (Pair x%expr y%expr) .. z%expr) : expr_scope. - Notation "( )" := TT : expr_scope. - Notation "()" := TT : expr_scope. - Notation "'λ' x .. y , t" := (Abs (fun x => .. (Abs (fun y => t%expr)) ..)) : expr_scope. - End Notations. - - Section unexpr. - Context {ident : type -> type -> Type} - {var : type -> Type}. - - Fixpoint unexpr {t} (e : @expr ident (@expr ident var) t) : @expr ident var t - := match e in expr t return expr t with - | Var t v => v - | TT => TT - | AppIdent s d idc args => AppIdent idc (unexpr args) - | App s d f x => App (unexpr f) (unexpr x) - | Pair A B a b => Pair (unexpr a) (unexpr b) - | Abs s d f => Abs (fun x => unexpr (f (Var x))) - end. - End unexpr. - - Section with_ident. - Context {ident : type -> type -> Type} - (interp_ident : forall s d, ident s d -> type.interp s -> type.interp d). - - (** Denote expressions *) - Fixpoint interp {t} (e : @expr ident type.interp t) : type.interp t - := match e with - | Var t v => v - | TT => tt - | AppIdent s d idc args => interp_ident s d idc (@interp s args) - | App s d f x => @interp _ f (@interp _ x) - | Pair A B a b => (@interp A a, @interp B b) - | Abs s d f => fun v => interp (f v) - end. - - Definition Interp {t} (e : Expr t) := interp (e _). - - (** [Interp (APP _ _)] is the same thing as Gallina - application of the [Interp]retations of the two arguments - to [APP]. *) - Definition Interp_APP {s d} (f : @Expr ident (s -> d)) (x : @Expr ident s) - : Interp (f @ x)%Expr = Interp f (Interp x) - := eq_refl. - - (** Same as [Interp_APP], but for any reflexive relation, not - just [eq] *) - Definition Interp_APP_rel_reflexive {s d} {R} {H:Reflexive R} - (f : @Expr ident (s -> d)) (x : @Expr ident s) - : R (Interp (f @ x)%Expr) (Interp f (Interp x)) - := H _. - End with_ident. - - Ltac require_primitive_const term := - lazymatch term with - | S ?n => require_primitive_const n - | O => idtac - | true => idtac - | false => idtac - | tt => idtac - | Z0 => idtac - | Zpos ?p => require_primitive_const p - | Zneg ?p => require_primitive_const p - | xI ?p => require_primitive_const p - | xO ?p => require_primitive_const p - | xH => idtac - | ?term => fail 0 "Not a known const:" term - end. - Ltac is_primitive_const term := - match constr:(Set) with - | _ => let check := match goal with - | _ => require_primitive_const term - end in - true - | _ => false - end. - - Module var_context. - Inductive list {var : type -> Type} := - | nil - | cons {t} (gallina_v : type.interp t) (v : var t) (ctx : list). - End var_context. - - (* cf COQBUG(https://github.com/coq/coq/issues/5448) , COQBUG(https://github.com/coq/coq/issues/6315) , COQBUG(https://github.com/coq/coq/issues/6559) , COQBUG(https://github.com/coq/coq/issues/6534) , https://github.com/mit-plv/fiat-crypto/issues/320 *) - Ltac require_same_var n1 n2 := - (*idtac n1 n2;*) - let c1 := constr:(fun n1 n2 : Set => ltac:(exact n1)) in - let c2 := constr:(fun n1 n2 : Set => ltac:(exact n2)) in - (*idtac c1 c2;*) - first [ constr_eq c1 c2 | fail 1 "Not the same var:" n1 "and" n2 "(via constr_eq" c1 c2 ")" ]. - Ltac is_same_var n1 n2 := - match goal with - | _ => let check := match goal with _ => require_same_var n1 n2 end in - true - | _ => false - end. - Ltac is_underscore v := - let v' := fresh v in - let v' := fresh v' in - is_same_var v v'. - Ltac refresh n fresh_tac := - let n_is_underscore := is_underscore n in - let n' := lazymatch n_is_underscore with - | true => fresh - | false => fresh_tac n - end in - let n' := fresh_tac n' in - n'. - - Ltac type_of_first_argument_of f := - let f_ty := type of f in - lazymatch eval hnf in f_ty with - | forall x : ?T, _ => T - end. - - (** Forms of abstraction in Gallina that our reflective language - cannot handle get handled by specializing the code "template" to - each particular application of that abstraction. In particular, - type arguments (nat, Z, (λ _, nat), etc) get substituted into - lambdas and treated as a integral part of primitive operations - (such as [@List.app T], [@list_rect (λ _, nat)]). During - reification, we accumulate them in a right-associated tuple, - using [tt] as the "nil" base case. When we hit a λ or an - identifier, we plug in the template parameters as necessary. *) - Ltac require_template_parameter parameter_type := - first [ unify parameter_type Prop - | unify parameter_type Set - | unify parameter_type Type - | lazymatch eval hnf in parameter_type with - | forall x : ?T, @?P x - => let check := constr:(fun x : T - => ltac:(require_template_parameter (P x); - exact I)) in - idtac - end ]. - Ltac is_template_parameter parameter_type := - is_success_run_tactic ltac:(fun _ => require_template_parameter parameter_type). - Ltac plug_template_ctx f template_ctx := - lazymatch template_ctx with - | tt => f - | (?arg, ?template_ctx') - => - let T := type_of_first_argument_of f in - let x_is_template_parameter := is_template_parameter T in - lazymatch x_is_template_parameter with - | true - => plug_template_ctx (f arg) template_ctx' - | false - => constr:(fun x : T - => ltac:(let v := plug_template_ctx (f x) template_ctx in - exact v)) - end - end. - - Ltac reify_in_context ident reify_ident var term value_ctx template_ctx := - let reify_rec_gen term value_ctx template_ctx := reify_in_context ident reify_ident var term value_ctx template_ctx in - let reify_rec term := reify_rec_gen term value_ctx template_ctx in - let reify_rec_not_head term := reify_rec_gen term value_ctx tt in - let mkAppIdent idc args - := let rargs := reify_rec_not_head args in - constr:(@AppIdent ident var _ _ idc rargs) in - let do_reify_ident term else_tac - := let term_is_primitive_const := is_primitive_const term in - reify_ident - mkAppIdent - term_is_primitive_const - term - else_tac in - (*let dummy := match goal with _ => idtac "reify_in_context: attempting to reify:" term end in*) - lazymatch value_ctx with - | context[@var_context.cons _ ?rT term ?v _] - => constr:(@Var ident var rT v) - | _ - => - lazymatch term with - | match ?b with true => ?t | false => ?f end - => let T := type of t in - reify_rec (@bool_rect (fun _ => T) t f b) - | match ?x with Datatypes.pair a b => ?f end - => reify_rec (match Datatypes.fst x, Datatypes.snd x return _ with - | a, b => f - end) - | match ?x with nil => ?N | cons a b => @?C a b end - => let T := type of term in - reify_rec (@list_case _ (fun _ => T) N C x) - | let x := ?a in @?b x - => let A := type of a in - let B := lazymatch type of b with forall x, @?B x => B end in - reify_rec (b a) (*(@Let_In A B a b)*) - | Datatypes.pair ?x ?y - => let rx := reify_rec x in - let ry := reify_rec y in - constr:(Pair (ident:=ident) (var:=var) rx ry) - | tt - => constr:(@TT ident var) - | (fun x : ?T => ?f) - => - let x_is_template_parameter := is_template_parameter T in - lazymatch x_is_template_parameter with - | true - => - lazymatch template_ctx with - | (?arg, ?template_ctx) - => (* we pull a trick with [match] to plug in [arg] without running cbv β *) - lazymatch type of term with - | forall y, ?P - => reify_rec_gen (match arg as y return P with x => f end) value_ctx template_ctx - end - end - | false - => - let rT := type.reify T in - let not_x := fresh (* could be [refresh x ltac:(fun n => fresh n)] in 8.8; c.f. https://github.com/mit-plv/fiat-crypto/issues/320 and probably COQBUG(https://github.com/coq/coq/issues/6534) *) in - let not_x2 := fresh (* could be [refresh not_x ltac:(fun n => fresh n)] in 8.8; c.f. https://github.com/mit-plv/fiat-crypto/issues/320 and probably COQBUG(https://github.com/coq/coq/issues/6534) *) in - let not_x3 := fresh (* could be [refresh not_x2 ltac:(fun n => fresh n)] in 8.8; c.f. https://github.com/mit-plv/fiat-crypto/issues/320 and probably COQBUG(https://github.com/coq/coq/issues/6534) *) in - (*let dummy := match goal with _ => idtac "reify_in_context: λ case:" term "using vars:" not_x not_x2 not_x3 end in*) - let rf0 := - constr:( - fun (x : T) (not_x : var rT) - => match f, @var_context.cons var rT x not_x value_ctx return _ with (* c.f. COQBUG(https://github.com/coq/coq/issues/6252#issuecomment-347041995) for [return _] *) - | not_x2, not_x3 - => ltac:( - let f := (eval cbv delta [not_x2] in not_x2) in - let var_ctx := (eval cbv delta [not_x3] in not_x3) in - (*idtac "rec call" f "was" term;*) - let rf := reify_rec_gen f var_ctx template_ctx in - exact rf) - end) in - lazymatch rf0 with - | (fun _ => ?rf) - => constr:(@Abs ident var rT _ rf) - | _ - => (* This will happen if the reified term still - mentions the non-var variable. By chance, [cbv delta] - strips type casts, which are only places that I can - think of where such dependency might remain. However, - if this does come up, having a distinctive error message - is much more useful for debugging than the generic "no - matching clause" *) - let dummy := match goal with - | _ => fail 1 "Failure to eliminate functional dependencies of" rf0 - end in - constr:(I : I) - end - end - | _ - => - do_reify_ident - term - ltac:( - fun _ - => - lazymatch term with - | ?f ?x - => - let ty := type_of_first_argument_of f in - let x_is_template_parameter := is_template_parameter ty in - lazymatch x_is_template_parameter with - | true - => (* we can't reify things of type [Type], so we save it for later to plug in *) - reify_rec_gen f value_ctx (x, template_ctx) - | false - => let rx := reify_rec_gen x value_ctx tt in - let rf := reify_rec_gen f value_ctx template_ctx in - constr:(App (ident:=ident) (var:=var) rf rx) - end - | _ - => let term' := plug_template_ctx term template_ctx in - do_reify_ident - term' - ltac:(fun _ - => - (*let __ := match goal with _ => idtac "Attempting to unfold" term end in*) - let term - := match constr:(Set) with - | _ => (eval cbv delta [term] in term) (* might fail, so we wrap it in a match to give better error messages *) - | _ - => let dummy := match goal with - | _ => fail 2 "Unrecognized term:" term' - end in - constr:(I : I) - end in - reify_rec term) - end) - end - end. - Ltac reify ident reify_ident var term := - reify_in_context ident reify_ident var term (@var_context.nil var) tt. - Ltac Reify ident reify_ident term := - constr:(fun var : type -> Type - => ltac:(let r := reify ident reify_ident var term in - exact r)). - Ltac Reify_rhs ident reify_ident interp_ident _ := - let RHS := lazymatch goal with |- _ = ?RHS => RHS end in - let R := Reify ident reify_ident RHS in - transitivity (@Interp ident interp_ident _ R); - [ | cbv beta iota delta [Interp interp interp_ident Let_In type.interp bool_rect]; - reflexivity ]. - - Module for_reification. - Module ident. - Import type. - Inductive ident : type -> type -> Set := - | primitive {t:type.primitive} (v : interp t) : ident () t - | Let_In {tx tC} : ident (tx * (tx -> tC)) tC - | Nat_succ : ident nat nat - | Nat_max : ident (nat * nat) nat - | Nat_mul : ident (nat * nat) nat - | Nat_add : ident (nat * nat) nat - | Nat_sub : ident (nat * nat) nat - | nil {t} : ident () (list t) - | cons {t} : ident (t * list t) (list t) - | fst {A B} : ident (A * B) A - | snd {A B} : ident (A * B) B - | bool_rect {T} : ident ((unit -> T) * (unit -> T) * bool) T - | nat_rect {P} : ident ((unit -> P) * (nat * P -> P) * nat) P - | list_rect {A P} : ident ((unit -> P) * (A * list A * P -> P) * list A) P - | list_case {A P} : ident ((unit -> P) * (A * list A -> P) * list A) P - | pred : ident nat nat - | List_length {T} : ident (list T) nat - | List_seq : ident (nat * nat) (list nat) - | List_repeat {A} : ident (A * nat) (list A) - | List_combine {A B} : ident (list A * list B) (list (A * B)) - | List_map {A B} : ident ((A -> B) * list A) (list B) - | List_flat_map {A B} : ident ((A -> list B) * list A) (list B) - | List_partition {A} : ident ((A -> bool) * list A) (list A * list A) - | List_app {A} : ident (list A * list A) (list A) - | List_rev {A} : ident (list A) (list A) - | List_fold_right {A B} : ident ((B * A -> A) * A * list B) A - | List_update_nth {T} : ident (nat * (T -> T) * list T) (list T) - | List_nth_default {T} : ident (T * list T * nat) T - | Z_add : ident (Z * Z) Z - | Z_mul : ident (Z * Z) Z - | Z_pow : ident (Z * Z) Z - | Z_sub : ident (Z * Z) Z - | Z_opp : ident Z Z - | Z_div : ident (Z * Z) Z - | Z_modulo : ident (Z * Z) Z - | Z_eqb : ident (Z * Z) bool - | Z_leb : ident (Z * Z) bool - | Z_of_nat : ident nat Z - | Z_mul_split : ident (Z * Z * Z) (Z * Z) - | Z_add_get_carry : ident (Z * Z * Z) (Z * Z) - | Z_add_with_carry : ident (Z * Z * Z) Z - | Z_add_with_get_carry : ident (Z * Z * Z * Z) (Z * Z) - | Z_sub_get_borrow : ident (Z * Z * Z) (Z * Z) - | Z_sub_with_get_borrow : ident (Z * Z * Z * Z) (Z * Z) - | Z_zselect : ident (Z * Z * Z) Z - | Z_add_modulo : ident (Z * Z * Z) Z - | Z_rshi : ident (Z * Z * Z * Z) Z - | Z_cc_m : ident (Z * Z) Z - . - - Notation curry0 f - := (fun 'tt => f). - Notation curry2 f - := (fun '(a, b) => f a b). - Notation curry3 f - := (fun '(a, b, c) => f a b c). - Notation curry4 f - := (fun '(a, b, c, d) => f a b c d). - Notation uncurry2 f - := (fun a b => f (a, b)). - Notation uncurry3 f - := (fun a b c => f (a, b, c)). - Notation curry3_1 f - := (fun '(a, b, c) => f (uncurry2 a) b c). - Notation curry3_2 f - := (fun '(a, b, c) => f a (uncurry2 b) c). - Notation curry3_3 f - := (fun '(a, b, c) => f a (uncurry3 b) c). - - (** Denote identifiers *) - Definition interp {s d} (idc : ident s d) : type.interp s -> type.interp d - := match idc in ident s d return type.interp s -> type.interp d with - | primitive _ v => curry0 v - | Let_In tx tC => curry2 (@LetIn.Let_In (type.interp tx) (fun _ => type.interp tC)) - | Nat_succ => Nat.succ - | Nat_add => curry2 Nat.add - | Nat_sub => curry2 Nat.sub - | Nat_mul => curry2 Nat.mul - | Nat_max => curry2 Nat.max - | nil t => curry0 (@Datatypes.nil (type.interp t)) - | cons t => curry2 (@Datatypes.cons (type.interp t)) - | fst A B => @Datatypes.fst (type.interp A) (type.interp B) - | snd A B => @Datatypes.snd (type.interp A) (type.interp B) - | bool_rect T => curry3 (fun t f => @Datatypes.bool_rect (fun _ => type.interp T) (t tt) (f tt)) - | nat_rect P => curry3_2 (fun O_case => @Datatypes.nat_rect (fun _ => type.interp P) (O_case tt)) - | list_rect A P => curry3_3 (fun N_case => @Datatypes.list_rect (type.interp A) (fun _ => type.interp P) (N_case tt)) - | list_case A P => curry3_2 (fun N_case => @ListUtil.list_case (type.interp A) (fun _ => type.interp P) (N_case tt)) - | pred => Nat.pred - | List_length T => @List.length (type.interp T) - | List_seq => curry2 List.seq - | List_combine A B => curry2 (@List.combine (type.interp A) (type.interp B)) - | List_map A B => curry2 (@List.map (type.interp A) (type.interp B)) - | List_repeat A => curry2 (@repeat (type.interp A)) - | List_flat_map A B => curry2 (@List.flat_map (type.interp A) (type.interp B)) - | List_partition A => curry2 (@List.partition (type.interp A)) - | List_app A => curry2 (@List.app (type.interp A)) - | List_rev A => @List.rev (type.interp A) - | List_fold_right A B => curry3_1 (@List.fold_right (type.interp A) (type.interp B)) - | List_update_nth T => curry3 (@update_nth (type.interp T)) - | List_nth_default T => curry3 (@List.nth_default (type.interp T)) - | Z_add => curry2 Z.add - | Z_mul => curry2 Z.mul - | Z_pow => curry2 Z.pow - | Z_modulo => curry2 Z.modulo - | Z_opp => Z.opp - | Z_sub => curry2 Z.sub - | Z_div => curry2 Z.div - | Z_eqb => curry2 Z.eqb - | Z_leb => curry2 Z.leb - | Z_of_nat => Z.of_nat - | Z_mul_split => curry3 Z.mul_split - | Z_add_get_carry => curry3 Z.add_get_carry_full - | Z_add_with_carry => curry3 Z.add_with_carry - | Z_add_with_get_carry => curry4 Z.add_with_get_carry_full - | Z_sub_get_borrow => curry3 Z.sub_get_borrow_full - | Z_sub_with_get_borrow => curry4 Z.sub_with_get_borrow_full - | Z_zselect => curry3 Z.zselect - | Z_add_modulo => curry3 Z.add_modulo - | Z_rshi => curry4 Z.rshi - | Z_cc_m => curry2 Z.cc_m - end. - - Ltac reify - mkAppIdent - term_is_primitive_const - term - else_tac := - (*let dummy := match goal with _ => idtac "attempting to reify_op" term end in*) - lazymatch term with - | Nat.succ ?x => mkAppIdent Nat_succ x - | Nat.add ?x ?y => mkAppIdent Nat_add (x, y) - | Nat.sub ?x ?y => mkAppIdent Nat_sub (x, y) - | Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y) - | Nat.max ?x ?y => mkAppIdent Nat_max (x, y) - | S ?x => mkAppIdent Nat_succ x - | @Datatypes.nil ?T - => let rT := type.reify T in - mkAppIdent (@ident.nil rT) tt - | @Datatypes.cons ?T ?x ?xs - => let rT := type.reify T in - mkAppIdent (@ident.cons rT) (x, xs) - | @Datatypes.fst ?A ?B ?x - => let rA := type.reify A in - let rB := type.reify B in - mkAppIdent (@ident.fst rA rB) x - | @Datatypes.snd ?A ?B ?x - => let rA := type.reify A in - let rB := type.reify B in - mkAppIdent (@ident.snd rA rB) x - | @Datatypes.bool_rect (fun _ => ?T) ?Ptrue ?Pfalse ?b - => let rT := type.reify T in - mkAppIdent (@ident.bool_rect rT) - ((fun _ : Datatypes.unit => Ptrue), (fun _ : Datatypes.unit => Pfalse), b) - | @Datatypes.nat_rect (fun _ => ?T) ?P0 (fun (n' : Datatypes.nat) Pn => ?PS) ?n - => let rT := type.reify T in - let pat := fresh "pat" in (* fresh for COQBUG(https://github.com/coq/coq/issues/6562) *) - mkAppIdent (@ident.nat_rect rT) ((fun _ : Datatypes.unit => P0), - (fun pat : Datatypes.nat * T - => let '(n', Pn) := pat in PS), - n) - | @Datatypes.nat_rect (fun _ => ?T) ?P0 ?PS ?n - => let dummy := match goal with _ => fail 1 "nat_rect successor case is not syntactically a function of two arguments:" PS end in - constr:(I : I) - | @Datatypes.list_rect ?A (fun _ => ?T) ?Pnil (fun a tl Ptl => ?PS) ?ls - => let rA := type.reify A in - let rT := type.reify T in - let pat := fresh "pat" in (* fresh for COQBUG(https://github.com/coq/coq/issues/6562) *) - mkAppIdent (@ident.list_rect rA rT) - ((fun _ : Datatypes.unit => Pnil), - (fun pat : A * Datatypes.list A * T - => let '(a, tl, Ptl) := pat in PS), - ls) - | @Datatypes.list_rect ?A (fun _ => ?T) ?Pnil ?PS ?ls - => let dummy := match goal with _ => fail 1 "list_rect successor case is not syntactically a function of three arguments:" PS end in - constr:(I : I) - | @ListUtil.list_case ?A (fun _ => ?T) ?Pnil (fun a tl => ?PS) ?ls - => let rA := type.reify A in - let rT := type.reify T in - let pat := fresh "pat" in (* fresh for COQBUG(https://github.com/coq/coq/issues/6562) *) - mkAppIdent (@ident.list_case rA rT) - ((fun _ : Datatypes.unit => Pnil), - (fun pat : A * Datatypes.list A - => let '(a, tl) := pat in PS), - ls) - | @ListUtil.list_case ?A (fun _ => ?T) ?Pnil ?PS ?ls - => let dummy := match goal with _ => fail 1 "list_case successor case is not syntactically a function of two arguments:" PS end in - constr:(I : I) - | Nat.pred ?x => mkAppIdent ident.pred x - | @List.length ?A ?x => - let rA := type.reify A in - mkAppIdent (@ident.List_length rA) x - | List.seq ?x ?y => mkAppIdent ident.List_seq (x, y) - | @repeat ?A ?x ?y - => let rA := type.reify A in - mkAppIdent (@ident.List_repeat rA) (x, y) - | @LetIn.Let_In ?A (fun _ => ?B) ?x ?f - => let rA := type.reify A in - let rB := type.reify B in - mkAppIdent (@ident.Let_In rA rB) (x, f) - | @LetIn.Let_In ?A ?B ?x ?f - => let dummy := match goal with _ => fail 1 "Let_In contains a dependent type λ as its second argument:" B end in - constr:(I : I) - | @combine ?A ?B ?ls1 ?ls2 - => let rA := type.reify A in - let rB := type.reify B in - mkAppIdent (@ident.List_combine rA rB) (ls1, ls2) - | @List.map ?A ?B ?f ?ls - => let rA := type.reify A in - let rB := type.reify B in - mkAppIdent (@ident.List_map rA rB) (f, ls) - | @List.flat_map ?A ?B ?f ?ls - => let rA := type.reify A in - let rB := type.reify B in - mkAppIdent (@ident.List_flat_map rA rB) (f, ls) - | @List.partition ?A ?f ?ls - => let rA := type.reify A in - mkAppIdent (@ident.List_partition rA) (f, ls) - | @List.app ?A ?ls1 ?ls2 - => let rA := type.reify A in - mkAppIdent (@ident.List_app rA) (ls1, ls2) - | @List.rev ?A ?ls - => let rA := type.reify A in - mkAppIdent (@ident.List_rev rA) ls - | @List.fold_right ?A ?B (fun b a => ?f) ?a0 ?ls - => let rA := type.reify A in - let rB := type.reify B in - let pat := fresh "pat" in (* fresh for COQBUG(https://github.com/coq/coq/issues/6562) *) - mkAppIdent (@ident.List_fold_right rA rB) ((fun pat : B * A => let '(b, a) := pat in f), a0, ls) - | @List.fold_right ?A ?B ?f ?a0 ?ls - => let dummy := match goal with _ => fail 1 "List.fold_right function argument is not syntactically a function of two arguments:" f end in - constr:(I : I) - | @update_nth ?T ?n ?f ?ls - => let rT := type.reify T in - mkAppIdent (@ident.List_update_nth rT) (n, f, ls) - | @List.nth_default ?T ?d ?ls ?n - => let rT := type.reify T in - mkAppIdent (@ident.List_nth_default rT) (d, ls, n) - | Z.add ?x ?y => mkAppIdent ident.Z_add (x, y) - | Z.mul ?x ?y => mkAppIdent ident.Z_mul (x, y) - | Z.pow ?x ?y => mkAppIdent ident.Z_pow (x, y) - | Z.sub ?x ?y => mkAppIdent ident.Z_sub (x, y) - | Z.opp ?x => mkAppIdent ident.Z_opp x - | Z.div ?x ?y => mkAppIdent ident.Z_div (x, y) - | Z.modulo ?x ?y => mkAppIdent ident.Z_modulo (x, y) - | Z.eqb ?x ?y => mkAppIdent ident.Z_eqb (x, y) - | Z.leb ?x ?y => mkAppIdent ident.Z_leb (x, y) - | Z.of_nat ?x => mkAppIdent ident.Z_of_nat x - | Z.mul_split ?x ?y ?z => mkAppIdent ident.Z_mul_split (x, y, z) - | Z.add_get_carry_full ?x ?y ?z => mkAppIdent ident.Z_add_get_carry (x, y, z) - | Z.add_with_carry ?x ?y ?z => mkAppIdent ident.Z_add_with_carry (x, y, z) - | Z.add_with_get_carry_full ?x ?y ?z ?a => mkAppIdent ident.Z_add_with_get_carry (x, y, z, a) - | Z.sub_get_borrow_full ?x ?y ?z => mkAppIdent ident.Z_sub_get_borrow (x, y, z) - | Z.sub_with_get_borrow_full ?x ?y ?z ?a => mkAppIdent ident.Z_sub_with_get_borrow (x, y, z, a) - | Z.zselect ?x ?y ?z => mkAppIdent ident.Z_zselect (x, y, z) - | Z.add_modulo ?x ?y ?z => mkAppIdent ident.Z_add_modulo (x,y,z) - | Z.rshi ?x ?y ?z ?a => mkAppIdent ident.Z_rshi (x,y,z,a) - | Z.cc_m ?x ?y => mkAppIdent ident.Z_cc_m (x,y) - | _ - => lazymatch term_is_primitive_const with - | true - => - let assert_const := match goal with - | _ => require_primitive_const term - end in - let T := type of term in - let rT := type.reify_primitive T in - mkAppIdent (@ident.primitive rT term) tt - | false => else_tac () - end - end. - - Module List. - Notation length := List_length. - Notation seq := List_seq. - Notation repeat := List_repeat. - Notation combine := List_combine. - Notation map := List_map. - Notation flat_map := List_flat_map. - Notation partition := List_partition. - Notation app := List_app. - Notation rev := List_rev. - Notation fold_right := List_fold_right. - Notation update_nth := List_update_nth. - Notation nth_default := List_nth_default. - End List. - - Module Z. - Notation add := Z_add. - Notation mul := Z_mul. - Notation pow := Z_pow. - Notation sub := Z_sub. - Notation opp := Z_opp. - Notation div := Z_div. - Notation modulo := Z_modulo. - Notation eqb := Z_eqb. - Notation leb := Z_leb. - Notation of_nat := Z_of_nat. - Notation mul_split := Z_mul_split. - Notation add_get_carry := Z_add_get_carry. - Notation add_with_carry := Z_add_with_carry. - Notation add_with_get_carry := Z_add_with_get_carry. - Notation sub_get_borrow := Z_sub_get_borrow. - Notation sub_with_get_borrow := Z_sub_with_get_borrow. - Notation zselect := Z_zselect. - Notation add_modulo := Z_add_modulo. - Notation rshi := Z_rshi. - Notation cc_m := Z_cc_m. - End Z. - - Module Nat. - Notation succ := Nat_succ. - Notation add := Nat_add. - Notation sub := Nat_sub. - Notation mul := Nat_mul. - Notation max := Nat_max. - End Nat. - - Module Export Notations. - Notation ident := ident. - End Notations. - End ident. - - Module Notations. - Include ident.Notations. - Notation expr := (@expr ident). - Notation Expr := (@Expr ident). - Notation interp := (@interp ident (@ident.interp)). - Notation Interp := (@Interp ident (@ident.interp)). - - (*Notation "( x , y , .. , z )" := (Pair .. (Pair x%expr y%expr) .. z%expr) : expr_scope.*) - Notation "'expr_let' x := A 'in' b" := (AppIdent ident.Let_In (Pair A%expr (Abs (fun x => b%expr)))) : expr_scope. - Notation "[ ]" := (AppIdent ident.nil _) : expr_scope. - Notation "x :: xs" := (AppIdent ident.cons (Pair x%expr xs%expr)) : expr_scope. - Notation "x" := (AppIdent (ident.primitive x) _) (only printing, at level 9) : expr_scope. - Notation "ls [[ n ]]" - := (AppIdent ident.List.nth_default (_, ls, AppIdent (ident.primitive n%nat) _)%expr) - : expr_scope. - - Module Reification. - Ltac reify var term := expr.reify ident ident.reify var term. - Ltac Reify term := expr.Reify ident ident.reify term. - Ltac Reify_rhs _ := - expr.Reify_rhs ident ident.reify ident.interp (). - End Reification. - Include Reification. - End Notations. - Include Notations. - End for_reification. - - Module Export default. - Module ident. - Import type. - Inductive ident : type -> type -> Set := - | primitive {t : type.primitive} (v : interp t) : ident () t - | Let_In {tx tC} : ident (tx * (tx -> tC)) tC - | Nat_succ : ident nat nat - | Nat_add : ident (nat * nat) nat - | Nat_sub : ident (nat * nat) nat - | Nat_mul : ident (nat * nat) nat - | Nat_max : ident (nat * nat) nat - | nil {t} : ident () (list t) - | cons {t} : ident (t * list t) (list t) - | fst {A B} : ident (A * B) A - | snd {A B} : ident (A * B) B - | bool_rect {T} : ident ((unit -> T) * (unit -> T) * bool) T - | nat_rect {P} : ident ((unit -> P) * (nat * P -> P) * nat) P - | pred : ident nat nat - | list_rect {A P} : ident ((unit -> P) * (A * list A * P -> P) * list A) P - | List_nth_default {T} : ident (T * list T * nat) T - | List_nth_default_concrete {T : type.primitive} (d : interp T) (n : Datatypes.nat) : ident (list T) T - | Z_shiftr (offset : BinInt.Z) : ident Z Z - | Z_shiftl (offset : BinInt.Z) : ident Z Z - | Z_land (mask : BinInt.Z) : ident Z Z - | Z_add : ident (Z * Z) Z - | Z_mul : ident (Z * Z) Z - | Z_pow : ident (Z * Z) Z - | Z_sub : ident (Z * Z) Z - | Z_opp : ident Z Z - | Z_div : ident (Z * Z) Z - | Z_modulo : ident (Z * Z) Z - | Z_eqb : ident (Z * Z) bool - | Z_leb : ident (Z * Z) bool - | Z_of_nat : ident nat Z - | Z_mul_split : ident (Z * Z * Z) (Z * Z) - | Z_mul_split_concrete (s:BinInt.Z) : ident (Z * Z) (Z * Z) - | Z_add_get_carry : ident (Z * Z * Z) (Z * Z) - | Z_add_get_carry_concrete (s:BinInt.Z) : ident (Z * Z) (Z * Z) - | Z_add_with_carry : ident (Z * Z * Z) Z - | Z_add_with_get_carry : ident (Z * Z * Z * Z) (Z * Z) - | Z_add_with_get_carry_concrete (s:BinInt.Z) : ident (Z * Z * Z) (Z * Z) - | Z_sub_get_borrow : ident (Z * Z * Z) (Z * Z) - | Z_sub_get_borrow_concrete (s:BinInt.Z) : ident (Z * Z) (Z * Z) - | Z_sub_with_get_borrow : ident (Z * Z * Z * Z) (Z * Z) - | Z_sub_with_get_borrow_concrete (s:BinInt.Z) : ident (Z * Z * Z) (Z * Z) - | Z_zselect : ident (Z * Z * Z) Z - | Z_add_modulo : ident (Z * Z * Z) Z - | Z_rshi : ident (Z * Z * Z * Z) Z - | Z_rshi_concrete (s offset:BinInt.Z) : ident (Z * Z) Z - | Z_cc_m : ident (Z * Z) Z - | Z_cc_m_concrete (s:BinInt.Z) : ident Z Z - | Z_cast (range : zrange) : ident Z Z - | Z_cast2 (range : zrange * zrange) : ident (Z * Z) (Z * Z) - . - - Notation curry0 f - := (fun 'tt => f). - Notation curry2 f - := (fun '(a, b) => f a b). - Notation curry3 f - := (fun '(a, b, c) => f a b c). - Notation curry4 f - := (fun '(a, b, c, d) => f a b c d). - Notation uncurry2 f - := (fun a b => f (a, b)). - Notation uncurry3 f - := (fun a b c => f (a, b, c)). - Notation curry3_23 f - := (fun '(a, b, c) => f a (uncurry3 b) c). - Notation curry3_2 f - := (fun '(a, b, c) => f a (uncurry2 b) c). - - Section gen. - Context (cast_outside_of_range : zrange -> BinInt.Z -> BinInt.Z). - - Definition cast (r : zrange) (x : BinInt.Z) - := if (lower r <=? x) && (x <=? upper r) - then x - else cast_outside_of_range r x. - - (** Interpret identifiers where the behavior of [Z_cast] - on a value that does not fit in the range is given by - a context variable. (This allows us to treat [Z_cast] - as "undefined behavior" when the value doesn't fit in - the range by quantifying over all possible - interpretations. *) - Definition gen_interp {s d} (idc : ident s d) : type.interp s -> type.interp d - := match idc in ident s d return type.interp s -> type.interp d with - | primitive _ v => curry0 v - | Let_In tx tC => curry2 (@LetIn.Let_In (type.interp tx) (fun _ => type.interp tC)) - | Nat_succ => Nat.succ - | Nat_add => curry2 Nat.add - | Nat_sub => curry2 Nat.sub - | Nat_mul => curry2 Nat.mul - | Nat_max => curry2 Nat.max - | nil t => curry0 (@Datatypes.nil (type.interp t)) - | cons t => curry2 (@Datatypes.cons (type.interp t)) - | fst A B => @Datatypes.fst (type.interp A) (type.interp B) - | snd A B => @Datatypes.snd (type.interp A) (type.interp B) - | bool_rect T => curry3 (fun t f => @Datatypes.bool_rect (fun _ => type.interp T) (t tt) (f tt)) - | nat_rect P => curry3_2 (fun O_case => @Datatypes.nat_rect (fun _ => type.interp P) (O_case tt)) - | pred => Nat.pred - | list_rect A P => curry3_23 (fun N_case => @Datatypes.list_rect (type.interp A) (fun _ => type.interp P) (N_case tt)) - | List_nth_default T => curry3 (@List.nth_default (type.interp T)) - | List_nth_default_concrete T d n => fun ls => @List.nth_default (type.interp T) d ls n - | Z_shiftr n => fun v => Z.shiftr v n - | Z_shiftl n => fun v => Z.shiftl v n - | Z_land mask => fun v => Z.land v mask - | Z_add => curry2 Z.add - | Z_mul => curry2 Z.mul - | Z_pow => curry2 Z.pow - | Z_modulo => curry2 Z.modulo - | Z_sub => curry2 Z.sub - | Z_opp => Z.opp - | Z_div => curry2 Z.div - | Z_eqb => curry2 Z.eqb - | Z_leb => curry2 Z.leb - | Z_of_nat => Z.of_nat - | Z_mul_split => curry3 Z.mul_split - | Z_mul_split_concrete s => curry2 (Z.mul_split s) - | Z_add_get_carry => curry3 Z.add_get_carry_full - | Z_add_get_carry_concrete s => curry2 (Z.add_get_carry_full s) - | Z_add_with_carry => curry3 Z.add_with_carry - | Z_add_with_get_carry => curry4 Z.add_with_get_carry_full - | Z_add_with_get_carry_concrete s => curry3 (Z.add_with_get_carry_full s) - | Z_sub_get_borrow => curry3 Z.sub_get_borrow_full - | Z_sub_get_borrow_concrete s => curry2 (Z.sub_get_borrow_full s) - | Z_sub_with_get_borrow => curry4 Z.sub_with_get_borrow_full - | Z_sub_with_get_borrow_concrete s => curry3 (Z.sub_with_get_borrow_full s) - | Z_zselect => curry3 Z.zselect - | Z_add_modulo => curry3 Z.add_modulo - | Z_rshi => curry4 Z.rshi - | Z_rshi_concrete s n => curry2 (fun x y => Z.rshi s x y n) - | Z_cc_m => curry2 Z.cc_m - | Z_cc_m_concrete s => Z.cc_m s - | Z_cast r => cast r - | Z_cast2 (r1, r2) => fun '(x1, x2) => (cast r1 x1, cast r2 x2) - end. - End gen. - - Definition cast_outside_of_range (r : zrange) (v : BinInt.Z) : BinInt.Z. - Proof. exact v. Qed. - - (** Interpret identifiers where [Z_cast] is an opaque - identity function when the value is not inside the range - *) - Definition interp {s d} (idc : ident s d) : type.interp s -> type.interp d - := @gen_interp cast_outside_of_range s d idc. - Global Arguments interp _ _ !_ _ / . - - Ltac reify - mkAppIdent - term_is_primitive_const - term - else_tac := - (*let dummy := match goal with _ => idtac "attempting to reify_op" term end in*) - lazymatch term with - | Nat.succ ?x => mkAppIdent Nat_succ x - | Nat.add ?x ?y => mkAppIdent Nat_add (x, y) - | Nat.sub ?x ?y => mkAppIdent Nat_sub (x, y) - | Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y) - | Nat.max ?x ?y => mkAppIdent Nat_max (x, y) - | S ?x => mkAppIdent Nat_succ x - | @Datatypes.nil ?T - => let rT := type.reify T in - mkAppIdent (@ident.nil rT) tt - | @Datatypes.cons ?T ?x ?xs - => let rT := type.reify T in - mkAppIdent (@ident.cons rT) (x, xs) - | @Datatypes.fst ?A ?B ?x - => let rA := type.reify A in - let rB := type.reify B in - mkAppIdent (@ident.fst rA rB) x - | @Datatypes.snd ?A ?B ?x - => let rA := type.reify A in - let rB := type.reify B in - mkAppIdent (@ident.snd rA rB) x - | @Datatypes.bool_rect (fun _ => ?T) ?Ptrue ?Pfalse ?b - => let rT := type.reify T in - mkAppIdent (@ident.bool_rect rT) - ((fun _ : Datatypes.unit => Ptrue), (fun _ : Datatypes.unit => Pfalse), b) - | @Datatypes.nat_rect (fun _ => ?T) ?P0 (fun (n' : Datatypes.nat) Pn => ?PS) ?n - => let rT := type.reify T in - let pat := fresh "pat" in (* fresh for COQBUG(https://github.com/coq/coq/issues/6562) *) - mkAppIdent (@ident.nat_rect rT) - ((fun _ : Datatypes.unit => P0), - (fun pat : Datatypes.nat * T - => let '(n', Pn) := pat in PS), - n) - | @Datatypes.nat_rect (fun _ => ?T) ?P0 ?PS ?n - => let dummy := match goal with _ => fail 1 "nat_rect successor case is not syntactically a function of two arguments:" PS end in - constr:(I : I) - | Nat.pred ?x => mkAppIdent ident.pred x - | @LetIn.Let_In ?A (fun _ => ?B) ?x ?f - => let rA := type.reify A in - let rB := type.reify B in - mkAppIdent (@ident.Let_In rA rB) (x, f) - | @LetIn.Let_In ?A ?B ?x ?f - => let dummy := match goal with _ => fail 1 "Let_In contains a dependent type λ as its second argument:" B end in - constr:(I : I) - | @Datatypes.list_rect ?A (fun _ => ?B) ?Pnil (fun x xs rec => ?Pcons) ?ls - => let rA := type.reify A in - let rB := type.reify B in - let pat := fresh "pat" in (* fresh for COQBUG(https://github.com/coq/coq/issues/6562) *) - let pat' := fresh "pat" in (* fresh for COQBUG(https://github.com/coq/coq/issues/6562) (must also not overlap with [rec], but I think [fresh] handles that correctly, at least) *) - mkAppIdent (@ident.list_rect rA rB) - ((fun _ : Datatypes.unit => Pnil), - (fun pat : A * Datatypes.list A * B - => let '(pat', rec) := pat in - let '(x, xs) := pat' in - Pcons), - ls) - | @Datatypes.list_rect ?A (fun _ => ?B) ?Pnil ?Pcons ?ls - => let dummy := match goal with _ => fail 1 "list_rect cons case is not syntactically a function of three arguments:" Pcons end in - constr:(I : I) - | @List.nth_default ?T ?d ?ls ?n - => let rT := type.reify T in - mkAppIdent (@ident.List_nth_default rT) (d, ls, n) - | Z.add ?x ?y => mkAppIdent ident.Z_add (x, y) - | Z.mul ?x ?y => mkAppIdent ident.Z_mul (x, y) - | Z.pow ?x ?y => mkAppIdent ident.Z_pow (x, y) - | Z.sub ?x ?y => mkAppIdent ident.Z_sub (x, y) - | Z.opp ?x => mkAppIdent ident.Z_opp x - | Z.div ?x ?y => mkAppIdent ident.Z_div (x, y) - | Z.modulo ?x ?y => mkAppIdent ident.Z_modulo (x, y) - | Z.eqb ?x ?y => mkAppIdent ident.Z_eqb (x, y) - | Z.leb ?x ?y => mkAppIdent ident.Z_leb (x, y) - | Z.of_nat ?x => mkAppIdent ident.Z_of_nat x - | Z.mul_split ?x ?y ?z => mkAppIdent ident.Z_mul_split (x, y, z) - | Z.add_get_carry_full ?x ?y ?z => mkAppIdent ident.Z_add_get_carry (x, y, z) - | Z.add_with_carry ?x ?y ?z => mkAppIdent ident.Z_add_with_carry (x, y, z) - | Z.add_with_get_carry_full ?x ?y ?z ?a => mkAppIdent ident.Z_add_with_get_carry (x, y, z, a) - | Z.sub_get_borrow_full ?x ?y ?z => mkAppIdent ident.Z_sub_get_borrow (x, y, z) - | Z.sub_with_get_borrow_full ?x ?y ?z ?a => mkAppIdent ident.Z_sub_with_get_borrow (x, y, z, a) - | Z.zselect ?x ?y ?z => mkAppIdent ident.Z_zselect (x, y, z) - | Z.add_modulo ?x ?y ?z => mkAppIdent ident.Z_add_modulo (x,y,z) - | Z.rshi ?x ?y ?z ?a => mkAppIdent ident.Z_rshi (x,y,z,a) - | Z.cc_m ?x ?y => mkAppIdent ident.Z_cc_m (x,y) - | _ - => lazymatch term_is_primitive_const with - | true - => - let assert_const := match goal with - | _ => require_primitive_const term - end in - let T := type of term in - let rT := type.reify_primitive T in - mkAppIdent (@ident.primitive rT term) tt - | _ => else_tac () - end - end. - - Module List. - Notation nth_default := List_nth_default. - Notation nth_default_concrete := List_nth_default_concrete. - End List. - - Module Z. - Notation shiftr := Z_shiftr. - Notation shiftl := Z_shiftl. - Notation land := Z_land. - Notation add := Z_add. - Notation mul := Z_mul. - Notation pow := Z_pow. - Notation sub := Z_sub. - Notation opp := Z_opp. - Notation div := Z_div. - Notation modulo := Z_modulo. - Notation eqb := Z_eqb. - Notation leb := Z_leb. - Notation of_nat := Z_of_nat. - Notation mul_split := Z_mul_split. - Notation mul_split_concrete := Z_mul_split_concrete. - Notation add_get_carry := Z_add_get_carry. - Notation add_get_carry_concrete := Z_add_get_carry_concrete. - Notation add_with_carry := Z_add_with_carry. - Notation add_with_get_carry := Z_add_with_get_carry. - Notation add_with_get_carry_concrete := Z_add_with_get_carry_concrete. - Notation sub_get_borrow := Z_sub_get_borrow. - Notation sub_get_borrow_concrete := Z_sub_get_borrow_concrete. - Notation sub_with_get_borrow := Z_sub_with_get_borrow. - Notation sub_with_get_borrow_concrete := Z_sub_with_get_borrow_concrete. - Notation zselect := Z_zselect. - Notation add_modulo := Z_add_modulo. - Notation rshi := Z_rshi. - Notation rshi_concrete := Z_rshi_concrete. - Notation cc_m := Z_cc_m. - Notation cc_m_concrete := Z_cc_m_concrete. - Notation cast := Z_cast. - Notation cast2 := Z_cast2. - End Z. - - Module Nat. - Notation succ := Nat_succ. - Notation add := Nat_add. - Notation sub := Nat_sub. - Notation mul := Nat_mul. - Notation max := Nat_max. - End Nat. - - Module Export Notations. - Notation ident := ident. - End Notations. - End ident. - - Module Notations. - Include ident.Notations. - Notation expr := (@expr ident). - Notation Expr := (@Expr ident). - Notation interp := (@interp ident (@ident.interp)). - Notation Interp := (@Interp ident (@ident.interp)). - Notation gen_interp cast_outside_of_range := (@interp ident (@ident.gen_interp cast_outside_of_range)). - Notation GenInterp cast_outside_of_range := (@Interp ident (@ident.gen_interp cast_outside_of_range)). - - (*Notation "( x , y , .. , z )" := (Pair .. (Pair x%expr y%expr) .. z%expr) : expr_scope.*) - Notation "'expr_let' x := A 'in' b" := (AppIdent ident.Let_In (Pair A%expr (Abs (fun x => b%expr)))) : expr_scope. - Notation "[ ]" := (AppIdent ident.nil _) : expr_scope. - Notation "x :: xs" := (AppIdent ident.cons (Pair x%expr xs%expr)) : expr_scope. - Notation "x" := (AppIdent (ident.primitive x) _) (only printing, at level 9) : expr_scope. - Notation "ls [[ n ]]" - := (AppIdent ident.List.nth_default (_, ls, AppIdent (ident.primitive n%nat) _)%expr) - : expr_scope. - Notation "ls [[ n ]]" - := (AppIdent (ident.List.nth_default_concrete n) ls%expr) - : expr_scope. - - Ltac reify var term := expr.reify ident ident.reify var term. - Ltac Reify term := expr.Reify ident ident.reify term. - Ltac Reify_rhs _ := - expr.Reify_rhs ident ident.reify ident.interp (). - End Notations. - Include Notations. - End default. - End expr. - - Module canonicalize_list_recursion. - Import expr. - Import expr.default. - Module ident. - Local Ltac app_and_maybe_cancel term := - lazymatch term with - | Abs (fun x : @expr ?var ?T => ?f) - => eval cbv [unexpr] in (fun x : @expr var T => @unexpr ident.ident var _ f) - | Abs (fun x : ?T => ?f) - => let dummy := match goal with _ => fail 1 "Invalid var type:" T end in - constr:(I : I) - end. - - Definition transfer {var} {s d} (idc : for_reification.ident s d) : @expr var s -> @expr var d - := let List_app A := - list_rect - (fun _ => list (type.interp A) -> list (type.interp A)) - (fun m => m) - (fun a l1 app_l1 m => a :: app_l1 m) in - match idc in for_reification.ident s d return @expr var s -> @expr var d with - | for_reification.ident.Let_In tx tC - => AppIdent ident.Let_In - | for_reification.ident.Nat_succ - => AppIdent ident.Nat_succ - | for_reification.ident.Nat_add - => AppIdent ident.Nat_add - | for_reification.ident.Nat_sub - => AppIdent ident.Nat_sub - | for_reification.ident.Nat_mul - => AppIdent ident.Nat_mul - | for_reification.ident.Nat_max - => AppIdent ident.Nat_max - | for_reification.ident.nil t - => AppIdent ident.nil - | for_reification.ident.cons t - => AppIdent ident.cons - | for_reification.ident.fst A B - => AppIdent ident.fst - | for_reification.ident.snd A B - => AppIdent ident.snd - | for_reification.ident.bool_rect T - => AppIdent ident.bool_rect - | for_reification.ident.nat_rect P - => AppIdent ident.nat_rect - | for_reification.ident.list_rect A P - => AppIdent ident.list_rect - | for_reification.ident.pred - => AppIdent ident.pred - | for_reification.ident.primitive t v - => AppIdent (ident.primitive v) - | for_reification.ident.Z_add - => AppIdent ident.Z.add - | for_reification.ident.Z_mul - => AppIdent ident.Z.mul - | for_reification.ident.Z_pow - => AppIdent ident.Z.pow - | for_reification.ident.Z_sub - => AppIdent ident.Z.sub - | for_reification.ident.Z_opp - => AppIdent ident.Z.opp - | for_reification.ident.Z_div - => AppIdent ident.Z.div - | for_reification.ident.Z_modulo - => AppIdent ident.Z.modulo - | for_reification.ident.Z_eqb - => AppIdent ident.Z.eqb - | for_reification.ident.Z_leb - => AppIdent ident.Z.leb - | for_reification.ident.Z_of_nat - => AppIdent ident.Z.of_nat - | for_reification.ident.Z_mul_split - => AppIdent ident.Z.mul_split - | for_reification.ident.Z_add_get_carry - => AppIdent ident.Z.add_get_carry - | for_reification.ident.Z_add_with_carry - => AppIdent ident.Z.add_with_carry - | for_reification.ident.Z_add_with_get_carry - => AppIdent ident.Z.add_with_get_carry - | for_reification.ident.Z_sub_get_borrow - => AppIdent ident.Z.sub_get_borrow - | for_reification.ident.Z_sub_with_get_borrow - => AppIdent ident.Z.sub_with_get_borrow - | for_reification.ident.Z_zselect - => AppIdent ident.Z.zselect - | for_reification.ident.Z_add_modulo - => AppIdent ident.Z.add_modulo - | for_reification.ident.Z_rshi - => AppIdent ident.Z.rshi - | for_reification.ident.Z_cc_m - => AppIdent ident.Z.cc_m - | for_reification.ident.list_case A P - => ltac:( - let v := reify - (@expr var) - (fun '((Pnil, Pcons, ls) - : (unit -> type.interp P) - * (type.interp A * list (type.interp A) -> type.interp P) - * (list (type.interp A))) - => list_rect - (fun _ => type.interp P) - (Pnil tt) - (fun x xs _ => Pcons (x, xs)) - ls) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_length A - => ltac:( - let v := reify - (@expr var) - (fun (ls : list (type.interp A)) - => list_rect - (fun _ => nat) - 0%nat - (fun a t len_t => S len_t) - ls) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_seq - => ltac:( - let v - := - reify - (@expr var) - (fun start_len : nat * nat - => nat_rect - (fun _ => nat -> list nat) - (fun _ => nil) - (fun len seq_len start => cons start (seq_len (S start))) - (snd start_len) (fst start_len)) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_repeat A - => ltac:( - let v := reify - (@expr var) - (fun (xn : type.interp A * nat) - => nat_rect - (fun _ => list (type.interp A)) - nil - (fun k repeat_k => cons (fst xn) repeat_k) - (snd xn)) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_combine A B - => ltac:( - let v := reify - (@expr var) - (fun '((ls1, ls2) : list (type.interp A) * list (type.interp B)) - => list_rect - (fun _ => list (type.interp B) -> list (type.interp A * type.interp B)) - (fun l' => nil) - (fun x tl combine_tl rest - => list_rect - (fun _ => list (type.interp A * type.interp B)) - nil - (fun y tl' _ - => (x, y) :: combine_tl tl') - rest) - ls1 - ls2) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_map A B - => ltac:( - let v := reify - (@expr var) - (fun '((f, ls) : (type.interp A -> type.interp B) * Datatypes.list (type.interp A)) - => list_rect - (fun _ => list (type.interp B)) - nil - (fun a t map_t => f a :: map_t) - ls) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_flat_map A B - => ltac:( - let List_app := (eval cbv [List_app] in (List_app B)) in - let v := reify - (@expr var) - (fun '((f, ls) : (type.interp A -> list (type.interp B)) * list (type.interp A)) - => list_rect - (fun _ => list (type.interp B)) - nil - (fun x t flat_map_t => List_app (f x) flat_map_t) - ls) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_partition A - => ltac:( - let v := reify - (@expr var) - (fun '((f, ls) : (type.interp A -> bool) * list (type.interp A)) - => list_rect - (fun _ => list (type.interp A) * list (type.interp A))%type - (nil, nil) - (fun x tl partition_tl - => let g := fst partition_tl in - let d := snd partition_tl in - if f x then (x :: g, d) else (g, x :: d)) - ls) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_app A - => ltac:( - let List_app := (eval cbv [List_app] in (List_app A)) in - let v := reify (@expr var) (fun '(ls1, ls2) => List_app ls1 ls2) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_rev A - => ltac:( - let List_app := (eval cbv [List_app] in (List_app A)) in - let v := reify - (@expr var) - (fun ls - => list_rect - (fun _ => list (type.interp A)) - nil - (fun x l' rev_l' => List_app rev_l' [x]) - ls) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_fold_right A B - => ltac:( - let v := reify - (@expr var) - (fun '((f, a0, ls) - : (type.interp B * type.interp A -> type.interp A) * type.interp A * list (type.interp B)) - => list_rect - (fun _ => type.interp A) - a0 - (fun b t fold_right_t => f (b, fold_right_t)) - ls) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_update_nth T - => ltac:( - let v := reify - (@expr var) - (fun '((n, f, ls) : nat * (type.interp T -> type.interp T) * list (type.interp T)) - => nat_rect - (fun _ => list (type.interp T) -> list (type.interp T)) - (fun ls - => list_rect - (fun _ => list (type.interp T)) - nil - (fun x' xs' __ => f x' :: xs') - ls) - (fun n' update_nth_n' ls - => list_rect - (fun _ => list (type.interp T)) - nil - (fun x' xs' __ => x' :: update_nth_n' xs') - ls) - n - ls) in - let v := app_and_maybe_cancel v in exact v) - | for_reification.ident.List_nth_default T - => AppIdent ident.List_nth_default - (*ltac:( - let v := reify - var - (fun (default : type.interp T) (l : list (type.interp T)) (n : nat) - => nat_rect - (fun _ => list (type.interp T) -> type.interp T) - (list_rect - (fun _ => type.interp T) - default - (fun x __ __ => x)) - (fun n nth_error_n - => list_rect - (fun _ => type.interp T) - default - (fun __ l __ => nth_error_n l)) - n - l) in - exact v)*) - end%expr. - End ident. - - Module expr. - Section with_var. - Context {var : type -> Type}. - - Fixpoint transfer {t} (e : @for_reification.Notations.expr var t) - : @expr var t - := match e with - | Var t v => Var v - | TT => TT - | Pair A B a b => Pair (@transfer A a) (@transfer B b) - | AppIdent s d idc args => @ident.transfer var s d idc (@transfer _ args) - | App s d f x => App (@transfer _ f) (@transfer _ x) - | Abs s d f => Abs (fun x => @transfer d (f x)) - end. - End with_var. - - Definition Transfer {t} (e : for_reification.Notations.Expr t) : Expr t - := fun var => transfer (e _). - End expr. - End canonicalize_list_recursion. - Notation canonicalize_list_recursion := canonicalize_list_recursion.expr.Transfer. - Export expr. - Export expr.default. - End Uncurried. - - Import Uncurried. - Section invert. - Context {var : type -> Type}. - - Definition invert_Var {t} (e : @expr var t) : option (var t) - := match e with - | Var t v => Some v - | _ => None - end. - - Local Notation if_arrow f - := (fun t => match t return Type with - | type.arrow s d => f s d - | _ => True - end) (only parsing). - Local Notation if_arrow_s f := (if_arrow (fun s d => f s)) (only parsing). - Local Notation if_arrow_d f := (if_arrow (fun s d => f d)) (only parsing). - Local Notation if_prod f - := (fun t => match t return Type with - | type.prod A B => f A B - | _ => True - end). - - Definition invert_Abs {s d} (e : @expr var (type.arrow s d)) : option (var s -> @expr var d) - := match e in expr.expr t return option (if_arrow (fun _ _ => _) t) with - | Abs s d f => Some f - | _ => None - end. - - Definition invert_App {d} (e : @expr var d) : option { s : _ & @expr var (s -> d) * @expr var s }%type - := match e with - | App s d f x => Some (existT _ s (f, x)) - | _ => None - end. - - Definition invert_AppIdent {d} (e : @expr var d) : option { s : _ & @ident s d * @expr var s }%type - := match e with - | AppIdent s d idc args - => Some (existT _ s (idc, args)) - | _ => None - end. - - Definition invert_App2 {d} (e : @expr var d) : option { s1s2 : _ * _ & @expr var (fst s1s2 -> snd s1s2 -> d) * @expr var (fst s1s2) * @expr var (snd s1s2) }%type - := match invert_App e with - | Some (existT s (f, y)) - => match invert_App f with - | Some (existT s' (f', x)) - => Some (existT _ (s', s) (f', x, y)) - | None => None - end - | None => None - end. - - Local Notation expr_prod - := (fun t => match t return Type with - | type.prod A B => prod (expr A) (expr B) - | _ => True - end) (only parsing). - - Definition invert_Pair {A B} (e : @expr var (type.prod A B)) : option (@expr var A * @expr var B) - := match e in expr.expr t return option (if_prod (fun A B => expr A * expr B)%type t) with - | Pair A B a b - => Some (a, b) - | _ => None - end. - - Definition invert_or_expand_Pair {A B} (e : @expr var (type.prod A B)) : @expr var A * @expr var B - := match invert_Pair e with - | Some p => p - | None => (ident.fst @@ e, ident.snd @@ e) - end%core%expr. - - (* if we want more code for the below, I would suggest [reify_base_type] and [reflect_base_type] *) - Definition reify_primitive {t} (v : type.interp (type.type_primitive t)) : @expr var (type.type_primitive t) - := AppIdent (ident.primitive v) TT. - Definition reflect_primitive {t} (e : @expr var (type.type_primitive t)) : option (type.interp (type.type_primitive t)) - := match invert_AppIdent e with - | Some (existT s (idc, args)) - => match idc in ident _ t return option (type.interp t) with - | ident.primitive _ v => Some v - | _ => None - end - | None => None - end. - Definition invert_Z_opp (e : @expr var type.Z) : option (@expr var type.Z) - := match invert_AppIdent e with - | Some (existT s (idc, args)) - => match idc in ident s t return expr s -> option (expr type.Z) with - | ident.Z_opp => fun v => Some v - | _ => fun _ => None - end args - | None => None - end. - - Definition invert_Z_cast (e : @expr var type.Z) : option (zrange * @expr var type.Z) - := match invert_AppIdent e with - | Some (existT s (idc, args)) - => match idc in ident s t return expr s -> option (zrange * expr type.Z) with - | ident.Z_cast r => fun v => Some (r, v) - | _ => fun _ => None - end args - | None => None - end. - - Definition invert_Z_cast2 (e : @expr var (type.Z * type.Z)) : option ((zrange * zrange) * @expr var (type.Z * type.Z)) - := match invert_AppIdent e with - | Some (existT s (idc, args)) - => match idc in ident s t return expr s -> option ((zrange * zrange) * expr (type.Z * type.Z)) with - | ident.Z_cast2 r => fun v => Some (r, v) - | _ => fun _ => None - end args - | None => None - end. - - Local Notation list_expr - := (fun t => match t return Type with - | type.list T => list (expr T) - | _ => True - end) (only parsing). - - (* oh, the horrors of not being able to use non-linear deep pattern matches. c.f. COQBUG(https://github.com/coq/coq/issues/6320) *) - Fixpoint reflect_list {t} (e : @expr var (type.list t)) - : option (list (@expr var t)) - := match e in expr.expr t return option (list_expr t) with - | AppIdent s (type.list t) idc x_xs - => match x_xs in expr.expr s return ident s (type.list t) -> option (list (expr t)) with - | Pair A (type.list B) x xs - => match @reflect_list B xs with - | Some xs - => fun idc - => match idc in ident s d - return if_prod (fun A B => expr A) s - -> if_prod (fun A B => list_expr B) s - -> option (list_expr d) - with - | ident.cons A - => fun x xs => Some (cons x xs) - | _ => fun _ _ => None - end x xs - | None => fun _ => None - end - | _ - => fun idc - => match idc in ident _ t return option (list_expr t) with - | ident.nil _ => Some nil - | _ => None - end - end idc - | _ => None - end. - End invert. - - Section gallina_reify. - Context {var : type -> Type}. - Definition reify_list {t} (ls : list (@expr var t)) : @expr var (type.list t) - := list_rect - (fun _ => _) - (ident.nil @@ TT)%expr - (fun x _ xs => ident.cons @@ (x, xs))%expr - ls. - End gallina_reify. - - Lemma interp_reify_list {t} ls - : interp (@reify_list _ t ls) = List.map interp ls. - Proof. - unfold reify_list. - induction ls as [|x xs IHxs]; cbn in *; [ reflexivity | ]. - rewrite IHxs; reflexivity. - Qed. - - Module GallinaReify. - Section value. - Context (var : type -> Type). - Fixpoint value (t : type) - := match t return Type with - | type.prod A B as t => value A * value B - | type.arrow s d => var s -> value d - | type.list A => list (value A) - | type.type_primitive _ as t - => type.interp t - end%type. - End value. - - Section reify. - Context {var : type -> Type}. - Fixpoint reify {t : type} {struct t} - : value var t -> @expr var t - := match t return value var t -> expr t with - | type.prod A B as t - => fun '((a, b) : value var A * value var B) - => (@reify A a, @reify B b)%expr - | type.arrow s d - => fun (f : var s -> value var d) - => Abs (fun x - => @reify d (f x)) - | type.list A as t - => fun x : list (value var A) - => reify_list (List.map (@reify A) x) - | type.type_primitive _ as t - => fun x : type.interp t - => (ident.primitive x @@ TT)%expr - end. - End reify. - - Definition Reify_as (t : type) (v : forall var, value var t) : Expr t - := fun var => reify (v _). - - (** [Reify] does Ltac type inference to get the type *) - Notation Reify v - := (Reify_as (type.reify_type_of v) (fun _ => v)) (only parsing). - End GallinaReify. - - Module Uncurry. - Module type. - Fixpoint uncurried_domain (t : type) : type - := match t with - | type.arrow s d - => match d with - | type.arrow _ _ - => s * uncurried_domain d - | _ => s - end - | _ => type.type_primitive type.unit - end%ctype. - - Definition uncurry (t : type) : type - := type.arrow (uncurried_domain t) (type.final_codomain t). - End type. - - Fixpoint app_curried {t : type} - : type.interp t -> type.interp (type.uncurried_domain t) -> type.interp (type.final_codomain t) - := match t return type.interp t -> type.interp (type.uncurried_domain t) -> type.interp (type.final_codomain t) with - | type.arrow s d - => match d - return (type.interp d -> type.interp (type.uncurried_domain d) -> type.interp (type.final_codomain d)) - -> type.interp (type.arrow s d) - -> type.interp (type.uncurried_domain (type.arrow s d)) - -> type.interp (type.final_codomain d) - with - | type.arrow _ _ as d - => fun app_curried_d - (f : type.interp s -> type.interp d) - (x : type.interp s * type.interp (type.uncurried_domain d)) - => app_curried_d (f (fst x)) (snd x) - | d - => fun _ - (f : type.interp s -> type.interp d) - (x : type.interp s) - => f x - end (@app_curried d) - | _ => fun f _ => f - end. - - Module expr. - Section with_var. - Context {var : type -> Type}. - - Fixpoint uncurry' {t} - : @expr (@expr var) t -> @expr var (type.uncurried_domain t) -> @expr var (type.final_codomain t) - := match t return expr t -> expr (type.uncurried_domain t) -> expr (type.final_codomain t) with - | type.arrow s d - => fun e - => let f := fun v - => @uncurry' - d - match invert_Abs e with - | Some f => f v - | None => e @ Var v - end%expr in - match d return (expr s -> expr (type.uncurried_domain d) -> expr (type.final_codomain d)) -> expr (type.uncurried_domain (s -> d)) -> expr (type.final_codomain d) with - | type.arrow _ _ as d - => fun f sdv - => f (ident.fst @@ sdv) (ident.snd @@ sdv) - | _ - => fun f sv => f sv TT - end f - | type.type_primitive _ - | type.prod _ _ - | type.list _ - => fun e _ => unexpr e - end%expr. - - Definition uncurry {t} (e : @expr (@expr var) t) - : @expr var (type.uncurry t) - := Abs (fun v => @uncurry' t e (Var v)). - End with_var. - - Definition Uncurry {t} (e : Expr t) : Expr (type.uncurry t) - := fun var => uncurry (e _). - End expr. - End Uncurry. - - Module CPS. - Import Uncurried. - Module Import Output. - Module type. - Import Compilers.type. - Inductive type := type_primitive (_:primitive) | prod (A B : type) | continuation (A : type) | list (A : type). - Module Export Coercions. - Global Coercion type_primitive : primitive >-> type. - End Coercions. - - Module Export Notations. - Export Coercions. - Delimit Scope cpstype_scope with cpstype. - Bind Scope cpstype_scope with type. - Notation "()" := unit : cpstype_scope. - Notation "A * B" := (prod A B) : cpstype_scope. - Notation "A --->" := (continuation A) : cpstype_scope. - Notation type := type. - End Notations. - - Section interp. - Context (R : Type). - (** denote CPS types *) - Fixpoint interp (t : type) - := match t return Type with - | type_primitive t => Compilers.type.interp t - | prod A B => interp A * interp B - | continuation A => interp A -> R - | list A => Datatypes.list (interp A) - end%type. - End interp. - End type. - Export type.Notations. - - Module expr. - Section expr. - Context {ident : type -> Type} {var : type -> Type} {R : type}. - - Inductive expr := - | Halt (v : var R) - | App {A} (f : var (A --->)) (x : var A) - | Bind {A} (x : primop A) (f : var A -> expr) - with - primop : type -> Type := - | Var {t} (v : var t) : primop t - | Abs {t} (f : var t -> expr) : primop (t --->) - | Pair {A B} (x : var A) (y : var B) : primop (A * B) - | Fst {A B} (x : var (A * B)) : primop A - | Snd {A B} (x : var (A * B)) : primop B - | TT : primop () - | Ident {t} (idc : ident t) : primop t. - End expr. - Global Arguments expr {ident var} R. - Global Arguments primop {ident var} R _. - - Definition Expr {ident : type -> Type} R := forall var, @expr ident var R. - - Section with_ident. - Context {ident : type -> Type} - (r : type) - (R : Type) - (interp_ident - : forall t, ident t -> type.interp R t). - - (** denote CPS exprs *) - Fixpoint interp (e : @expr ident (type.interp R) r) (k : type.interp R r -> R) - {struct e} - : R - := match e with - | Halt v => k v - | App A f x => f x - | Bind A x f => interp (f (@interp_primop _ x k)) k - end - with interp_primop {t} (e : @primop ident (type.interp R) r t) (k : type.interp R r -> R) - {struct e} - : type.interp R t - := match e with - | Var t v => v - | Abs t f => fun x : type.interp _ t => interp (f x) k - | Pair A B x y => (x, y) - | Fst A B x => fst x - | Snd A B x => snd x - | TT => tt - | Ident t idc => interp_ident t idc - end. - - Definition Interp (e : Expr r) (k : type.interp R r -> R) : R := interp (e _) k. - End with_ident. - - Module Export Notations. - Delimit Scope cpsexpr_scope with cpsexpr. - Bind Scope cpsexpr_scope with expr. - Bind Scope cpsexpr_scope with primop. - - Infix "@" := App : cpsexpr_scope. - Notation "v <- x ; f" := (Bind x (fun v => f)) : cpsexpr_scope. - Notation "'λ' x .. y , t" := (Abs (fun x => .. (Abs (fun y => t%cpsexpr)) ..)) : cpsexpr_scope. - Notation "( x , y , .. , z )" := (Pair .. (Pair x%cpsexpr y%cpsexpr) .. z%cpsexpr) : cpsexpr_scope. - Notation "( )" := TT : cpsexpr_scope. - Notation "()" := TT : cpsexpr_scope. - End Notations. - End expr. - Export expr.Notations. - End Output. - - Module type. - Section translate. - Fixpoint translate (t : Compilers.type.type) : type - := match t with - | A * B => (translate A * translate B)%cpstype - | s -> d => (translate s * (translate d --->) --->)%cpstype - | Compilers.type.list A => type.list (translate A) - | Compilers.type.type_primitive t - => t - end%ctype. - Fixpoint untranslate (R : Compilers.type.type) (t : type) - : Compilers.type.type - := match t with - | type.type_primitive t => t - | A * B => (untranslate R A * untranslate R B)%ctype - | (t --->) - => (untranslate R t -> R)%ctype - | type.list A => Compilers.type.list (untranslate R A) - end%cpstype. - End translate. - End type. - - Module expr. - Import Output.expr. - Import Output.expr.Notations. - Import Compilers.type. - Import Compilers.Uncurried.expr. - Section with_ident. - Context {ident : Output.type.type -> Type} - {ident' : type -> type -> Type} - {var : Output.type.type -> Type} - (translate_ident : forall s d, ident' s d -> ident (type.translate (s -> d))). - Notation var' := (fun t => var (type.translate t)). - Local Notation oexpr := (@Output.expr.expr ident var). - - Section splice. - Context {r1 r2 : Output.type.type}. - Fixpoint splice (e1 : oexpr r1) (e2 : var r1 -> oexpr r2) - {struct e1} - : oexpr r2 - := match e1 with - | Halt v => e2 v - | f @ x => f @ x - | Bind A x f => v <- @splice_primop _ x e2; @splice (f v) e2 - end%cpsexpr - with - splice_primop {t} (f : @primop ident var r1 t) (e2 : var r1 -> oexpr r2) - {struct f} - : @primop ident var r2 t - := match f with - | Output.expr.Var t v => Output.expr.Var v - | Output.expr.Pair A B x y as e => Output.expr.Pair x y - | Output.expr.Fst A B x => Output.expr.Fst x - | Output.expr.Snd A B x => Output.expr.Snd x - | Output.expr.TT => Output.expr.TT - | Output.expr.Ident t idc => Output.expr.Ident idc - | Output.expr.Abs t f - => Output.expr.Abs (fun x => @splice (f x) e2) - end. - End splice. - - Local Notation "x <-- e1 ; e2" := (splice e1 (fun x => e2%cpsexpr)) : cpsexpr_scope. - - (** Note: We used to special-case [bool_rect] because - reduction of the bodies of eliminators should block on the - branching. We would like to just write: - -<< -| AppIdent (A * A * type.bool) A ident.bool_rect (Ptrue, Pfalse, b) - => b' <-- @translate _ b; - App_bool_rect (@translate _ Ptrue) (@translate _ Pfalse) b' -| AppIdent s d idc args - => args' <-- @translate _ args; - k <- Output.expr.Abs (fun r => Halt r); - p <- (args', k); - f <- Output.expr.Ident (translate_ident s d idc); - f @ p ->> - but due do deficiencies in non-linear deep pattern - matching (and the fact that we're generic over the type of - identifiers), we cannot, and must write something - significantly more verbose. Because this is so painful, - we do not special-case [nat_rect] nor [list_rect], which - anyway do not need special casing except in cases where - they never hit the base case; it is already the case that - functions get a sort of "free pass" and do get evaluated - until applied to arguments, and the base case ought to be - hit exactly once. - - However, now that [bool_rect]'s arguments are thunked, we - no longer need to do this. *) - Fixpoint translate {t} - (e : @Compilers.Uncurried.expr.expr ident' var' t) - : @Output.expr.expr ident var (type.translate t) - := match e with - | Var t v => Halt v - | TT => x <- () ; Halt x - | AppIdent s d idc args - => (args' <-- @translate _ args; - k <- Output.expr.Abs (fun r => Halt r); - p <- (args', k); - f <- Output.expr.Ident (translate_ident s d idc); - f @ p) - | Pair A B a b - => (a' <-- @translate _ a; - b' <-- @translate _ b; - p <- (a', b'); - Halt p) - | App s d e1 e2 - => (f <-- @translate _ e1; - x <-- @translate _ e2; - k <- Output.expr.Abs (fun r => Halt r); - p <- (x, k); - f @ p) - | Abs s d f - => f <- (Output.expr.Abs - (fun p - => x <- Fst p; - k <- Snd p; - r <-- @translate _ (f x); - k @ r)); - Halt f - end%cpsexpr. - End with_ident. - - Definition Translate - {ident : Output.type.type -> Type} - {ident' : type -> type -> Type} - (translate_ident : forall s d, ident' s d -> ident (type.translate (s -> d))) - {t} (e : @Compilers.Uncurried.expr.Expr ident' t) - : @Output.expr.Expr ident (type.translate t) - := fun var => translate translate_ident (e _). - - Section call_with_cont. - Context {ident' : Output.type.type -> Type} - {ident : type -> type -> Type} - {var : type -> Type} - {r : Output.type.type} - {R : type}. - Notation ucexpr := (@Compilers.Uncurried.expr.expr ident var). - Notation ucexprut t := (ucexpr (type.untranslate R t)) (only parsing). - Notation var' := (fun t => ucexprut t). - Context (untranslate_ident : forall t, ident' t -> ucexprut t) - (ifst : forall A B, ident (A * B)%ctype A) - (isnd : forall A B, ident (A * B)%ctype B). - - Fixpoint call_with_continuation - (e : @Output.expr.expr ident' var' r) - (k : ucexprut r -> ucexpr R) - {struct e} - : ucexpr R - := match e with - | Halt v => k v - | expr.App A f x - => @App _ _ (type.untranslate R A) R - f x - | Bind A x f - => @call_with_continuation - (f (@call_primop_with_continuation A x k)) - k - end%expr - with - call_primop_with_continuation - {t} - (e : @Output.expr.primop ident' var' r t) - (k : ucexprut r -> ucexpr R) - {struct e} - : ucexprut t - := match e in Output.expr.primop _ t return ucexprut t with - | expr.Var t v => v - | expr.Abs t f => Abs (fun x : var (type.untranslate _ _) - => @call_with_continuation - (f (Var x)) k) - | expr.Pair A B x y => (x, y) - | Fst A B x => ifst (type.untranslate _ A) (type.untranslate _ B) - @@ x - | Snd A B x => isnd (type.untranslate _ A) (type.untranslate _ B) - @@ x - | expr.TT => TT - | Ident t idc => untranslate_ident t idc - end%expr. - End call_with_cont. - - Definition CallWithContinuation - {ident' : Output.type.type -> Type} - {ident : type -> type -> Type} - {R : type} - (untranslate_ident : forall t, ident' t -> @Compilers.Uncurried.expr.Expr ident (type.untranslate R t)) - (ifst : forall A B, ident (A * B)%ctype A) - (isnd : forall A B, ident (A * B)%ctype B) - {t} (e : @Output.expr.Expr ident' t) - (k : forall var, @Uncurried.expr.expr ident var (type.untranslate R t) -> @Uncurried.expr.expr ident var R) - : @Compilers.Uncurried.expr.Expr ident R - := fun var => call_with_continuation - (fun t idc => untranslate_ident t idc _) ifst isnd (e _) (k _). - End expr. - - Module ident. - Import CPS.Output.type. - - Inductive ident : type -> Set := - | wrap {s d} (idc : Uncurried.expr.default.ident s d) : ident (type.translate (s -> d)). - - Notation cps_of f - := (fun x k => k (f x)). - Notation curry0 f - := (fun 'tt => f). - Notation curry2 f - := (fun '(a, b) => f a b). - Notation curry3 f - := (fun '(a, b, c) => f a b c). - Notation uncurry2 f - := (fun a b => f (a, b)). - Notation uncurry3 f - := (fun a b c => f (a, b, c)). - Notation curry3_23 f - := (fun '(a, b, c) => f a (uncurry3 b) c). - Notation curry3_2 f - := (fun '(a, b, c) => f a (uncurry2 b) c). - - (** denote CPS identifiers *) - Definition interp {R} {t} (idc : ident t) : type.interp R t - := match idc in ident t return type.interp R t with - | wrap s d idc - => fun '((x, k) : type.interp R (type.translate s) * (type.interp R (type.translate d) -> R)) - => - match idc in Uncurried.expr.default.ident s d return type.interp R (type.translate s) -> (type.interp R (type.translate d) -> R) -> R with - | ident.primitive _ _ as idc - | ident.Nat_succ as idc - | ident.Nat_add as idc - | ident.Nat_sub as idc - | ident.Nat_mul as idc - | ident.Nat_max as idc - | ident.pred as idc - | ident.Z_shiftr _ as idc - | ident.Z_shiftl _ as idc - | ident.Z_land _ as idc - | ident.Z_add as idc - | ident.Z_mul as idc - | ident.Z_pow as idc - | ident.Z_sub as idc - | ident.Z_opp as idc - | ident.Z_div as idc - | ident.Z_modulo as idc - | ident.Z_eqb as idc - | ident.Z_leb as idc - | ident.Z_of_nat as idc - | ident.Z_mul_split as idc - | ident.Z_add_get_carry as idc - | ident.Z_add_with_carry as idc - | ident.Z_add_with_get_carry as idc - | ident.Z_sub_with_get_borrow as idc - | ident.Z_sub_get_borrow as idc - | ident.Z_zselect as idc - | ident.Z_add_modulo as idc - | ident.Z_rshi as idc - | ident.Z_cc_m as idc - | ident.Z_cast _ as idc - | ident.Z_cast2 _ as idc - => cps_of (Uncurried.expr.default.ident.interp idc) - | ident.Z_mul_split_concrete s - => cps_of (curry2 (Z.mul_split s)) - | ident.Z_add_get_carry_concrete s - => cps_of (curry2 (Z.add_get_carry_full s)) - | ident.Z_add_with_get_carry_concrete s - => cps_of (curry3 (Z.add_with_get_carry_full s)) - | ident.Z_sub_get_borrow_concrete s - => cps_of (curry2 (Z.sub_get_borrow_full s)) - | ident.Z_sub_with_get_borrow_concrete s - => cps_of (curry3 (Z.sub_with_get_borrow_full s)) - | ident.Z_rshi_concrete s n - => cps_of (curry2 (fun x y => Z.rshi s x y n)) - | ident.Z_cc_m_concrete s - => cps_of (Z.cc_m s) - | ident.Let_In tx tC - => fun '((x, f) : (interp R (type.translate tx) - * (interp R (type.translate tx) * (interp R (type.translate tC) -> R) -> R))) - (k : interp R (type.translate tC) -> R) - => @LetIn.Let_In - (type.interp R (type.translate tx)) (fun _ => R) - x - (fun v => f (v, k)) - | ident.nil t - => cps_of (curry0 (@Datatypes.nil (interp R (type.translate t)))) - | ident.cons t - => cps_of (curry2 (@Datatypes.cons (interp R (type.translate t)))) - | ident.fst A B - => cps_of (@Datatypes.fst (interp R (type.translate A)) (interp R (type.translate B))) - | ident.snd A B - => cps_of (@Datatypes.snd (interp R (type.translate A)) (interp R (type.translate B))) - | ident.bool_rect T - => fun '((tc, fc, b) : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) ((unit * (type.interp R (type.translate T) -> R) -> R) * (unit * (type.interp R (type.translate T) -> R) -> R) * bool)) - k - => @Datatypes.bool_rect - (fun _ => R) - (tc (tt, k)) - (fc (tt, k)) - b - | ident.nat_rect P - => fun '((PO, PS, n) : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) ((unit * (interp R (type.translate P) -> R) -> R) * (nat * interp R (type.translate P) * (interp R (type.translate P) -> R) -> R) * nat)) - k - => @Datatypes.nat_rect - (fun _ => (interp R (type.translate P) -> R) -> R) - (fun k => PO (tt, k)) - (fun n' rec k - => rec (fun rec => PS (n', rec, k))) - n - k - | ident.list_rect A P - => fun '((Pnil, Pcons, ls) : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) ((unit * (interp R (type.translate P) -> R) -> R) * (interp R (type.translate A) * Datatypes.list (interp R (type.translate A)) * interp R (type.translate P) * (interp R (type.translate P) -> R) -> R) * Datatypes.list (interp R (type.translate A)))) - k - => @Datatypes.list_rect - (interp R (type.translate A)) - (fun _ => (interp R (type.translate P) -> R) -> R) - (fun k => Pnil (tt, k)) - (fun x xs rec k - => rec (fun rec => Pcons (x, xs, rec, k))) - ls - k - | ident.List_nth_default T - => cps_of (curry3 (@List.nth_default (interp R (type.translate T)))) - | ident.List_nth_default_concrete T d n - => cps_of (fun ls => @List.nth_default (interp R (type.translate T)) d ls n) - end x k - end. - - Local Notation var_eta x := (ident.fst @@ x, ident.snd @@ x)%core%expr. - - Definition untranslate {R} {t} (idc : ident t) - : @Compilers.Uncurried.expr.Expr Uncurried.expr.default.ident (type.untranslate R t) - := fun var - => match idc in ident t return @Compilers.Uncurried.expr.expr Uncurried.expr.default.ident var (type.untranslate R t) with - | wrap s d idc - => - match idc in default.ident s d return @Compilers.Uncurried.expr.expr Uncurried.expr.default.ident var (type.untranslate R (type.translate (s -> d))) with - | ident.primitive t v - => λ (_k : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (() * (t -> R))%ctype) , - (ident.snd @@ (Var _k)) - @ (ident.primitive v @@ TT) - | ident.Let_In tx tC - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.untranslate _ (type.translate tx) * (type.untranslate _ (type.translate tx) * (type.untranslate _ (type.translate tC) -> R) -> R) * (type.untranslate _ (type.translate tC) -> R))%ctype) , - ident.Let_In - @@ (ident.fst @@ (ident.fst @@ (Var xyk)), - (λ (x : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.untranslate _ (type.translate tx))) , - (ident.snd @@ (ident.fst @@ (Var xyk))) - @ (Var x, ident.snd @@ Var xyk))) - | ident.nat_rect P - => λ (PO_PS_n_k : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var ((Compilers.type.type_primitive ()%cpstype * (type.untranslate R (type.translate P) -> R) -> R) * (Compilers.type.type_primitive type.nat * type.untranslate R (type.translate P) * (type.untranslate R (type.translate P) -> R) -> R) * Compilers.type.type_primitive type.nat * (type.untranslate R (type.translate P) -> R))%ctype) , - let (PO_PS_n, k) := var_eta (Var PO_PS_n_k) in - let (PO_PS, n) := var_eta PO_PS_n in - let (PO, PS) := var_eta PO_PS in - ((@ident.nat_rect ((type.untranslate _ (type.translate P) -> R) -> R)) - @@ ((λ tt k , PO @ (Var tt, Var k)), - (λ n'rec k , - let (n', rec) := var_eta (Var n'rec) in - rec @ (λ rec , PS @ (n', Var rec, Var k))), - n)) - @ k - | ident.list_rect A P - => λ (Pnil_Pcons_ls_k : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var ((Compilers.type.type_primitive ()%cpstype * (type.untranslate R (type.translate P) -> R) -> R) * (type.untranslate R (type.translate A) * Compilers.type.list (type.untranslate R (type.translate A)) * type.untranslate R (type.translate P) * (type.untranslate R (type.translate P) -> R) -> R) * Compilers.type.list (type.untranslate R (type.translate A)) * (type.untranslate R (type.translate P) -> R))%ctype) , - let (Pnil_Pcons_ls, k) := var_eta (Var Pnil_Pcons_ls_k) in - let (Pnil_Pcons, ls) := var_eta Pnil_Pcons_ls in - let (Pnil, Pcons) := var_eta Pnil_Pcons in - ((@ident.list_rect - (type.untranslate _ (type.translate A)) - ((type.untranslate _ (type.translate P) -> R) -> R)) - @@ ((λ tt k, Pnil @ (Var tt, Var k)), - (λ x_xs_rec k, - let (x_xs, rec) := var_eta (Var x_xs_rec) in - let (x, xs) := var_eta x_xs in - rec @ (λ rec , Pcons @ (x, xs, Var rec, Var k))), - ls)) - @ k - | ident.List_nth_default T - => λ (xyzk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.untranslate _ (type.translate T) * Compilers.type.list (type.untranslate _ (type.translate T)) * type.nat * (type.untranslate _ (type.translate T) -> R))%ctype) , - (ident.snd @@ Var xyzk) - @ (ident.List_nth_default @@ (ident.fst @@ Var xyzk)) - | ident.List_nth_default_concrete T d n - => λ (xk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (Compilers.type.list (type.untranslate R (type.translate T)) * (type.untranslate R (type.translate T) -> R))%ctype) , - (ident.snd @@ Var xk) - @ (ident.List_nth_default_concrete d n @@ (ident.fst @@ Var xk)) - | ident.bool_rect T - => λ (xyzk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var ((Compilers.type.type_primitive ()%cpstype * (type.untranslate R (type.translate T) -> R) -> R) * (Compilers.type.type_primitive ()%cpstype * (type.untranslate R (type.translate T) -> R) -> R) * Compilers.type.type_primitive type.bool * (type.untranslate R (type.translate T) -> R))%ctype) , - ident.bool_rect - @@ ((λ tt, - (ident.fst @@ (ident.fst @@ (ident.fst @@ (Var xyzk)))) - @ (Var tt, (ident.snd @@ (Var xyzk)))), - (λ tt, - (ident.snd @@ (ident.fst @@ (ident.fst @@ (Var xyzk)))) - @ (Var tt, (ident.snd @@ (Var xyzk)))), - ident.snd @@ (ident.fst @@ (Var xyzk))) - | ident.nil t - => λ (_k : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (() * (Compilers.type.list (type.untranslate _ (type.translate t)) -> R))%ctype) , - (ident.snd @@ (Var _k)) - @ (ident.nil @@ TT) - | ident.cons t - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.untranslate _ (type.translate t) * Compilers.type.list (type.untranslate _ (type.translate t)) * (Compilers.type.list (type.untranslate _ (type.translate t)) -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ (ident.cons - @@ (ident.fst @@ (Var xyk))) - | ident.fst A B - => λ (xk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.untranslate _ (type.translate A) * type.untranslate _ (type.translate B) * (type.untranslate _ (type.translate A) -> R))%ctype) , - (ident.snd @@ (Var xk)) - @ (ident.fst - @@ (ident.fst @@ (Var xk))) - | ident.snd A B - => λ (xk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.untranslate _ (type.translate A) * type.untranslate _ (type.translate B) * (type.untranslate _ (type.translate B) -> R))%ctype) , - (ident.snd @@ (Var xk)) - @ (ident.snd - @@ (ident.fst @@ (Var xk))) - | ident.Nat_succ as idc - | ident.pred as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.nat * (type.nat -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ type.nat) - @@ (ident.fst @@ (Var xyk))) - | ident.Nat_add as idc - | ident.Nat_sub as idc - | ident.Nat_mul as idc - | ident.Nat_max as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.nat * type.nat * (type.nat -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ type.nat) - @@ (ident.fst @@ (Var xyk))) - | ident.Z_shiftr _ as idc - | ident.Z_shiftl _ as idc - | ident.Z_land _ as idc - | ident.Z_opp as idc - | ident.Z_cast _ as idc - | ident.Z.cc_m_concrete _ as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.Z * (type.Z -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ type.Z) - @@ (ident.fst @@ (Var xyk))) - | ident.Z_add as idc - | ident.Z_mul as idc - | ident.Z_sub as idc - | ident.Z_pow as idc - | ident.Z_div as idc - | ident.Z_modulo as idc - | ident.Z.cc_m as idc - | ident.Z_rshi_concrete _ _ as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.Z * type.Z * (type.Z -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ type.Z) - @@ (ident.fst @@ (Var xyk))) - | ident.Z_eqb as idc - | ident.Z_leb as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.Z * type.Z * (type.bool -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ type.bool) - @@ (ident.fst @@ (Var xyk))) - | ident.Z_of_nat as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.nat * (type.Z -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ type.Z) - @@ (ident.fst @@ (Var xyk))) - | ident.Z_mul_split as idc - | ident.Z_add_get_carry as idc - | ident.Z_sub_get_borrow as idc - | ident.Z_add_with_get_carry_concrete _ as idc - | ident.Z_sub_with_get_borrow_concrete _ as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.Z * type.Z * type.Z * ((type.Z * type.Z) -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ (type.Z * type.Z)) - @@ (ident.fst @@ (Var xyk))) - | ident.Z_cast2 _ as idc - | ident.Z_mul_split_concrete _ as idc - | ident.Z_add_get_carry_concrete _ as idc - | ident.Z_sub_get_borrow_concrete _ as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.Z * type.Z * ((type.Z * type.Z) -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ (type.Z * type.Z)) - @@ (ident.fst @@ (Var xyk))) - | ident.Z_add_with_carry as idc - | ident.Z_zselect as idc - | ident.Z_add_modulo as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.Z * type.Z * type.Z * (type.Z -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ type.Z) - @@ (ident.fst @@ (Var xyk))) - | ident.Z_add_with_get_carry as idc - | ident.Z_sub_with_get_borrow as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.Z * type.Z * type.Z * type.Z * ((type.Z * type.Z) -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ (type.Z * type.Z)) - @@ (ident.fst @@ (Var xyk))) - | ident.Z_rshi as idc - => λ (xyk : - (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.Z * type.Z * type.Z * type.Z * (type.Z -> R))%ctype) , - (ident.snd @@ (Var xyk)) - @ ((idc : default.ident _ type.Z) - @@ (ident.fst @@ (Var xyk))) - end%expr - end. - End ident. - Notation ident := ident.ident. - - Module default. - Notation expr := (@Output.expr.expr ident). - Notation Expr := (@Output.expr.Expr ident). - - Definition Translate - {t} (e : @Compilers.Uncurried.expr.default.Expr t) - : Expr (type.translate t) - := expr.Translate (@ident.wrap) e. - - Definition call_with_continuation - {var} - {R : Compilers.type.type} - {t} (e : @expr _ t) - (k : @Uncurried.expr.default.expr var (type.untranslate R t) -> @Uncurried.expr.default.expr var R) - : @Compilers.Uncurried.expr.default.expr var R - := expr.call_with_continuation (fun t idc => @ident.untranslate _ t idc _) (@ident.fst) (@ident.snd) e k. - - Definition CallWithContinuation - {R : Compilers.type.type} - {t} (e : Expr t) - (k : forall var, @Uncurried.expr.default.expr var (type.untranslate R t) -> @Uncurried.expr.default.expr var R) - : @Compilers.Uncurried.expr.default.Expr R - := expr.CallWithContinuation (@ident.untranslate _) (@ident.fst) (@ident.snd) e k. - - Local Notation iffT A B := ((A -> B) * (B -> A))%type. - (** We can only "plug in the identity continuation" for flat - (arrow-free) types. (Actually, we know how to do it in a - very ad-hoc way for types of at-most second-order functions; - see git history. This is much simpler.) *) - Fixpoint try_untranslate_translate {R} {t} - : option (forall (P : Compilers.type.type -> Type), - iffT (P (type.untranslate R (type.translate t))) (P t)) - := match t return option (forall (P : Compilers.type.type -> Type), - iffT (P (type.untranslate R (type.translate t))) (P t)) with - | Compilers.type.type_primitive x - => Some (fun P => ((fun v => v), (fun v => v))) - | type.arrow s d => None - | Compilers.type.prod A B - => (fA <- (@try_untranslate_translate _ A); - fB <- (@try_untranslate_translate _ B); - Some - (fun P - => let fA := fA (fun A => P (Compilers.type.prod A (type.untranslate R (type.translate B)))) in - let fB := fB (fun B => P (Compilers.type.prod A B)) in - ((fun v => fst fB (fst fA v)), - (fun v => snd fA (snd fB v)))))%option - | Compilers.type.list A - => (fA <- (@try_untranslate_translate R A); - Some (fun P => fA (fun A => P (Compilers.type.list A))))%option - end. - - Local Notation "x <-- e1 ; e2" := (expr.splice e1 (fun x => e2%cpsexpr)) : cpsexpr_scope. - - Definition call_fun_with_id_continuation' - {s d} - : option (forall var - (e : @expr _ (type.translate (s -> d))), - @Compilers.Uncurried.expr.default.expr var (s -> d)) - := (fs <- (@try_untranslate_translate _ s); - fd <- (@try_untranslate_translate _ d); - Some - (fun var e - => let P := @Compilers.Uncurried.expr.default.expr var in - Abs - (fun v : var s - => call_with_continuation - ((f <-- e; - k <- (λ r, expr.Halt r); - p <- (snd (fs P) (Var v), k); - f @ p)%cpsexpr) - (fst (fd P)))))%option. - - Definition call_fun_with_id_continuation - {var} - {s d} (e : @expr _ (type.translate (s -> d))) - : option (@Compilers.Uncurried.expr.default.expr var (s -> d)) - := option_map - (fun f => f _ e) - (@call_fun_with_id_continuation' s d). - - Definition CallFunWithIdContinuation - {s d} - (e : Expr (type.translate (s -> d))) - : option (@Compilers.Uncurried.expr.default.Expr (s -> d)) - := option_map - (fun f var => f _ (e _)) - (@call_fun_with_id_continuation' s d). - End default. - Include default. - End CPS. - - Module ZRange. - Module type. - Module primitive. - (** turn a [type.primitive] into a [Set] describing the type - of bounds on that primitive *) - Definition interp (t : type.primitive) : Set - := match t with - | type.unit => unit - | type.Z => zrange - | type.nat => unit - | type.bool => unit - end. - Definition is_neg {t} : interp t -> bool - := match t with - | type.Z => fun r => (lower r <? 0) && (upper r <=? 0) - | _ => fun _ => false - end. - Definition is_tighter_than {t} : interp t -> interp t -> bool - := match t with - | type.Z => is_tighter_than_bool - | type.unit - | type.nat - | type.bool - => fun _ _ => true - end. - Definition is_bounded_by {t} : interp t -> type.interp t -> bool - := match t with - | type.Z => fun r z => (lower r <=? z) && (z <=? upper r) - | type.unit - | type.nat - | type.bool - => fun _ _ => true - end. - Module option. - (** turn a [type.primitive] into a [Set] describing the type - of optional bounds on that primitive; bounds on a [Z] - may be either a range, or [None], generally indicating - that the [Z] is unbounded. *) - Definition interp (t : type.primitive) : Set - := match t with - | type.unit => unit - | type.Z => option zrange - | type.nat => unit - | type.bool => unit - end. - Definition None {t} : interp t - := match t with - | type.Z => None - | _ => tt - end. - Definition Some {t} : primitive.interp t -> interp t - := match t with - | type.Z => Some - | _ => id - end. - Definition is_neg {t} : interp t -> bool - := match t with - | type.Z => fun v => match v with - | Datatypes.Some v => @is_neg type.Z v - | Datatypes.None => false - end - | t => @primitive.is_neg t - end. - Definition is_tighter_than {t} : interp t -> interp t -> bool - := match t with - | type.Z - => fun r1 r2 - => match r1, r2 with - | _, Datatypes.None => true - | Datatypes.None, Datatypes.Some _ => false - | Datatypes.Some r1, Datatypes.Some r2 => is_tighter_than (t:=type.Z) r1 r2 - end - | t => @is_tighter_than t - end. - Definition is_bounded_by {t} : interp t -> type.interp t -> bool - := match t with - | type.Z - => fun r - => match r with - | Datatypes.Some r => @is_bounded_by type.Z r - | Datatypes.None => fun _ => true - end - | t => @is_bounded_by t - end. - End option. - End primitive. - (** turn a [type] into a [Set] describing the type of bounds on - that type; this lifts [primitive.interp] from - [type.primitive] to [type] *) - Fixpoint interp (t : type) : Set - := match t with - | type.type_primitive x => primitive.interp x - | type.prod A B => interp A * interp B - | type.arrow s d => interp s -> interp d - | type.list A => list (interp A) - end. - Fixpoint is_tighter_than {t} : interp t -> interp t -> bool - := match t with - | type.type_primitive x => @primitive.is_tighter_than x - | type.prod A B - => fun '((ra, rb) : interp A * interp B) - '((ra', rb') : interp A * interp B) - => @is_tighter_than A ra ra' && @is_tighter_than B rb rb' - | type.arrow s d => fun _ _ => false - | type.list A - => fold_andb_map (@is_tighter_than A) - end. - Fixpoint is_bounded_by {t} : interp t -> Compilers.type.interp t -> bool - := match t return interp t -> Compilers.type.interp t -> bool with - | type.type_primitive x => @primitive.is_bounded_by x - | type.prod A B - => fun '((ra, rb) : interp A * interp B) - '((ra', rb') : Compilers.type.interp A * Compilers.type.interp B) - => @is_bounded_by A ra ra' && @is_bounded_by B rb rb' - | type.arrow s d => fun _ _ => false - | type.list A - => fold_andb_map (@is_bounded_by A) - end. - Module option. - (** turn a [type] into a [Set] describing the type of optional - bounds on that primitive; bounds on a [Z] may be either a - range, or [None], generally indicating that the [Z] is - unbounded. This lifts [primitive.option.interp] from - [type.primitive] to [type] *) - Fixpoint interp (t : type) : Set - := match t with - | type.type_primitive x => primitive.option.interp x - | type.prod A B => interp A * interp B - | type.arrow s d => interp s -> interp d - | type.list A => option (list (interp A)) - end. - Fixpoint None {t : type} : interp t - := match t with - | type.type_primitive x => @primitive.option.None x - | type.prod A B => (@None A, @None B) - | type.arrow s d => fun _ => @None d - | type.list A => Datatypes.None - end. - Fixpoint Some {t : type} : type.interp t -> interp t - := match t with - | type.type_primitive x => @primitive.option.Some x - | type.prod A B - => fun x : type.interp A * type.interp B - => (@Some A (fst x), @Some B (snd x)) - | type.arrow s d => fun _ _ => @None d - | type.list A => fun ls => Datatypes.Some (List.map (@Some A) ls) - end. - Fixpoint is_tighter_than {t} : interp t -> interp t -> bool - := match t with - | type.type_primitive x => @primitive.option.is_tighter_than x - | type.prod A B - => fun '((ra, rb) : interp A * interp B) - '((ra', rb') : interp A * interp B) - => @is_tighter_than A ra ra' && @is_tighter_than B rb rb' - | type.arrow s d => fun _ _ => false - | type.list A - => fun ls1 ls2 - => match ls1, ls2 with - | Datatypes.None, Datatypes.None => true - | Datatypes.Some _, Datatypes.None => true - | Datatypes.None, Datatypes.Some _ => false - | Datatypes.Some ls1, Datatypes.Some ls2 => fold_andb_map (@is_tighter_than A) ls1 ls2 - end - end. - Fixpoint is_bounded_by {t} : interp t -> Compilers.type.interp t -> bool - := match t return interp t -> Compilers.type.interp t -> bool with - | type.type_primitive x => @primitive.option.is_bounded_by x - | type.prod A B - => fun '((ra, rb) : interp A * interp B) - '((ra', rb') : Compilers.type.interp A * Compilers.type.interp B) - => @is_bounded_by A ra ra' && @is_bounded_by B rb rb' - | type.arrow s d => fun _ _ => false - | type.list A - => fun ls1 ls2 - => match ls1 with - | Datatypes.None => true - | Datatypes.Some ls1 => fold_andb_map (@is_bounded_by A) ls1 ls2 - end - end. - - Lemma is_bounded_by_Some {t} r val - : is_bounded_by (@Some t r) val = type.is_bounded_by r val. - Proof. - induction t; - repeat first [ reflexivity - | progress cbn in * - | progress destruct_head'_prod - | progress destruct_head' type.primitive - | match goal with H : _ |- _ => rewrite H end ]. - { lazymatch goal with - | [ r : list (type.interp t), val : list (Compilers.type.interp t) |- _ ] - => revert r val IHt - end; intros r val; revert r val. - induction r, val; cbn; auto with nocore; try congruence; []. - intro H'; rewrite H', IHr by auto. - reflexivity. } - Qed. - - Lemma is_tighter_than_is_bounded_by {t} r1 r2 val - (Htight : @is_tighter_than t r1 r2 = true) - (Hbounds : is_bounded_by r1 val = true) - : is_bounded_by r2 val = true. - Proof. - induction t; - repeat first [ progress destruct_head'_prod - | progress destruct_head'_and - | progress destruct_head' type.primitive - | progress cbn in * - | progress destruct_head' option - | solve [ eauto with nocore ] - | progress cbv [is_tighter_than_bool] in * - | progress rewrite ?Bool.andb_true_iff in * - | discriminate - | apply conj - | Z.ltb_to_lt; omega - | rewrite @fold_andb_map_map in * ]. - { lazymatch goal with - | [ r1 : list (interp t), r2 : list (interp t), val : list (Compilers.type.interp t) |- _ ] - => revert r1 r2 val Htight Hbounds IHt - end; intros r1 r2 val; revert r1 r2 val. - induction r1, r2, val; cbn; auto with nocore; try congruence; []. - rewrite !Bool.andb_true_iff; intros; destruct_head'_and; split; eauto with nocore. } - Qed. - - Lemma is_tighter_than_Some_is_bounded_by {t} r1 r2 val - (Htight : @is_tighter_than t r1 (Some r2) = true) - (Hbounds : is_bounded_by r1 val = true) - : type.is_bounded_by r2 val = true. - Proof. - rewrite <- is_bounded_by_Some. - eapply is_tighter_than_is_bounded_by; eassumption. - Qed. - End option. - End type. - - Module ident. - Module option. - Local Open Scope zrange_scope. - - Notation curry0 f - := (fun 'tt => f). - Notation curry2 f - := (fun '(a, b) => f a b). - Notation uncurry2 f - := (fun a b => f (a, b)). - Notation curry3 f - := (fun '(a, b, c) => f a b c). - - (** do bounds analysis on identifiers; take in optional bounds - on arguments, return optional bounds on outputs. *) - Definition interp {s d} (idc : ident s d) : type.option.interp s -> type.option.interp d - := match idc in ident.ident s d return type.option.interp s -> type.option.interp d with - | ident.primitive type.Z v => fun _ => Some r[v ~> v] - | ident.Let_In tx tC => fun '(x, C) => C x - | ident.primitive _ _ - | ident.Nat_succ - | ident.Nat_add - | ident.Nat_sub - | ident.Nat_mul - | ident.Nat_max - | ident.bool_rect _ - | ident.nat_rect _ - | ident.pred - | ident.list_rect _ _ - | ident.List_nth_default _ - | ident.Z_pow - | ident.Z_div - | ident.Z_eqb - | ident.Z_leb - | ident.Z_of_nat - | ident.Z_mul_split - | ident.Z_add_get_carry - | ident.Z_add_with_get_carry - | ident.Z_sub_get_borrow - | ident.Z_sub_with_get_borrow - | ident.Z_modulo - | ident.Z_rshi - | ident.Z_cc_m - => fun _ => type.option.None - | ident.nil t => curry0 (Some (@nil (type.option.interp t))) - | ident.cons t => curry2 (fun a => option_map (@Datatypes.cons (type.option.interp t) a)) - | ident.fst A B => @Datatypes.fst (type.option.interp A) (type.option.interp B) - | ident.snd A B => @Datatypes.snd (type.option.interp A) (type.option.interp B) - | ident.List_nth_default_concrete T d n - => fun ls - => match ls with - | Datatypes.Some ls - => @nth_default (type.option.interp T) type.option.None ls n - | Datatypes.None - => type.option.None - end - | ident.Z_shiftr _ as idc - | ident.Z_shiftl _ as idc - | ident.Z_opp as idc - | ident.Z_cc_m_concrete _ as idc - => option_map (ZRange.two_corners (ident.interp idc)) - | ident.Z_land mask - => option_map - (fun r : zrange - => ZRange.land_bounds r r[mask~>mask]) - | ident.Z_add as idc - | ident.Z_mul as idc - | ident.Z_sub as idc - | ident.Z.rshi_concrete _ _ as idc - => fun '((x, y) : option zrange * option zrange) - => match x, y with - | Some x, Some y - => Some (ZRange.four_corners (uncurry2 (ident.interp idc)) x y) - | Some _, None | None, Some _ | None, None => None - end - | ident.Z_cast range - => fun r : option zrange - => Some match r with - | Some r => ZRange.intersection r range - | None => range - end - | ident.Z_cast2 (r1, r2) - => fun '((r1', r2') : option zrange * option zrange) - => (Some match r1' with - | Some r => ZRange.intersection r r1 - | None => r1 - end, - Some match r2' with - | Some r => ZRange.intersection r r2 - | None => r2 - end) - | ident.Z_mul_split_concrete split_at - => fun '((x, y) : option zrange * option zrange) - => match x, y with - | Some x, Some y - => type.option.Some - (t:=(type.Z*type.Z)%ctype) - (ZRange.split_bounds (ZRange.four_corners BinInt.Z.mul x y) split_at) - | Some _, None | None, Some _ | None, None => type.option.None - end - | ident.Z_add_get_carry_concrete split_at - => fun '((x, y) : option zrange * option zrange) - => match x, y with - | Some x, Some y - => type.option.Some - (t:=(type.Z*type.Z)%ctype) - (ZRange.split_bounds (ZRange.four_corners BinInt.Z.add x y) split_at) - | Some _, None | None, Some _ | None, None => type.option.None - end - | ident.Z_add_with_carry - => fun '((x, y, z) : option zrange * option zrange * option zrange) - => match x, y, z with - | Some x, Some y, Some z - => type.option.Some - (t:=type.Z) - (ZRange.eight_corners (fun x y z => (x + y + z)%Z) x y z) - | _, _, _ => type.option.None - end - | ident.Z_add_with_get_carry_concrete split_at - => fun '((x, y, z) : option zrange * option zrange * option zrange) - => match x, y, z with - | Some x, Some y, Some z - => type.option.Some - (t:=(type.Z*type.Z)%ctype) - (ZRange.split_bounds - (ZRange.eight_corners (fun x y z => (x + y + z)%Z) x y z) - split_at) - | _, _, _ => type.option.None - end - | ident.Z_sub_get_borrow_concrete split_at - => fun '((x, y) : option zrange * option zrange) - => match x, y with - | Some x, Some y - => type.option.Some - (t:=(type.Z*type.Z)%ctype) - (let b := ZRange.split_bounds (ZRange.four_corners BinInt.Z.sub x y) split_at in - (* N.B. sub_get_borrow returns - ((x - y) / split_at) as the borrow, so we need to negate *) - (fst b, ZRange.opp (snd b))) - | Some _, None | None, Some _ | None, None => type.option.None - end - | ident.Z_sub_with_get_borrow_concrete split_at - => fun '((x, y, z) : option zrange * option zrange * option zrange) - => match x, y, z with - | Some x, Some y, Some z - => type.option.Some - (t:=(type.Z*type.Z)%ctype) - (let b := ZRange.split_bounds (ZRange.eight_corners (fun x y z => (y - z - x)%Z) x y z) split_at in - (* N.B. sub_get_borrow returns - ((x - y) / split_at) as the borrow, so we need to negate *) - (fst b, ZRange.opp (snd b))) - | _, _, _ => type.option.None - end - | ident.Z_zselect - => fun '((x, y, z) : option zrange * option zrange * option zrange) - => match y, z with - | Some y, Some z => Some (ZRange.union y z) - | Some _, None | None, Some _ | None, None => None - end - | ident.Z_add_modulo - => fun '((x, y, z) : option zrange * option zrange * option zrange) - => match x, y, z with - | Some x, Some y, Some m - => Some (ZRange.union - (ZRange.four_corners BinInt.Z.add x y) - (ZRange.eight_corners (fun x y m => Z.max 0 (x + y - m)) - x y m)) - | _, _, _ => None - end - end. - End option. - End ident. - End ZRange. - - Module DefaultValue. - (** This module provides "default" inhabitants for the - interpretation of PHOAS types and for the PHOAS [expr] type. - These values are used for things like [nth_default] and in - other places where we need to provide a dummy value in cases - that will never actually be reached in correctly used code. *) - Module type. - Module primitive. - Definition default {t : type.primitive} : type.interp t - := match t with - | type.unit => tt - | type.Z => (-1)%Z - | type.nat => 0%nat - | type.bool => true - end. - End primitive. - Fixpoint default {t} : type.interp t - := match t with - | type.type_primitive x => @primitive.default x - | type.prod A B => (@default A, @default B) - | type.arrow s d => fun _ => @default d - | type.list A => @nil (type.interp A) - end. - End type. - - Module expr. - Section with_var. - Context {var : type -> Type}. - Fixpoint default {t : type} : @expr var t - := match t with - | type.type_primitive x - => AppIdent (ident.primitive type.primitive.default) TT - | type.prod A B - => (@default A, @default B) - | type.arrow s d => Abs (fun _ => @default d) - | type.list A => AppIdent ident.nil TT - end. - End with_var. - - Definition Default {t} : Expr t := fun _ => default. - End expr. - End DefaultValue. - - Module GeneralizeVar. - (** In both lazy and cbv evaluation strategies, reduction under - lambdas is only done at the very end. This means that if we - have a computation which returns a PHOAS syntax tree, and we - plug in two different values for [var], the computation is run - twice. This module provides a way of computing a - representation of terms which does not suffer from this issue. - By computing a flat representation, and then going back to - PHOAS, the cbv strategy will fully compute the preceeding - PHOAS passes only once, and the lazy strategy will share - computation among the various uses of [var] (because there are - no lambdas to get blocked on) and thus will also compute the - preceeding PHOAS passes only once. *) - Module Flat. - Inductive expr : type -> Set := - | Var (t : type) (n : positive) : expr t - | TT : expr type.unit - | AppIdent {s d} (idc : ident s d) (arg : expr s) : expr d - | App {s d} (f : expr (s -> d)) (x : expr s) : expr d - | Pair {A B} (a : expr A) (b : expr B) : expr (A * B) - | Abs (s : type) (n : positive) {d} (f : expr d) : expr (s -> d). - End Flat. - - Definition ERROR {T} (v : T) : T. exact v. Qed. - - Fixpoint to_flat' {t} (e : @expr (fun _ => PositiveMap.key) t) - (cur_idx : PositiveMap.key) - : Flat.expr t - := match e in expr.expr t return Flat.expr t with - | Var t v => Flat.Var t v - | TT => Flat.TT - | AppIdent s d idc args - => Flat.AppIdent idc (@to_flat' _ args cur_idx) - | App s d f x => Flat.App - (@to_flat' _ f cur_idx) - (@to_flat' _ x cur_idx) - | Pair A B a b => Flat.Pair - (@to_flat' _ a cur_idx) - (@to_flat' _ b cur_idx) - | Abs s d f - => Flat.Abs s cur_idx - (@to_flat' - d (f cur_idx) - (Pos.succ cur_idx)) - end. - - Fixpoint from_flat {t} (e : Flat.expr t) - : forall var, PositiveMap.t { t : type & var t } -> @expr var t - := match e in Flat.expr t return forall var, _ -> expr t with - | Flat.Var t v - => fun var ctx - => match (tv <- PositiveMap.find v ctx; - type.try_transport var _ _ (projT2 tv))%option with - | Some v => Var v - | None => ERROR DefaultValue.expr.default - end - | Flat.TT => fun _ _ => TT - | Flat.AppIdent s d idc args - => let args' := @from_flat _ args in - fun var ctx => AppIdent idc (args' var ctx) - | Flat.App s d f x - => let f' := @from_flat _ f in - let x' := @from_flat _ x in - fun var ctx => App (f' var ctx) (x' var ctx) - | Flat.Pair A B a b - => let a' := @from_flat _ a in - let b' := @from_flat _ b in - fun var ctx => Pair (a' var ctx) (b' var ctx) - | Flat.Abs s cur_idx d f - => let f' := @from_flat d f in - fun var ctx - => Abs (fun v => f' var (PositiveMap.add cur_idx (existT _ s v) ctx)) - end. - - Definition to_flat {t} (e : expr t) : Flat.expr t - := to_flat' e 1%positive. - Definition ToFlat {t} (E : Expr t) : Flat.expr t - := to_flat (E _). - Definition FromFlat {t} (e : Flat.expr t) : Expr t - := let e' := @from_flat t e in - fun var => e' var (PositiveMap.empty _). - Definition GeneralizeVar {t} (e : @expr (fun _ => PositiveMap.key) t) : Expr t - := FromFlat (to_flat e). - End GeneralizeVar. - - Module partial. - Notation data := ZRange.type.option.interp. - Section value. - Context (var : type -> Type). - Definition value_prestep (value : type -> Type) (t : type) - := match t return Type with - | type.prod A B as t => value A * value B - | type.arrow s d => value s -> value d - | type.list A => list (value A) - | type.type_primitive _ as t - => type.interp t - end%type. - Definition value_step (value : type -> Type) (t : type) - := match t return Type with - | type.arrow _ _ as t - => value_prestep value t - | type.prod _ _ as t - | type.list _ as t - | type.type_primitive _ as t - => data t * @expr var t + value_prestep value t - end%type. - Fixpoint value (t : type) - := value_step value t. - - Fixpoint value_default {t} : value t - := match t return value t with - | type.type_primitive type.Z - | type.type_primitive _ - => inr DefaultValue.type.primitive.default - | type.prod A B - => inr (@value_default A, @value_default B) - | type.arrow s d => fun _ => @value_default d - | type.list A => inr (@nil (value A)) - end. - - Fixpoint data_from_value {t} : value t -> data t - := match t return value t -> data t with - | type.arrow _ _ as t - => fun _ => ZRange.type.option.None - | type.prod A B as t - => fun v - => match v with - | inl (data, _) => data - | inr (a, b) - => (@data_from_value A a, @data_from_value B b) - end - | type.list A as t - => fun v - => match v with - | inl (data, _) => data - | inr ls - => Some (List.map (@data_from_value A) ls) - end - | type.type_primitive type.Z as t - => fun v - => match v with - | inl (data, _) => data - | inr v => Some r[v~>v]%zrange - end - | type.type_primitive _ as t - => fun v - => match v with - | inl (data, _) => data - | inr _ => ZRange.type.option.None - end - end. - End value. - - Module expr. - Section reify. - Context {var : type -> Type}. - Fixpoint reify {t : type} {struct t} - : value var t -> @expr var t - := match t return value var t -> expr t with - | type.prod A B as t - => fun x : (data A * data B) * expr t + value var A * value var B - => match x with - | inl ((da, db), v) - => match A, B return data A -> data B -> expr (A * B) -> expr (A * B) with - | type.Z, type.Z - => fun da db v - => match da, db with - | Some r1, Some r2 - => (ident.Z.cast2 (r1, r2)%core @@ v)%expr - | _, _ => v - end - | _, _ => fun _ _ v => v - end da db v - | inr (a, b) => (@reify A a, @reify B b)%expr - end - | type.arrow s d - => fun (f : value var s -> value var d) - => Abs (fun x - => @reify d (f (@reflect s (Var x)))) - | type.list A as t - => fun x : _ * expr t + list (value var A) - => match x with - | inl (_, v) => v - | inr v => reify_list (List.map (@reify A) v) - end - | type.type_primitive type.Z as t - => fun x : _ * expr t + type.interp t - => match x with - | inl (Some r, v) => ident.Z.cast r @@ v - | inl (None, v) => v - | inr v => ident.primitive v @@ TT - end%core%expr - | type.type_primitive _ as t - => fun x : _ * expr t + type.interp t - => match x with - | inl (_, v) => v - | inr v => ident.primitive v @@ TT - end%core%expr - end - with reflect {t : type} - : @expr var t -> value var t - := match t return expr t -> value var t with - | type.arrow s d - => fun (f : expr (s -> d)) (x : value var s) - => @reflect d (App f (@reify s x)) - | type.prod A B as t - => fun v : expr t - => let inr := @inr (data t * expr t) (value_prestep (value var) t) in - let inl := @inl (data t * expr t) (value_prestep (value var) t) in - match invert_Pair v with - | Some (a, b) - => inr (@reflect A a, @reflect B b) - | None - => inl - (match A, B return expr (A * B) -> data (A * B) * expr (A * B) with - | type.Z, type.Z - => fun v - => match invert_Z_cast2 v with - | Some (r, v) - => (ZRange.type.option.Some (t:=type.Z*type.Z) r, v) - | None - => (ZRange.type.option.None, v) - end - | _, _ => fun v => (ZRange.type.option.None, v) - end v) - end - | type.list A as t - => fun v : expr t - => let inr := @inr (data t * expr t) (value_prestep (value var) t) in - let inl := @inl (data t * expr t) (value_prestep (value var) t) in - match reflect_list v with - | Some ls - => inr (List.map (@reflect A) ls) - | None - => inl (None, v) - end - | type.type_primitive type.Z as t - => fun v : expr t - => let inr' := @inr (data t * expr t) (value_prestep (value var) t) in - let inl' := @inl (data t * expr t) (value_prestep (value var) t) in - match reflect_primitive v, invert_Z_cast v with - | Some v, _ => inr' v - | None, Some (r, v) => inl' (Some r, v) - | None, None => inl' (None, v) - end - | type.type_primitive _ as t - => fun v : expr t - => let inr := @inr (data t * expr t) (value_prestep (value var) t) in - let inl := @inl (data t * expr t) (value_prestep (value var) t) in - match reflect_primitive v with - | Some v => inr v - | None => inl (tt, v) - end - end. - End reify. - End expr. - - Module ident. - Section interp. - Context (inline_var_nodes : bool) - {var : type -> Type}. - Fixpoint is_var_like {t} (e : @expr var t) : bool - := match e with - | Var t v => true - | TT => true - | AppIdent _ _ (ident.fst _ _) args => @is_var_like _ args - | AppIdent _ _ (ident.snd _ _) args => @is_var_like _ args - | AppIdent _ _ (ident.Z.cast _) args => @is_var_like _ args - | AppIdent _ _ (ident.Z.cast2 _) args => @is_var_like _ args - | Pair A B a b => @is_var_like A a && @is_var_like B b - | AppIdent _ _ _ _ => false - | App _ _ _ _ - | Abs _ _ _ - => false - end. - (** do partial evaluation on let-in, controlling what gets - inlined and what doesn't *) - Fixpoint interp_let_in {tC tx : type} {struct tx} : value var tx -> (value var tx -> value var tC) -> value var tC - := match tx return value var tx -> (value var tx -> value var tC) -> value var tC with - | type.arrow _ _ - => fun x f => f x - | type.list T as t - => fun (x : data t * expr t + list (value var T)) (f : data t * expr t + list (value var T) -> value var tC) - => match x with - | inr ls - => list_rect - (fun _ => (list (value var T) -> value var tC) -> value var tC) - (fun f => f nil) - (fun x _ rec f - => rec (fun ls => @interp_let_in - _ T x - (fun x => f (cons x ls)))) - ls - (fun ls => f (inr ls)) - | inl e => f (inl e) - end - | type.prod A B as t - => fun (x : data t * expr t + value var A * value var B) (f : data t * expr t + value var A * value var B -> value var tC) - => match x with - | inr (a, b) - => @interp_let_in - _ A a - (fun a - => @interp_let_in - _ B b - (fun b => f (inr (a, b)))) - | inl (data, e) - => if inline_var_nodes && is_var_like e - then f x - else partial.expr.reflect - (expr_let y := partial.expr.reify (t:=t) x in - partial.expr.reify (f (inl (data, Var y)%core)))%expr - end - | type.type_primitive _ as t - => fun (x : data t * expr t + type.interp t) (f : data t * expr t + type.interp t -> value var tC) - => match x with - | inl (data, e) - => if inline_var_nodes && is_var_like e - then f x - else partial.expr.reflect - (expr_let y := (partial.expr.reify (t:=t) x) in - partial.expr.reify (f (inl (data, Var y)%core)))%expr - | inr v => f (inr v) (* FIXME: do not substitute [S (big stuck term)] *) - end - end. - - Let default_interp - {s d} - : ident s d -> value var s -> value var d - := match d return ident s d -> value var s -> value var d with - | type.arrow _ _ - => fun idc args => expr.reflect (AppIdent idc (expr.reify args)) - | _ - => fun idc args - => inl (ZRange.ident.option.interp idc (data_from_value var args), - AppIdent idc (expr.reify args)) - end. - - (** do partial evaluation on identifiers *) - Definition interp {s d} (idc : ident s d) : value var (s -> d) - := match idc in ident s d return value var (s -> d) with - | ident.Let_In tx tC as idc - => fun (xf : data (tx * (tx -> tC)) * expr (tx * (tx -> tC)) + value var tx * value var (tx -> tC)) - => match xf with - | inr (x, f) => interp_let_in x f - | _ => expr.reflect (AppIdent idc (expr.reify (t:=tx * (tx -> tC)) xf)) - end - | ident.nil t - => fun _ => inr (@nil (value var t)) - | ident.primitive t v - => fun _ => inr v - | ident.cons t as idc - => fun (x_xs : data (t * type.list t) * expr (t * type.list t) + value var t * (data (type.list t) * expr (type.list t) + list (value var t))) - => match x_xs return data (type.list t) * expr (type.list t) + list (value var t) with - | inr (x, inr xs) => inr (cons x xs) - | _ - => default_interp idc x_xs - end - | ident.fst A B as idc - => fun x : data (A * B) * expr (A * B) + value var A * value var B - => match x with - | inr x => fst x - | _ => default_interp idc x - end - | ident.snd A B as idc - => fun x : data (A * B) * expr (A * B) + value var A * value var B - => match x with - | inr x => snd x - | _ => default_interp idc x - end - | ident.bool_rect T as idc - => fun (true_case_false_case_b : data ((type.unit -> T) * (type.unit -> T) * type.bool) * expr ((type.unit -> T) * (type.unit -> T) * type.bool) + (data ((type.unit -> T) * (type.unit -> T)) * expr ((type.unit -> T) * (type.unit -> T)) + (_ + Datatypes.unit -> value var T) * (_ + Datatypes.unit -> value var T)) * (data type.bool * expr type.bool + bool)) - => match true_case_false_case_b with - | inr (inr (true_case, false_case), inr b) - => if b then true_case (inr tt) else false_case (inr tt) - | _ => default_interp idc true_case_false_case_b - end - | ident.nat_rect P as idc - => fun (O_case_S_case_n : _ * expr ((type.unit -> P) * (type.nat * P -> P) * type.nat) + (_ * expr ((type.unit -> P) * (type.nat * P -> P)) + (_ + Datatypes.unit -> value var P) * value var (type.nat * P -> P)) * (_ * expr type.nat + nat)) - => match O_case_S_case_n with - | inr (inr (O_case, S_case), inr n) - => @nat_rect (fun _ => value var P) - (O_case (inr tt)) - (fun n' rec => S_case (inr (inr n', rec))) - n - | _ => default_interp idc O_case_S_case_n - end - | ident.list_rect A P as idc - => fun (nil_case_cons_case_ls : _ * expr ((type.unit -> P) * (A * type.list A * P -> P) * type.list A) + (_ * expr ((type.unit -> P) * (A * type.list A * P -> P)) + (_ + Datatypes.unit -> value var P) * value var (A * type.list A * P -> P)) * (_ * expr (type.list A) + list (value var A))) - => match nil_case_cons_case_ls with - | inr (inr (nil_case, cons_case), inr ls) - => @list_rect - (value var A) - (fun _ => value var P) - (nil_case (inr tt)) - (fun x xs rec => cons_case (inr (inr (x, inr xs), rec))) - ls - | _ => default_interp idc nil_case_cons_case_ls - end - | ident.List.nth_default type.Z as idc - => fun (default_ls_idx : _ * expr (type.Z * type.list type.Z * type.nat) + (_ * expr (type.Z * type.list type.Z) + (_ * expr type.Z + type.interp type.Z) * (_ * expr (type.list type.Z) + list (value var type.Z))) * (_ * expr type.nat + nat)) - => match default_ls_idx with - | inr (inr (default, inr ls), inr idx) - => List.nth_default default ls idx - | inr (inr (inr default, ls), inr idx) - => default_interp (ident.List.nth_default_concrete default idx) ls - | _ => default_interp idc default_ls_idx - end - | ident.List.nth_default (type.type_primitive A) as idc - => fun (default_ls_idx : _ * expr (A * type.list A * type.nat) + (_ * expr (A * type.list A) + (_ * expr A + type.interp A) * (_ * expr (type.list A) + list (value var A))) * (_ * expr type.nat + nat)) - => match default_ls_idx with - | inr (inr (default, inr ls), inr idx) - => List.nth_default default ls idx - | inr (inr (inr default, ls), inr idx) - => default_interp (ident.List.nth_default_concrete default idx) ls - | _ => default_interp idc default_ls_idx - end - | ident.List.nth_default A as idc - => fun (default_ls_idx : _ * expr (A * type.list A * type.nat) + (_ * expr (A * type.list A) + value var A * (_ * expr (type.list A) + list (value var A))) * (_ * expr type.nat + nat)) - => match default_ls_idx with - | inr (inr (default, inr ls), inr idx) - => List.nth_default default ls idx - | _ => default_interp idc default_ls_idx - end - | ident.List.nth_default_concrete A default idx as idc - => fun (ls : _ * expr (type.list A) + list (value var A)) - => match ls with - | inr ls - => List.nth_default (expr.reflect (t:=A) (AppIdent (ident.primitive default) TT)) ls idx - | _ => default_interp idc ls - end - | ident.Z_mul_split as idc - => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) + - (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) - => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with - | inr (inr (inr x, inr y), inr z) => - let result := ident.interp idc (x, y, z) in - inr (inr (fst result), inr (snd result)) - | inr (inr (inr x, y), z) - => let default _ := default_interp (ident.Z.mul_split_concrete x) (inr (y, z)) in - match (y, z) with - | (inr xx, inl (data, e) as y) - | (inl (data, e) as y, inr xx) - => if Z.eqb xx 0 - then inr (inr 0%Z, inr 0%Z) - else if Z.eqb xx 1 - then inr (y, inr 0%Z) - else if Z.eqb xx (-1) - then inr (inl (data, AppIdent ident.Z.opp (expr.reify (t:=type.Z) y)), inr 0%Z) - else default tt - | _ => default tt - end - | _ => default_interp idc x_y_z - end - | ident.Z_rshi as idc - => fun (x_y_z_a : - (_ * expr (_ * _ * _ * _) + (_ * expr (_ * _ * _) + (_ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _))%type) - => match x_y_z_a return _ * expr _ + type.interp _ with - | inr (inr (inr (inr x, inr y), inr z), inr a) => inr (ident.interp idc (x, y, z, a)) - | inr (inr (inr (inr x, y), z), inr a) - => default_interp (ident.Z.rshi_concrete x a) (inr (y, z)) - | _ => default_interp idc x_y_z_a - end - | ident.Z_cc_m as idc - => fun (x_y : data (_ * _) * expr (_ * _) + (_ + type.interp _) * (_ + type.interp _)) - => match x_y return _ + type.interp _ with - | inr (inr x, inr y) => inr (ident.interp idc (x, y)) - | inr (inr x, y) - => default_interp (ident.Z.cc_m_concrete x) y - | _ => default_interp idc x_y - end - | ident.Z_add_get_carry as idc - => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) + - (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) - => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with - | inr (inr (inr x, inr y), inr z) => - let result := ident.interp idc (x, y, z) in - inr (inr (fst result), inr (snd result)) - | inr (inr (inr x, y), z) - => let default _ := default_interp (ident.Z.add_get_carry_concrete x) (inr (y, z)) in - match (y, z) with - | (inr xx, inl e) - | (inl e, inr xx) - => if Z.eqb xx 0 - then inr (inl e, inr 0%Z) - else default tt - | _ => default tt - end - | _ => default_interp idc x_y_z - end - | ident.Z_add_with_carry as idc - => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) + - (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) - => match x_y_z return ( _ * expr _ + type.interp _) with - | inr (inr (inr x, inr y), inr z) => inr (ident.interp idc (x, y, z)) - | inr (inr (inr x, y), z) - => if Z.eqb x 0 then default_interp (ident.Z.add) (inr (y,z)) else default_interp idc x_y_z - | _ => default_interp idc x_y_z - end - | ident.Z_add_with_get_carry as idc - => fun (x_y_z_a : (_ * expr (_ * _ * _ * _) + - (_ * expr (_ * _ * _) + - (_ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) * - (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _))%type) - => match x_y_z_a return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with - | inr (inr (inr (inr x, inr y), inr z), inr a) => - let result := ident.interp idc (x, y, z, a) in - inr (inr (fst result), inr (snd result)) - | inr (inr (inr (inr x, y), z), a) - => - let default _ := default_interp (ident.Z.add_with_get_carry_concrete x) (inr (inr (y, z), a)) in - let default_add _ := default_interp (ident.Z.add_get_carry_concrete x) (inr (z,a)) in - let default_adx _ := default_interp (ident.Z.add_with_carry) (inr (inr (y, z), a)) in - match y, z, a with - | inr cc, inr xx, inl e - | inr cc, inl e, inr xx - => if Z.eqb cc 0 - then if Z.eqb xx 0 - then inr (inl e, inr 0%Z) - else default_add tt (* carry = 0: ADC x y -> ADD x y *) - else default tt - | inr cc, inl xx, inl yy - => if Z.eqb cc 0 - then default_add tt (* carry = 0: ADC x y -> ADD x y *) - else default tt - | inl _, inr xx, inr yy - => if Z.eqb xx 0 - then if Z.eqb yy 0 - then inr (default_adx tt, inr 0%Z) (* ADC 0 0 -> (ADX 0 0, 0) *) - else default tt - else default tt - | _, _, _ => default tt - end - | _ => default_interp idc x_y_z_a - end - | ident.Z_sub_get_borrow as idc - => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) + - (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) - => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with - | inr (inr (inr x, inr y), inr z) => - let result := ident.interp idc (x, y, z) in - inr (inr (fst result), inr (snd result)) - | inr (inr (inr x, y), z) - => default_interp (ident.Z.sub_get_borrow_concrete x) (inr (y, z)) - | _ => default_interp idc x_y_z - end - | ident.Z_sub_with_get_borrow as idc - => fun (x_y_z_a : (_ * expr (_ * _ * _ * _) + (_ * expr (_ * _ * _) + (_ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _))%type) - => match x_y_z_a return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with - | inr (inr (inr (inr x, inr y), inr z), inr a) => - let '(r, b) := ident.interp idc (x, y, z, a) in - inr (inr r, inr b) - | inr (inr (inr (inr x, y), z), a) - => default_interp (ident.Z.sub_with_get_borrow_concrete x) (inr (inr (y, z), a)) - | _ => default_interp idc x_y_z_a - end - | ident.Z_mul_split_concrete _ as idc - | ident.Z.sub_get_borrow_concrete _ as idc - => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default _ := default_interp idc x_y in - match x_y return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with - | inr (inr x, inr y) => - let result := ident.interp idc (x, y) in - inr (inr (fst result), inr (snd result)) - | _ => default tt - end - | ident.Z.add_get_carry_concrete _ as idc - => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default _ := default_interp idc x_y in - match x_y return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with - | inr (inr x, inr y) => - let result := ident.interp idc (x, y) in - inr (inr (fst result), inr (snd result)) - | inr (inr x, inl e) - | inr (inl e, inr x) => - if Z.eqb x 0%Z - then inr (inl e, inr 0%Z) - else default tt - | _ => default tt - end - | ident.Z.add_with_get_carry_concrete _ as idc - | ident.Z.sub_with_get_borrow_concrete _ as idc - => fun (x_y_z : - (_ * expr (type.Z * type.Z * type.Z) + (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) - => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with - | inr (inr (inr x, inr y), inr z) => - let result := ident.interp idc (x, y, z) in - inr (inr (fst result), inr (snd result)) - | _ => default_interp idc x_y_z - end - | ident.pred as idc - | ident.Nat_succ as idc - => fun x : _ * expr _ + type.interp _ - => match x return _ * expr _ + type.interp _ with - | inr x => inr (ident.interp idc x) - | _ => default_interp idc x - end - | ident.Z_of_nat as idc - => fun x : _ * expr _ + type.interp _ - => match x return _ * expr _ + type.interp _ with - | inr x => inr (ident.interp idc x) - | _ => default_interp idc x - end - | ident.Z_opp as idc - => fun x : _ * expr _ + type.interp _ - => match x return _ * expr _ + type.interp _ with - | inr x => inr (ident.interp idc x) - | inl (r, x) - => match invert_Z_opp x with - | Some x => inl (r, x) - | None => inl (ZRange.ident.option.interp idc r, AppIdent idc x) - end - end - | ident.Z_shiftr _ as idc - | ident.Z_shiftl _ as idc - | ident.Z_land _ as idc - | ident.Z_cc_m_concrete _ as idc - => fun x : _ * expr _ + type.interp _ - => match x return _ * expr _ + type.interp _ with - | inr x => inr (ident.interp idc x) - | inl (data, e) - => inl (ZRange.ident.option.interp idc data, - AppIdent idc (expr.reify (t:=type.Z) x)) - end - | ident.Nat_add as idc - | ident.Nat_sub as idc - | ident.Nat_mul as idc - | ident.Nat_max as idc - | ident.Z_eqb as idc - | ident.Z_leb as idc - | ident.Z_pow as idc - | ident.Z_rshi_concrete _ _ as idc - => fun (x_y : data (_ * _) * expr (_ * _) + (_ + type.interp _) * (_ + type.interp _)) - => match x_y return _ + type.interp _ with - | inr (inr x, inr y) => inr (ident.interp idc (x, y)) - | _ => default_interp idc x_y - end - | ident.Z_div as idc - => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default _ := default_interp idc x_y in - match x_y return _ * expr _ + type.interp _ with - | inr (inr x, inr y) => inr (ident.interp idc (x, y)) - | inr (x, inr y) - => if Z.eqb y (2^Z.log2 y) - then default_interp (ident.Z.shiftr (Z.log2 y)) x - else default tt - | _ => default tt - end - | ident.Z_modulo as idc - => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default _ := default_interp idc x_y in - match x_y return _ * expr _ + type.interp _ with - | inr (inr x, inr y) => inr (ident.interp idc (x, y)) - | inr (x, inr y) - => if Z.eqb y (2^Z.log2 y) - then default_interp (ident.Z.land (y-1)) x - else default tt - | _ => default tt - end - | ident.Z_mul as idc - => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default _ := default_interp idc x_y in - match x_y return _ * expr _ + type.interp _ with - | inr (inr x, inr y) => inr (ident.interp idc (x, y)) - | inr (inr x, inl (data, e) as y) - | inr (inl (data, e) as y, inr x) - => let data' _ := ZRange.ident.option.interp idc (data, Some r[x~>x]%zrange) in - if Z.eqb x 0 - then inr 0%Z - else if Z.eqb x 1 - then y - else if Z.eqb x (-1) - then inl (data' tt, AppIdent ident.Z.opp (expr.reify (t:=type.Z) y)) - else if Z.eqb x (2^Z.log2 x) - then inl (data' tt, - AppIdent (ident.Z.shiftl (Z.log2 x)) (expr.reify (t:=type.Z) y)) - else inl (data' tt, - AppIdent idc (ident.primitive (t:=type.Z) x @@ TT, expr.reify (t:=type.Z) y)) - | inr (inl (dataa, a), inl (datab, b)) - => inl (ZRange.ident.option.interp idc (dataa, datab), - AppIdent idc (a, b)) - | inl _ => default tt - end - | ident.Z_add as idc - => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default0 _ := AppIdent idc (expr.reify (t:=_*_) x_y) in - let default _ := expr.reflect (default0 tt) in - match x_y return _ * expr _ + type.interp _ with - | inr (inr x, inr y) => inr (ident.interp idc (x, y)) - | inr (inr x, inl (data, e) as y) - | inr (inl (data, e) as y, inr x) - => let data' _ := ZRange.ident.option.interp idc (data, Some r[x~>x]%zrange) in - if Z.eqb x 0 - then y - else inl (data' tt, - match invert_Z_opp e with - | Some e => AppIdent - ident.Z.sub - (ident.primitive (t:=type.Z) x @@ TT, - e) - | None => default0 tt - end) - | inr (inl (dataa, a), inl (datab, b)) - => inl (ZRange.ident.option.interp idc (dataa, datab), - match invert_Z_opp a, invert_Z_opp b with - | Some a, Some b - => AppIdent - ident.Z.opp - (idc @@ (a, b)) - | Some a, None - => AppIdent ident.Z.sub (b, a) - | None, Some b - => AppIdent ident.Z.sub (a, b) - | None, None => default0 tt - end) - | inl _ => default tt - end - | ident.Z_sub as idc - => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default0 _ := AppIdent idc (expr.reify (t:=_*_) x_y) in - let default _ := expr.reflect (default0 tt) in - match x_y return _ * expr _ + type.interp _ with - | inr (inr x, inr y) => inr (ident.interp idc (x, y)) - | inr (inr x, inl (data, e)) - => let data' _ := ZRange.ident.option.interp idc (Some r[x~>x]%zrange, data) in - if Z.eqb x 0 - then inl (data' tt, AppIdent ident.Z.opp e) - else inl (data' tt, default0 tt) - | inr (inl (data, e), inr x) - => let data' _ := ZRange.ident.option.interp idc (data, Some r[x~>x]%zrange) in - if Z.eqb x 0 - then inl (data' tt, e) - else inl (data' tt, default0 tt) - | inr (inl (dataa, a), inl (datab, b)) - => inl (ZRange.ident.option.interp idc (dataa, datab), - match invert_Z_opp a, invert_Z_opp b with - | Some a, Some b - => AppIdent - ident.Z.opp - (idc @@ (a, b)) - | Some a, None - => AppIdent ident.Z.add (b, a) - | None, Some b - => AppIdent ident.Z.add (a, b) - | None, None => default0 tt - end) - | inl _ => default tt - end - | ident.Z_zselect as idc - | ident.Z_add_modulo as idc - => fun (x_y_z : (_ * expr (_ * _ * _) + - (_ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _))%type) - => match x_y_z return _ * expr _ + type.interp _ with - | inr (inr (inr x, inr y), inr z) => inr (ident.interp idc (x, y, z)) - | _ => default_interp idc x_y_z - end - | ident.Z_cast r as idc - => fun (x : _ * expr _ + type.interp _) - => match x with - | inr x => inr (ident.interp idc x) - | inl (data, e) - => inl (ZRange.ident.option.interp idc data, e) - end - | ident.Z_cast2 (r1, r2) as idc - => fun (x : _ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => match x with - | inr (inr a, inr b) - => inr (inr (ident.interp (ident.Z.cast r1) a), - inr (ident.interp (ident.Z.cast r2) b)) - | inr (inr a, inl (r2', b)) - => inr (inr (ident.interp (ident.Z.cast r1) a), - inl (ZRange.ident.option.interp (ident.Z.cast r2) r2', b)) - | inr (inl (r1', a), inr b) - => inr (inl (ZRange.ident.option.interp (ident.Z.cast r1) r1', a), - inr (ident.interp (ident.Z.cast r2) b)) - | inr (inl (r1', a), inl (r2', b)) - => inr (inl (ZRange.ident.option.interp (ident.Z.cast r1) r1', a), - inl (ZRange.ident.option.interp (ident.Z.cast r2) r2', b)) - | inl (data, e) - => inl (ZRange.ident.option.interp idc data, e) - end - end. - End interp. - End ident. - - Module bounds. - Section with_var. - Context {var : type -> Type}. - - Fixpoint extend_concrete_list_with_obounds {t} - (extend_with_obounds : ZRange.type.option.interp t -> partial.value var t -> partial.value var t ) - (ls : list (ZRange.type.option.interp t)) - (e : list (partial.value var t)) - {struct ls} - : list (partial.value var t) - := match ls with - | nil => nil - | cons b bs - => cons (extend_with_obounds - b - (hd (partial.value_default _) e)) - (@extend_concrete_list_with_obounds - t extend_with_obounds bs (tl e)) - end. - - Fixpoint extend_list_expr_with_obounds {t} - (extend_with_obounds : ZRange.type.primitive.option.interp t -> partial.value var t -> partial.value var t ) - (starting_index : nat) - (ls : list (ZRange.type.option.interp t)) - (e : @expr var (type.list t)) - {struct ls} - : list (partial.value var t) - := match ls with - | nil => nil - | cons b bs - => cons (extend_with_obounds - b - (partial.expr.reflect - (AppIdent - (ident.List_nth_default_concrete - DefaultValue.type.default starting_index) - e))) - (@extend_list_expr_with_obounds - t extend_with_obounds (S starting_index) bs e) - end. - - Fixpoint extend_with_obounds {t} : ZRange.type.option.interp t -> partial.value var t -> partial.value var t - := match t return ZRange.type.option.interp t -> partial.value var t -> partial.value var t with - | type.type_primitive type.Z - => fun (r : option zrange) (e : option zrange * expr _ + type.interp _) - => match r, e with - | Some r, inr v => inr (default.ident.interp (ident.Z.cast r) v) - | Some r, inl (data, e) - => inl (ZRange.ident.option.interp (ident.Z.cast r) data, e) - | None, e => e - end - | type.type_primitive t => fun _ => id - | type.prod A B - => fun '((ra, rb) : ZRange.type.option.interp A * ZRange.type.option.interp B) - (e : _ * expr _ + partial.value var A * partial.value var B) - => match e with - | inr (a, b) - => inr (@extend_with_obounds A ra a, - @extend_with_obounds B rb b) - | inl ((dataa, datab), e) - => if partial.ident.is_var_like e - then inr (@extend_with_obounds A ra (partial.expr.reflect (AppIdent ident.fst e)), - @extend_with_obounds B rb (partial.expr.reflect (AppIdent ident.snd e))) - else inl - (match A, B return ZRange.type.option.interp A -> ZRange.type.option.interp B -> data A -> data B -> expr (A * B) -> data (A * B) * expr (A * B) with - | type.Z, type.Z - => fun ra rb da db e - => let da' - := match ra with - | Some ra - => ZRange.ident.option.interp - (ident.Z.cast ra) da - | None => da - end in - let db' - := match rb with - | Some rb - => ZRange.ident.option.interp - (ident.Z.cast rb) db - | None => db - end in - ((da', db'), e) - | _, _ - => fun _ _ da db e => ((da, db), e) - end ra rb dataa datab e) - end - | type.arrow s d => fun _ => id - | type.list A - => fun (ls : option (Datatypes.list (ZRange.type.option.interp A))) - (e : data _ * expr _ + list (partial.value var A)) - => match ls with - | None => e - | Some ls - => - match e with - | inl (data, e) - => match A return (ZRange.type.option.interp A -> partial.value var A -> partial.value var A) - -> Datatypes.list (ZRange.type.option.interp A) - -> option (Datatypes.list (ZRange.type.option.interp A)) - -> expr (type.list A) - -> partial.value var (type.list A) - with - | type.type_primitive A - => fun extend_with_obounds ls data e - => match data with - | Some data - => inr - (extend_concrete_list_with_obounds - extend_with_obounds ls - (extend_list_expr_with_obounds - extend_with_obounds 0 data e)) - | None - => inr (extend_list_expr_with_obounds - extend_with_obounds 0 ls e) - end - | A' - (* N.B. We clobber the existing bounds here, rather than fusing them *) - => fun _ ls data e => inl (Some ls, e) - end (@extend_with_obounds A) ls data e - | inr e => inr (extend_concrete_list_with_obounds - (@extend_with_obounds A) ls e) - end - end - end. - Definition extend_with_bounds {t} - (b : ZRange.type.interp t) - (e : partial.value var t) - : partial.value var t - := @extend_with_obounds t (ZRange.type.option.Some b) e. - End with_var. - - Module ident. - Definition extract {s d} (idc : ident s d) : ZRange.type.option.interp s -> ZRange.type.option.interp d - := match idc in ident s d return ZRange.type.option.interp s -> ZRange.type.option.interp d with - | ident.Let_In tx tC - => fun '((x, f) : ZRange.type.option.interp tx * (ZRange.type.option.interp tx -> ZRange.type.option.interp tC)) - => f x - | ident.Z_cast range => fun _ => Some range - | ident.Z_cast2 (r1, r2) => fun _ => (Some r1, Some r2) - | ident.primitive type.Z v - => fun _ => Some r[v~>v]%zrange - | ident.nil _ => fun _ => Some nil - | ident.cons t - => fun '((x, xs) : ZRange.type.option.interp t * option (list (ZRange.type.option.interp t))) - => option_map (cons x) xs - | _ => fun _ => ZRange.type.option.None - end. - End ident. - - Module expr. - Section with_var. - Context {var : type -> Type} - (fill_var : forall t, ZRange.type.option.interp t -> var t). - Fixpoint extract' {t} (e : @expr var t) : ZRange.type.option.interp t - := match e in expr.expr t return ZRange.type.option.interp t with - | Var _ _ - | TT - => ZRange.type.option.None - | AppIdent s d idc args => ident.extract idc (@extract' s args) - | App s d f x => @extract' _ f (@extract' s x) - | Pair A B a b => (@extract' A a, @extract' B b) - | Abs s d f => fun bs : ZRange.type.option.interp s - => @extract' d (f (fill_var s bs)) - end. - End with_var. - - Definition extract {t} (e : expr t) : ZRange.type.option.interp t - := extract' (fun _ => id) e. - - Definition Extract {t} (e : Expr t) : ZRange.type.option.interp t - := extract (e _). - End expr. - End bounds. - End partial. - - Section partial_evaluate. - Context (inline_var_nodes : bool) - {var : type -> Type}. - - Definition partial_evaluate'_step - (partial_evaluate' : forall {t} (e : @expr (partial.value var) t), - partial.value var t) - {t} (e : @expr (partial.value var) t) - : partial.value var t - := match e in expr.expr t return partial.value var t with - | Var t v => v - | TT => inr tt - | AppIdent s d idc args => partial.ident.interp inline_var_nodes idc (@partial_evaluate' _ args) - | Pair A B a b => inr (@partial_evaluate' A a, @partial_evaluate' B b) - | App s d f x => @partial_evaluate' _ f (@partial_evaluate' _ x) - | Abs s d f => fun x => @partial_evaluate' d (f x) - end. - Fixpoint partial_evaluate' {t} (e : @expr (partial.value var) t) - : partial.value var t - := @partial_evaluate'_step (@partial_evaluate') t e. - - Definition partial_evaluate {t} (e : @expr (partial.value var) t) : @expr var t - := partial.expr.reify (@partial_evaluate' t e). - - Definition partial_evaluate_with_bounds1' {s d} (e : @expr (partial.value var) (s -> d)) - (b : ZRange.type.option.interp s) - : partial.value var (s -> d) - := fun x : partial.value var s - => partial_evaluate' e (partial.bounds.extend_with_obounds b x). - - Definition partial_evaluate_with_bounds1 {s d} (e : @expr (partial.value var) (s -> d)) - (b : ZRange.type.option.interp s) - := partial.expr.reify (@partial_evaluate_with_bounds1' s d e b). - - End partial_evaluate. - - Definition PartialEvaluate (inline_var_nodes : bool) {t} (e : Expr t) : Expr t - := fun var => @partial_evaluate inline_var_nodes var t (e _). - - Module RelaxZRange. - Module ident. - Section relax. - Context (relax_zrange : zrange -> option zrange) - {var : type -> Type}. - - Definition relax {s d} (idc : ident s d) : @expr var s -> @expr var d - := match idc in ident s d return expr s -> expr d with - | ident.Z_cast range - => match relax_zrange range with - | Some r => AppIdent (ident.Z.cast r) - | None => id - end - | ident.Z_cast2 (r1, r2) - => match relax_zrange r1, relax_zrange r2 with - | Some r1, Some r2 - => AppIdent (ident.Z.cast2 (r1, r2)) - | Some _, None | None, Some _ | None, None => id - end - | idc => AppIdent idc - end. - End relax. - End ident. - - Module expr. - Section relax. - Context (relax_zrange : zrange -> option zrange). - Section with_var. - Context {var : type -> Type}. - - Fixpoint relax {t} (e : @expr var t) : @expr var t - := match e with - | Var t v => Var v - | TT => TT - | AppIdent s d idc args => @ident.relax relax_zrange var s d idc - (@relax s args) - | App s d f x => App (@relax _ f) (@relax _ x) - | Pair A B a b => Pair (@relax A a) (@relax B b) - | Abs s d f => Abs (fun v => @relax d (f v)) - end. - End with_var. - - Definition Relax {t} (e : Expr t) : Expr t - := fun var => relax (e _). - End relax. - End expr. - End RelaxZRange. - - Definition PartialEvaluateWithBounds1 - {s d} (e : Expr (s -> d)) (b : ZRange.type.option.interp s) - : Expr (s -> d) - := fun var => @partial_evaluate_with_bounds1 true var s d (e _) b. - - Definition CheckPartialEvaluateWithBounds1 - (relax_zrange : zrange -> option zrange) - {s d} (E : Expr (s -> d)) - (b_in : ZRange.type.option.interp s) - (b_out : ZRange.type.option.interp d) - : Expr (s -> d) + (ZRange.type.option.interp d * Expr (s -> d)) - := let b_computed := partial.bounds.expr.Extract E b_in in - if ZRange.type.option.is_tighter_than b_computed b_out - then @inl (Expr (s -> d)) _ (RelaxZRange.expr.Relax relax_zrange E) - else @inr _ (ZRange.type.option.interp d * Expr (s -> d)) (b_computed, E). - - Definition CheckPartialEvaluateWithBounds0 - (relax_zrange : zrange -> option zrange) - {t} (E : Expr t) - (b_out : ZRange.type.option.interp t) - : Expr t + (ZRange.type.option.interp t * Expr t) - := let b_computed := partial.bounds.expr.Extract E in - if ZRange.type.option.is_tighter_than b_computed b_out - then @inl (Expr t) _ (RelaxZRange.expr.Relax relax_zrange E) - else @inr _ (ZRange.type.option.interp t * Expr t) (b_computed, E). - - Definition CheckedPartialEvaluateWithBounds1 - (relax_zrange : zrange -> option zrange) - {s d} (e : Expr (s -> d)) - (b_in : ZRange.type.option.interp s) - (b_out : ZRange.type.option.interp d) - : Expr (s -> d) + (ZRange.type.option.interp d * Expr (s -> d)) - := let E := PartialEvaluateWithBounds1 e b_in in - dlet_nd e := GeneralizeVar.ToFlat E in - let E := GeneralizeVar.FromFlat e in - CheckPartialEvaluateWithBounds1 relax_zrange E b_in b_out. - - Definition CheckedPartialEvaluateWithBounds0 - (relax_zrange : zrange -> option zrange) - {t} (e : Expr t) - (b_out : ZRange.type.option.interp t) - : Expr t + (ZRange.type.option.interp t * Expr t) - := let E := PartialEvaluate true e in - dlet_nd e := GeneralizeVar.ToFlat E in - let E := GeneralizeVar.FromFlat e in - CheckPartialEvaluateWithBounds0 relax_zrange E b_out. - - Axiom admit_pf : False. - Local Notation admit := (match admit_pf with end). - - Theorem CheckedPartialEvaluateWithBounds1_Correct - (relax_zrange : zrange -> option zrange) - (Hrelax : forall r r' z, is_tighter_than_bool z r = true - -> relax_zrange r = Some r' - -> is_tighter_than_bool z r' = true) - {s d} (e : Expr (s -> d)) - (b_in : ZRange.type.option.interp s) - (b_out : ZRange.type.option.interp d) - rv (Hrv : CheckedPartialEvaluateWithBounds1 relax_zrange e b_in b_out = inl rv) - : forall arg - (Harg : ZRange.type.option.is_bounded_by b_in arg = true), - Interp rv arg = Interp e arg - /\ ZRange.type.option.is_bounded_by b_out (Interp rv arg) = true. - Proof. - cbv [CheckedPartialEvaluateWithBounds1 CheckPartialEvaluateWithBounds1 Let_In] in *; - break_innermost_match_hyps; inversion_sum; subst. - intros arg Harg. - split. - { exact admit. (* correctness of interp *) } - { eapply ZRange.type.option.is_tighter_than_is_bounded_by; [ eassumption | ]. - cbv [expr.Interp]. - revert Harg. - exact admit. (* boundedness *) } - Qed. - - Theorem CheckedPartialEvaluateWithBounds0_Correct - (relax_zrange : zrange -> option zrange) - (Hrelax : forall r r' z, is_tighter_than_bool z r = true - -> relax_zrange r = Some r' - -> is_tighter_than_bool z r' = true) - {t} (e : Expr t) - (b_out : ZRange.type.option.interp t) - rv (Hrv : CheckedPartialEvaluateWithBounds0 relax_zrange e b_out = inl rv) - : Interp rv = Interp e - /\ ZRange.type.option.is_bounded_by b_out (Interp rv) = true. - Proof. - cbv [CheckedPartialEvaluateWithBounds0 CheckPartialEvaluateWithBounds0 Let_In] in *; - break_innermost_match_hyps; inversion_sum; subst. - split. - { exact admit. (* correctness of interp *) } - { eapply ZRange.type.option.is_tighter_than_is_bounded_by; [ eassumption | ]. - cbv [expr.Interp]. - exact admit. (* boundedness *) } - Qed. - - Module DeadCodeElimination. - Fixpoint compute_live' {t} (e : @expr (fun _ => PositiveSet.t) t) (cur_idx : positive) - : positive * PositiveSet.t - := match e with - | Var t v => (cur_idx, v) - | TT => (cur_idx, PositiveSet.empty) - | AppIdent s d idc args - => let default _ := @compute_live' _ args cur_idx in - match args in expr.expr t return ident.ident t d -> _ with - | Pair A B x (Abs s d f) - => fun idc - => match idc with - | ident.Let_In _ _ - => let '(idx, live) := @compute_live' A x cur_idx in - let '(_, live) := @compute_live' _ (f (PositiveSet.add idx live)) (Pos.succ idx) in - (Pos.succ idx, live) - | _ => default tt - end - | _ => fun _ => default tt - end idc - | App s d f x - => let '(idx, live1) := @compute_live' _ f cur_idx in - let '(idx, live2) := @compute_live' _ x idx in - (idx, PositiveSet.union live1 live2) - | Pair A B a b - => let '(idx, live1) := @compute_live' A a cur_idx in - let '(idx, live2) := @compute_live' B b idx in - (idx, PositiveSet.union live1 live2) - | Abs s d f - => let '(_, live) := @compute_live' _ (f PositiveSet.empty) cur_idx in - (cur_idx, live) - end. - Definition compute_live {t} e : PositiveSet.t := snd (@compute_live' t e 1). - Definition ComputeLive {t} (e : Expr t) := compute_live (e _). - - Section with_var. - Context {var : type -> Type} - (live : PositiveSet.t). - Definition OUGHT_TO_BE_UNUSED {T1 T2} (v : T1) (v' : T2) := v. - Global Opaque OUGHT_TO_BE_UNUSED. - Fixpoint eliminate_dead' {t} (e : @expr (@expr var) t) (cur_idx : positive) - : positive * @expr var t - := match e with - | Var t v => (cur_idx, v) - | TT => (cur_idx, TT) - | AppIdent s d idc args - => let default _ - := let default' := @eliminate_dead' _ args cur_idx in - (fst default', AppIdent idc (snd default')) in - match args in expr.expr t return ident.ident t d -> (unit -> positive * expr d) -> positive * expr d with - | Pair A B x y - => match y in expr.expr Y return ident.ident (A * Y) d -> (unit -> positive * expr d) -> positive * expr d with - | Abs s' d' f - => fun idc - => let '(idx, x') := @eliminate_dead' A x cur_idx in - let f' := fun v => snd (@eliminate_dead' _ (f v) (Pos.succ idx)) in - match idc in ident.ident s d - return (match s return Type with - | A * _ => expr A - | _ => unit - end%ctype - -> match s return Type with - | _ * (s -> d) => (expr s -> expr d)%type - | _ => unit - end%ctype - -> (unit -> positive * expr d) - -> positive * expr d) - with - | ident.Let_In _ _ - => fun x' f' _ - => if PositiveSet.mem idx live - then (Pos.succ idx, AppIdent ident.Let_In (Pair x' (Abs (fun v => f' (Var v))))) - else (Pos.succ idx, f' (OUGHT_TO_BE_UNUSED x' (Pos.succ idx, PositiveSet.elements live))) - | _ => fun _ _ default => default tt - end x' f' - | _ => fun _ default => default tt - end - | _ => fun _ default => default tt - end idc default - | App s d f x - => let '(idx, f') := @eliminate_dead' _ f cur_idx in - let '(idx, x') := @eliminate_dead' _ x idx in - (idx, App f' x') - | Pair A B a b - => let '(idx, a') := @eliminate_dead' A a cur_idx in - let '(idx, b') := @eliminate_dead' B b idx in - (idx, Pair a' b') - | Abs s d f - => (cur_idx, Abs (fun v => snd (@eliminate_dead' _ (f (Var v)) cur_idx))) - end. - - Definition eliminate_dead {t} e : expr t - := snd (@eliminate_dead' t e 1). - End with_var. - - Definition EliminateDead {t} (e : Expr t) : Expr t - := fun var => eliminate_dead (ComputeLive e) (e _). - End DeadCodeElimination. - - Module Subst01. - Local Notation PositiveMap_incr idx m - := (PositiveMap.add idx (match PositiveMap.find idx m with - | Some n => S n - | None => S O - end) m). - Local Notation PositiveMap_union m1 m2 - := (PositiveMap.map2 - (fun c1 c2 - => match c1, c2 with - | Some n1, Some n2 => Some (n1 + n2)%nat - | Some n, None - | None, Some n - => Some n - | None, None => None - end) m1 m2). - Fixpoint compute_live_counts' {t} (e : @expr (fun _ => positive) t) (cur_idx : positive) - : positive * PositiveMap.t nat - := match e with - | Var t v => (cur_idx, PositiveMap_incr v (PositiveMap.empty _)) - | TT => (cur_idx, PositiveMap.empty _) - | AppIdent s d idc args - => @compute_live_counts' _ args cur_idx - | App s d f x - => let '(idx, live1) := @compute_live_counts' _ f cur_idx in - let '(idx, live2) := @compute_live_counts' _ x idx in - (idx, PositiveMap_union live1 live2) - | Pair A B a b - => let '(idx, live1) := @compute_live_counts' A a cur_idx in - let '(idx, live2) := @compute_live_counts' B b idx in - (idx, PositiveMap_union live1 live2) - | Abs s d f - => let '(idx, live) := @compute_live_counts' _ (f cur_idx) (Pos.succ cur_idx) in - (cur_idx, live) - end. - Definition compute_live_counts {t} e : PositiveMap.t _ := snd (@compute_live_counts' t e 1). - Definition ComputeLiveCounts {t} (e : Expr t) := compute_live_counts (e _). - - Section with_var. - Context {var : type -> Type} - (live : PositiveMap.t nat). - Fixpoint subst01' {t} (e : @expr (@expr var) t) (cur_idx : positive) - : positive * @expr var t - := match e with - | Var t v => (cur_idx, v) - | TT => (cur_idx, TT) - | AppIdent s d idc args - => let default _ - := let default := @subst01' _ args cur_idx in - (fst default, AppIdent idc (snd default)) in - match args in expr.expr t return ident.ident t d -> (unit -> positive * expr d) -> positive * expr d with - | Pair A B x y - => match y in expr.expr Y return ident.ident (A * Y) d -> (unit -> positive * expr d) -> positive * expr d with - | Abs s' d' f - => fun idc - => let '(idx, x') := @subst01' A x cur_idx in - let f' := fun v => snd (@subst01' _ (f v) (Pos.succ idx)) in - match idc in ident.ident s d - return (match s return Type with - | A * _ => expr A - | _ => unit - end%ctype - -> match s return Type with - | _ * (s -> d) => (expr s -> expr d)%type - | _ => unit - end%ctype - -> (unit -> positive * expr d) - -> positive * expr d) - with - | ident.Let_In _ _ - => fun x' f' _ - => if match PositiveMap.find idx live with - | Some n => (n <=? 1)%nat - | None => true - end - then (Pos.succ idx, f' x') - else (Pos.succ idx, AppIdent ident.Let_In (Pair x' (Abs (fun v => f' (Var v))))) - | _ => fun _ _ default => default tt - end x' f' - | _ => fun _ default => default tt - end - | _ => fun _ default => default tt - end idc default - | App s d f x - => let '(idx, f') := @subst01' _ f cur_idx in - let '(idx, x') := @subst01' _ x idx in - (idx, App f' x') - | Pair A B a b - => let '(idx, a') := @subst01' A a cur_idx in - let '(idx, b') := @subst01' B b idx in - (idx, Pair a' b') - | Abs s d f - => (cur_idx, Abs (fun v => snd (@subst01' _ (f (Var v)) (Pos.succ cur_idx)))) - end. - - Definition subst01 {t} e : expr t - := snd (@subst01' t e 1). - End with_var. - - Definition Subst01 {t} (e : Expr t) : Expr t - := fun var => subst01 (ComputeLiveCounts e) (e _). - End Subst01. - - Module ReassociateSmallConstants. - Import Compilers.Uncurried.expr.default. - - Section with_var. - Context (max_const_val : Z) - {var : type -> Type}. - - Fixpoint to_mul_list (e : @expr var type.Z) : list (@expr var type.Z) - := match e in expr.expr t return list (@expr var t) with - | AppIdent s type.Z ident.Z_mul (Pair type.Z type.Z x y) - => to_mul_list x ++ to_mul_list y - | Var _ _ as e - | TT as e - | App _ _ _ _ as e - | Abs _ _ _ as e - | Pair _ _ _ _ as e - | AppIdent _ _ _ _ as e - => [e] - end. - - Definition is_small_prim (e : @expr var type.Z) : bool - := match e with - | AppIdent _ _ (ident.primitive type.Z v) _ - => Z.abs v <=? Z.abs max_const_val - | _ => false - end. - Definition is_not_small_prim (e : @expr var type.Z) : bool - := negb (is_small_prim e). - - Definition reorder_mul_list (ls : list (@expr var type.Z)) - : list (@expr var type.Z) - := filter is_not_small_prim ls ++ filter is_small_prim ls. - - Fixpoint of_mul_list (ls : list (@expr var type.Z)) : @expr var type.Z - := match ls with - | nil => AppIdent (ident.primitive (t:=type.Z) 1) TT - | cons x nil - => x - | cons x xs - => AppIdent ident.Z_mul (x, of_mul_list xs) - end. - - Fixpoint reassociate {t} (e : @expr var t) : @expr var t - := match e in expr.expr t return expr t with - | Var _ _ as e - | TT as e - => e - | Pair A B a b - => Pair (@reassociate A a) (@reassociate B b) - | App s d f x => App (@reassociate _ f) (@reassociate _ x) - | Abs s d f => Abs (fun v => @reassociate _ (f v)) - | AppIdent s type.Z idc args - => of_mul_list (reorder_mul_list (to_mul_list (AppIdent idc (@reassociate s args)))) - | AppIdent s d idc args - => AppIdent idc (@reassociate s args) - end. - End with_var. - - Definition Reassociate (max_const_val : Z) {t} (e : Expr t) : Expr t - := fun var => reassociate max_const_val (e _). - End ReassociateSmallConstants. -End Compilers. -Import Associational Positional Compilers. -Local Coercion Z.of_nat : nat >-> Z. -Local Coercion QArith_base.inject_Z : Z >-> Q. - -(** TODO: FILES: -- Uncurried expr + reification + list canonicalization -- cps -- partial evaluation + DCE -- reassociation -- indexed + bounds analysis + of phoas *) - -Import Uncurried. -Import expr. -Import for_reification.Notations.Reification. - -Notation "x + y" - := (AppIdent ident.Z.add (x, y)%expr) - : expr_scope. -Notation "x * y" - := (AppIdent ident.Z.mul (x, y)%expr) - : expr_scope. -Notation "x" := (Var x) (only printing, at level 9) : expr_scope. - -Example test1 : True. -Proof. - let v := Reify ((fun x => 2^x) 255)%Z in - pose v as E. - vm_compute in E. - pose (PartialEvaluate false (canonicalize_list_recursion E)) as E'. - vm_compute in E'. - lazymatch (eval cbv delta [E'] in E') with - | (fun var => AppIdent (ident.primitive ?v) TT) => idtac - end. - constructor. -Qed. -Module test2. - Example test2 : True. - Proof. - let v := Reify (fun y : Z - => (fun k : Z * Z -> Z * Z - => dlet_nd x := (y * y) in - dlet_nd z := (x * x) in - k (z, z)) - (fun v => v)) in - pose v as E. - vm_compute in E. - pose (PartialEvaluate false (canonicalize_list_recursion E)) as E'. - vm_compute in E'. - lazymatch (eval cbv delta [E'] in E') with - | (fun var : type -> Type => - (λ x : var (type.type_primitive type.Z), - expr_let x0 := (Var x * Var x) in - expr_let x1 := (Var x0 * Var x0) in - (Var x1, Var x1))%expr) => idtac - end. - pose (PartialEvaluateWithBounds1 E' (Some r[0~>10]%zrange)) as E''. - lazy in E''. - lazymatch (eval cbv delta [E''] in E'') with - | (fun var : type -> Type => - (λ x : var (type.type_primitive type.Z), - expr_let y := ident.Z.cast r[0 ~> 100] @@ (Var x * Var x) in - expr_let y0 := ident.Z.cast r[0 ~> 10000] @@ (Var y * Var y) in - (ident.Z.cast r[0 ~> 10000] @@ Var y0, ident.Z.cast r[0 ~> 10000] @@ Var y0))%expr) - => idtac - end. - constructor. - Qed. -End test2. -Module test3. - Example test3 : True. - Proof. - let v := Reify (fun y : Z - => dlet_nd x := dlet_nd x := (y * y) in - (x * x) in - dlet_nd z := dlet_nd z := (x * x) in - (z * z) in - (z * z)) in - pose v as E. - vm_compute in E. - pose (option_map (PartialEvaluate false) (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. - vm_compute in E'. - lazymatch (eval cbv delta [E'] in E') with - | (Some - (fun var : type -> Type => - (λ x : var (type.type_primitive type.Z), - expr_let x0 := Var x * Var x in - expr_let x1 := Var x0 * Var x0 in - expr_let x2 := Var x1 * Var x1 in - expr_let x3 := Var x2 * Var x2 in - Var x3 * Var x3)%expr)) - => idtac - end. - pose (PartialEvaluateWithBounds1 (invert_Some E') (Some r[0~>10]%zrange)) as E'''. - lazy in E'''. - lazymatch (eval cbv delta [E'''] in E''') with - | (fun var : type -> Type => - (λ x : var (type.type_primitive type.Z), - expr_let y := ident.Z.cast r[0 ~> 100] @@ (Var x * Var x) in - expr_let y0 := ident.Z.cast r[0 ~> 10000] @@ (Var y * Var y) in - expr_let y1 := ident.Z.cast r[0 ~> 100000000] @@ (Var y0 * Var y0) in - expr_let y2 := ident.Z.cast r[0 ~> 10000000000000000] @@ (Var y1 * Var y1) in - ident.Z.cast r[0 ~> 100000000000000000000000000000000] @@ (Var y2 * Var y2))%expr) - => idtac - end. - constructor. - Qed. -End test3. -Module test4. - Example test4 : True. - Proof. - let v := Reify (fun y : (list Z * list Z) - => dlet_nd x := List.nth_default (-1) (fst y) 0 in - dlet_nd z := List.nth_default (-1) (snd y) 0 in - dlet_nd xz := (x * z) in - (xz :: xz :: nil)) in - pose v as E. - vm_compute in E. - pose (option_map (PartialEvaluate false) (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. - lazy in E'. - clear E. - pose (PartialEvaluateWithBounds1 (invert_Some E') (Some [Some r[0~>10]%zrange],Some [Some r[0~>10]%zrange])) as E''. - lazy in E''. - lazymatch (eval cbv delta [E''] in E'') with - | (fun var : type -> Type => - (λ x : var (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z))%ctype, - expr_let y := ident.Z.cast r[0 ~> 10] @@ - (ident.List.nth_default_concrete (-1) 0 @@ (ident.fst @@ Var x)) in - expr_let y0 := ident.Z.cast r[0 ~> 10] @@ - (ident.List.nth_default_concrete (-1) 0 @@ (ident.snd @@ Var x)) in - expr_let y1 := ident.Z.cast r[0 ~> 100] @@ (Var y * Var y0) in - ident.Z.cast r[0 ~> 100] @@ Var y1 :: ident.Z.cast r[0 ~> 100] @@ Var y1 :: [])%expr) - => idtac - end. - constructor. - Qed. -End test4. -Module test5. - Example test5 : True. - Proof. - let v := Reify (fun y : (Z * Z) - => dlet_nd x := (13 * (fst y * snd y)) in - x) in - pose v as E. - vm_compute in E. - pose (ReassociateSmallConstants.Reassociate (2^8) (PartialEvaluate false (invert_Some (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))))) as E'. - lazy in E'. - clear E. - lazymatch (eval cbv delta [E'] in E') with - | (fun var => - Abs (fun v - => (expr_let v0 := ident.Z.mul @@ (ident.fst @@ Var v, ident.Z.mul @@ (ident.snd @@ Var v, ident.primitive 13 @@ TT)) in - Var v0)%expr)) - => idtac - end. - constructor. - Qed. -End test5. -Module test6. - (* check for no dead code with if *) - Example test6 : True. - Proof. - let v := Reify (fun y : Z - => if 0 =? 1 - then dlet_nd x := (y * y) in - x - else y) in - pose v as E. - vm_compute in E. - pose (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))) as E'. - lazy in E'. - clear E. - pose (PartialEvaluate false (invert_Some E')) as E''. - lazy in E''. - lazymatch eval cbv delta [E''] in E'' with - | fun var : type -> Type => (λ x : var (type.type_primitive type.Z), Var x)%expr - => idtac - end. - exact I. - Qed. -End test6. -Module test7. - Example test7 : True. - Proof. - let v := Reify (fun y : Z - => dlet_nd x := y + y in - dlet_nd z := x in - dlet_nd z' := z in - dlet_nd z'' := z in - z'' + z'') in - pose v as E. - vm_compute in E. - pose (canonicalize_list_recursion E) as E'. - lazy in E'. - clear E. - pose (Subst01.Subst01 (DeadCodeElimination.EliminateDead E')) as E''. - lazy in E''. - lazymatch eval cbv delta [E''] in E'' with - | fun var : type -> Type => (λ x : var (type.type_primitive type.Z), expr_let v0 := Var x + Var x in Var v0 + Var v0)%expr - => idtac - end. - exact I. - Qed. -End test7. -Module test8. - Example test8 : True. - Proof. - let v := Reify (fun y : Z - => dlet_nd x := y + y in - dlet_nd z := x in - dlet_nd z' := z in - dlet_nd z'' := z in - z'' + z'') in - pose v as E. - vm_compute in E. - pose (canonicalize_list_recursion E) as E'. - lazy in E'. - clear E. - pose (GeneralizeVar.GeneralizeVar (E' _)) as E''. - lazy in E''. - unify E' E''. - exact I. - Qed. -End test8. -Module test9. - Example test9 : True. - Proof. - let v := Reify (fun y : list Z => (hd 0%Z y, tl y)) in - pose v as E. - vm_compute in E. - pose (PartialEvaluate true (canonicalize_list_recursion E)) as E'. - lazy in E'. - clear E. - lazymatch (eval cbv delta [E'] in E') with - | (fun var - => (λ x, - (ident.list_rect - @@ - ((λ _, ident.primitive 0%Z @@ TT), - (λ x0, ident.fst @@ (ident.fst @@ Var x0)), - Var x), - ident.list_rect - @@ - ((λ _, ident.nil @@ TT), - (λ x0, ident.snd @@ (ident.fst @@ Var x0)), - Var x)))%expr) - => idtac - end. - exact I. - Qed. -End test9. -Module test10. - Example test10 : True. - Proof. - let v := Reify (fun (f : Z -> Z -> Z) x y => f (x + y) (x * y))%Z in - pose v as E. - vm_compute in E. - pose (Uncurry.expr.Uncurry (PartialEvaluate true (canonicalize_list_recursion E))) as E'. - lazy in E'. - clear E. - lazymatch (eval cbv delta [E'] in E') with - | (fun var => - (λ v, - ident.fst @@ Var v @ - (ident.fst @@ (ident.snd @@ Var v) + ident.snd @@ (ident.snd @@ Var v)) @ - (ident.fst @@ (ident.snd @@ Var v) * ident.snd @@ (ident.snd @@ Var v)))%expr) - => idtac - end. - constructor. - Qed. -End test10. -Module test11. - Example test11 : True. - Proof. - let v := Reify (fun x y => (fun f a b => f a b) (fun a b => a + b) (x + y) (x * y))%Z in - pose v as E. - vm_compute in E. - pose (Uncurry.expr.Uncurry (PartialEvaluate true (canonicalize_list_recursion E))) as E'. - lazy in E'. - clear E. - lazymatch (eval cbv delta [E'] in E') with - | (fun var => - (λ x, - ident.fst @@ Var x + ident.snd @@ Var x + ident.fst @@ Var x * ident.snd @@ Var x)%expr) - => idtac - end. - constructor. - Qed. -End test11. -Axiom admit_pf : False. -Notation admit := (match admit_pf with end). -Ltac cache_reify _ := - intros; - etransitivity; - [ - | repeat apply (f_equal (fun f => f _)); - Reify_rhs (); - reflexivity ]; - cbv beta; - let RHS := match goal with |- _ = ?RHS => RHS end in - let e := match RHS with context[expr.Interp _ ?e] => e end in - let E := fresh "E" in - set (E := e); - let E' := constr:(canonicalize_list_recursion E) in - let LHS := match goal with |- ?LHS = _ => LHS end in - lazymatch LHS with - | context LHS[@expr.Interp ?ident ?interp_ident ?t ?e] - => let LHS := context LHS[@expr.Interp ident interp_ident t E'] in - transitivity LHS; [ | clear e ] - end; - [ repeat match goal with |- context[expr.Interp _ _ _] => apply (f_equal (fun f => f _)) end; - apply f_equal; - lazymatch goal with |- ?LHS = ?RHS => subst LHS end; - let RHS := lazymatch goal with |- ?LHS = ?RHS => RHS end in - time (let RHS' := (eval vm_compute in RHS) in (* [vm_compute] is much faster than [lazy] here on large things *) - time instantiate (1:=RHS'); - vm_cast_no_check (eq_refl RHS')) - | clearbody E ]. - -Create HintDb reify_gen_cache. - -Derive carry_mul_gen - SuchThat (forall (limbwidth_num limbwidth_den : Z) - (f g : list Z) - (n : nat) - (s : Z) - (c : list (Z * Z)) - (len_c : nat) - (idxs : list nat) - (len_idxs : nat), - Interp (t:=type.reify_type_of carry_mulmod) - carry_mul_gen limbwidth_num limbwidth_den s c n len_c idxs len_idxs f g - = carry_mulmod limbwidth_num limbwidth_den s c n len_c idxs len_idxs f g) - As carry_mul_gen_correct. -Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. -Hint Extern 1 (_ = carry_mulmod _ _ _ _ _ _ _ _ _ _) => simple apply carry_mul_gen_correct : reify_gen_cache. - -Derive carry_gen - SuchThat (forall (limbwidth_num limbwidth_den : Z) - (f : list Z) - (n : nat) - (s : Z) - (c : list (Z * Z)) - (len_c : nat) - (idxs : list nat) - (len_idxs : nat), - Interp (t:=type.reify_type_of carrymod) - carry_gen limbwidth_num limbwidth_den s c n len_c idxs len_idxs f - = carrymod limbwidth_num limbwidth_den s c n len_c idxs len_idxs f) - As carry_gen_correct. -Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. -Hint Extern 1 (_ = carrymod _ _ _ _ _ _ _ _ _) => simple apply carry_gen_correct : reify_gen_cache. - -Derive encode_gen - SuchThat (forall (limbwidth_num limbwidth_den : Z) - (v : Z) - (n : nat) - (s : Z) - (c : list (Z * Z)) - (len_c : nat), - Interp (t:=type.reify_type_of encodemod) - encode_gen limbwidth_num limbwidth_den s c n len_c v - = encodemod limbwidth_num limbwidth_den s c n len_c v) - As encode_gen_correct. -Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. -Hint Extern 1 (_ = encodemod _ _ _ _ _ _ _) => simple apply encode_gen_correct : reify_gen_cache. - -Derive add_gen - SuchThat (forall (limbwidth_num limbwidth_den : Z) - (f g : list Z) - (n : nat), - Interp (t:=type.reify_type_of addmod) - add_gen limbwidth_num limbwidth_den n f g - = addmod limbwidth_num limbwidth_den n f g) - As add_gen_correct. -Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. -Hint Extern 1 (_ = addmod _ _ _ _ _) => simple apply add_gen_correct : reify_gen_cache. -Derive sub_gen - SuchThat (forall (limbwidth_num limbwidth_den : Z) - (n : nat) - (s : Z) - (c : list (Z * Z)) - (len_c : nat) - (coef : Z) - (f g : list Z), - Interp (t:=type.reify_type_of submod) - sub_gen limbwidth_num limbwidth_den s c n len_c coef f g - = submod limbwidth_num limbwidth_den s c n len_c coef f g) - As sub_gen_correct. -Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. -Hint Extern 1 (_ = submod _ _ _ _ _ _ _ _ _) => simple apply sub_gen_correct : reify_gen_cache. - -Derive opp_gen - SuchThat (forall (limbwidth_num limbwidth_den : Z) - (n : nat) - (s : Z) - (c : list (Z * Z)) - (len_c : nat) - (coef : Z) - (f : list Z), - Interp (t:=type.reify_type_of oppmod) - opp_gen limbwidth_num limbwidth_den s c n len_c coef f - = oppmod limbwidth_num limbwidth_den s c n len_c coef f) - As opp_gen_correct. -Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. -Hint Extern 1 (_ = oppmod _ _ _ _ _ _ _ _) => simple apply opp_gen_correct : reify_gen_cache. - -Definition zeromod limbwidth_num limbwidth_den n s c len_c := encodemod limbwidth_num limbwidth_den n s c len_c 0. -Definition onemod limbwidth_num limbwidth_den n s c len_c := encodemod limbwidth_num limbwidth_den n s c len_c 1. - -Derive zero_gen - SuchThat (forall (limbwidth_num limbwidth_den : Z) - (n : nat) - (s : Z) - (c : list (Z * Z)) - (len_c : nat), - Interp (t:=type.reify_type_of zeromod) - zero_gen limbwidth_num limbwidth_den s c n len_c - = zeromod limbwidth_num limbwidth_den s c n len_c) - As zero_gen_correct. -Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. -Hint Extern 1 (_ = zeromod _ _ _ _ _ _) => simple apply zero_gen_correct : reify_gen_cache. - -Derive one_gen - SuchThat (forall (limbwidth_num limbwidth_den : Z) - (n : nat) - (s : Z) - (c : list (Z * Z)) - (len_c : nat), - Interp (t:=type.reify_type_of onemod) - one_gen limbwidth_num limbwidth_den s c n len_c - = onemod limbwidth_num limbwidth_den s c n len_c) - As one_gen_correct. -Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. -Hint Extern 1 (_ = onemod _ _ _ _ _ _) => simple apply one_gen_correct : reify_gen_cache. - -Derive id_gen - SuchThat (forall (n : nat) - (ls : list Z), - Interp (t:=type.reify_type_of expanding_id) - id_gen n ls - = expanding_id n ls) - As id_gen_correct. -Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. -Hint Extern 1 (_ = expanding_id _ _) => simple apply id_gen_correct : reify_gen_cache. - -Import Uncurry. -Module Pipeline. - Import GeneralizeVar. - Inductive ErrorMessage := - | Computed_bounds_are_not_tight_enough - {t} (computed_bounds expected_bounds : ZRange.type.option.interp t) - {s} (syntax_tree : Expr (s -> t)) (arg_bounds : ZRange.type.option.interp s) - | Bounds_analysis_failed - | Type_too_complicated_for_cps (t : type) - | Value_not_le (descr : string) {T'} (lhs rhs : T') - | Value_not_lt (descr : string) {T'} (lhs rhs : T') - | Values_not_provably_distinct (descr : string) {T'} (lhs rhs : T') - | Values_not_provably_equal (descr : string) {T'} (lhs rhs : T'). - - Inductive ErrorT {T} := - | Success (v : T) - | Error (msg : ErrorMessage). - Global Arguments ErrorT : clear implicits. - - Definition invert_result {T} (v : ErrorT T) - := match v return match v with Success _ => T | _ => ErrorMessage end with - | Success v => v - | Error msg => msg - end. - - Definition PrePipeline - {t} - (E : for_reification.Expr t) - : Expr t - := canonicalize_list_recursion E. - - Lemma PrePipeline_correct {t} (E : for_reification.Expr t) v - : expr.Interp (@ident.interp) v = - expr.Interp (@for_reification.ident.interp) E. - Admitted. - - Definition BoundsPipeline - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - {t} - (E : Expr t) - arg_bounds - out_bounds - : ErrorT (Expr (type.uncurry t)) - := let E := expr.Uncurry E in - let E := CPS.CallFunWithIdContinuation (CPS.Translate E) in - match E with - | Some E - => (let E := PartialEvaluate false E in - (* Note that DCE evaluates the expr with two different - [var] arguments, and so results in a pipeline that is - 2x slower unless we pass through a uniformly concrete - [var] type first *) - dlet_nd e := ToFlat E in - let E := FromFlat e in - let E := if with_dead_code_elimination then DeadCodeElimination.EliminateDead E else E in - dlet_nd e := ToFlat E in - let E := FromFlat e in - let E := if with_subst01 then Subst01.Subst01 E else E in - let E := ReassociateSmallConstants.Reassociate (2^8) E in - let E := CheckedPartialEvaluateWithBounds1 relax_zrange E arg_bounds out_bounds in - match E with - | inl E => Success E - | inr (b, E) - => Error (Computed_bounds_are_not_tight_enough b out_bounds E arg_bounds) - end) - | None => Error (Type_too_complicated_for_cps t) - end. - - Lemma BoundsPipeline_correct - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - (Hrelax : forall r r' z : zrange, - (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) - {t} - (e : Expr t) - arg_bounds - out_bounds - rv - (Hrv : BoundsPipeline (*with_dead_code_elimination*) with_subst01 relax_zrange e arg_bounds out_bounds = Success rv) - : forall arg - (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), - ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true - /\ Interp rv arg = app_curried (Interp e) arg. - Proof. - cbv [BoundsPipeline Let_In] in *; - repeat match goal with - | [ H : match ?x with _ => _ end = Success _ |- _ ] - => destruct x eqn:?; cbv beta iota in H; [ | destruct_head'_prod; congruence ]; - let H' := fresh in - inversion H as [H']; clear H; rename H' into H - end. - { intros; - match goal with - | [ H : _ = _ |- _ ] - => eapply CheckedPartialEvaluateWithBounds1_Correct in H; - [ destruct H as [H0 H1] | .. ] - end; - [ - | eassumption || (try reflexivity).. ]. - refine (let H' := admit (* interp correctness *) in - conj _ (eq_trans H' _)); - clearbody H'. - { rewrite H'; eassumption. } - { rewrite H0. - exact admit. (* interp correctness *) } } - Qed. - - Definition BoundsPipeline_correct_transT - {t} - arg_bounds - out_bounds - (InterpE : type.interp t) - (rv : Expr (type.uncurry t)) - := forall arg - (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), - ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true - /\ Interp rv arg = app_curried InterpE arg. - - Lemma BoundsPipeline_correct_trans - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - (Hrelax - : forall r r' z : zrange, - (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) - {t} - (e : Expr t) - arg_bounds out_bounds - (InterpE : type.interp t) - (InterpE_correct - : forall arg - (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), - app_curried (Interp e) arg = app_curried InterpE arg) - rv - (Hrv : BoundsPipeline (*with_dead_code_elimination*) with_subst01 relax_zrange e arg_bounds out_bounds = Success rv) - : BoundsPipeline_correct_transT arg_bounds out_bounds InterpE rv. - Proof. - intros arg Harg; rewrite <- InterpE_correct by assumption. - eapply @BoundsPipeline_correct; eassumption. - Qed. - - Definition BoundsPipeline_full - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - {t} - (E : for_reification.Expr t) - arg_bounds - out_bounds - : ErrorT (Expr (type.uncurry t)) - := let E := PrePipeline E in - @BoundsPipeline - (*with_dead_code_elimination*) - with_subst01 - relax_zrange - t E arg_bounds out_bounds. - - Lemma BoundsPipeline_full_correct - (with_dead_code_elimination : bool := true) - (with_subst01 : bool) - relax_zrange - (Hrelax : forall r r' z : zrange, - (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) - {t} - (E : for_reification.Expr t) - arg_bounds - out_bounds - rv - (Hrv : BoundsPipeline_full (*with_dead_code_elimination*) with_subst01 relax_zrange E arg_bounds out_bounds = Success rv) - : forall arg - (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), - ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true - /\ Interp rv arg = app_curried (for_reification.Interp E) arg. - Proof. - cbv [BoundsPipeline_full] in *. - eapply BoundsPipeline_correct_trans; [ eassumption | | eassumption.. ]. - intros; erewrite PrePipeline_correct; reflexivity. - Qed. -End Pipeline. - -Definition round_up_bitwidth_gen (possible_values : list Z) (bitwidth : Z) : option Z - := List.fold_right - (fun allowed cur - => if bitwidth <=? allowed - then Some allowed - else cur) - None - possible_values. - -Lemma round_up_bitwidth_gen_le possible_values bitwidth v - : round_up_bitwidth_gen possible_values bitwidth = Some v - -> bitwidth <= v. -Proof. - cbv [round_up_bitwidth_gen]. - induction possible_values as [|x xs IHxs]; cbn; intros; inversion_option. - break_innermost_match_hyps; Z.ltb_to_lt; inversion_option; subst; trivial. - specialize_by_assumption; omega. -Qed. - -Definition relax_zrange_gen (possible_values : list Z) : zrange -> option zrange - := (fun '(r[ l ~> u ]) - => if (0 <=? l)%Z - then option_map (fun u => r[0~>2^u-1]) - (round_up_bitwidth_gen possible_values (Z.log2_up (u+1))) - else None)%zrange. - -Lemma relax_zrange_gen_good - (possible_values : list Z) - : forall r r' z : zrange, - (z <=? r)%zrange = true -> relax_zrange_gen possible_values r = Some r' -> (z <=? r')%zrange = true. -Proof. - cbv [is_tighter_than_bool relax_zrange_gen]; intros *. - pose proof (Z.log2_up_nonneg (upper r + 1)). - rewrite !Bool.andb_true_iff; destruct_head' zrange; cbn [ZRange.lower ZRange.upper] in *. - cbv [fold_right option_map]. - break_innermost_match; intros; destruct_head'_and; - try match goal with - | [ H : _ |- _ ] => apply round_up_bitwidth_gen_le in H - end; - inversion_option; inversion_zrange; - subst; - repeat apply conj; - Z.ltb_to_lt; try omega; - try (rewrite <- Z.log2_up_le_pow2_full in *; omega). -Qed. - -(** XXX TODO: Translate Jade's python script *) -Section rcarry_mul. - Context (n : nat) - (s : Z) - (c : list (Z * Z)) - (machine_wordsize : Z). - - Let limbwidth := (Z.log2_up (s - Associational.eval c) / Z.of_nat n)%Q. - Let idxs := (seq 0 n ++ [0; 1])%list%nat. - Let coef := 2. - Let tight_upperbounds : list Z - := List.map - (fun v : Z => Qceiling (11/10 * v)) - (encode (weight (Qnum limbwidth) (Qden limbwidth)) n s c (s-1)). - Let prime_bound : ZRange.type.option.interp (type.Z) - := Some r[0~>(s - Associational.eval c - 1)]%zrange. - - Definition relax_zrange_of_machine_wordsize - := relax_zrange_gen [machine_wordsize; 2 * machine_wordsize]%Z. - - Let relax_zrange := relax_zrange_of_machine_wordsize. - Let tight_bounds : list (ZRange.type.option.interp type.Z) - := List.map (fun u => Some r[0~>u]%zrange) tight_upperbounds. - Let loose_bounds : list (ZRange.type.option.interp type.Z) - := List.map (fun u => Some r[0 ~> 3*u]%zrange) tight_upperbounds. - - Definition check_args {T} (res : Pipeline.ErrorT T) - : Pipeline.ErrorT T - := if negb (Qle_bool 1 limbwidth)%Q - then Pipeline.Error (Pipeline.Value_not_le "1 ≤ limbwidth" 1%Q limbwidth) - else if (negb (0 <? s - Associational.eval c))%Z - then Pipeline.Error (Pipeline.Value_not_lt "s - Associational.eval c ≤ 0" 0 (s - Associational.eval c)) - else if (s =? 0)%Z - then Pipeline.Error (Pipeline.Values_not_provably_distinct "s ≠ 0" s 0) - else if (n =? 0)%nat - then Pipeline.Error (Pipeline.Values_not_provably_distinct "n ≠ 0" n 0%nat) - else if (negb (0 <? machine_wordsize)) - then Pipeline.Error (Pipeline.Value_not_lt "0 < machine_wordsize" 0 machine_wordsize) - else res. - - Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). - - Notation BoundsPipeline rop in_bounds out_bounds - := (Pipeline.BoundsPipeline - (*false*) true - relax_zrange - rop%Expr in_bounds out_bounds). - - Notation BoundsPipeline_correct in_bounds out_bounds op - := (fun rv (rop : Expr (type.reify_type_of op)) Hrop - => @Pipeline.BoundsPipeline_correct_trans - (*false*) true - relax_zrange - (relax_zrange_gen_good _) - _ - rop - in_bounds - out_bounds - op - Hrop rv) - (only parsing). - - (* N.B. We only need [rcarry_mul] if we want to extract the Pipeline; otherwise we can just use [rcarry_mul_correct] *) - Definition rcarry_mul - := BoundsPipeline - (carry_mul_gen - @ GallinaReify.Reify (Qnum limbwidth) @ GallinaReify.Reify (Z.pos (Qden limbwidth)) @ GallinaReify.Reify s @ GallinaReify.Reify c @ GallinaReify.Reify n @ GallinaReify.Reify (length c) @ GallinaReify.Reify idxs @ GallinaReify.Reify (length idxs)) - (Some loose_bounds, Some loose_bounds) - (Some tight_bounds). - - Definition rcarry_mul_correct - := BoundsPipeline_correct - (Some loose_bounds, Some loose_bounds) - (Some tight_bounds) - (carry_mulmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c) idxs (List.length idxs)). - - Definition rcarry_correct - := BoundsPipeline_correct - (Some loose_bounds) - (Some tight_bounds) - (carrymod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c) idxs (List.length idxs)). - - Definition rrelax_correct - := BoundsPipeline_correct - (Some tight_bounds) - (Some loose_bounds) - (expanding_id n). - - Definition radd_correct - := BoundsPipeline_correct - (Some tight_bounds, Some tight_bounds) - (Some loose_bounds) - (addmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) n). - - Definition rsub_correct - := BoundsPipeline_correct - (Some tight_bounds, Some tight_bounds) - (Some loose_bounds) - (submod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c) coef). - - Definition ropp_correct - := BoundsPipeline_correct - (Some tight_bounds) - (Some loose_bounds) - (oppmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c) coef). - - Definition rencode_correct - := BoundsPipeline_correct - prime_bound - (Some tight_bounds) - (encodemod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)). - - Definition rzero_correct - := BoundsPipeline_correct - tt - (Some tight_bounds) - (zeromod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)). - - Definition rone_correct - := BoundsPipeline_correct - tt - (Some tight_bounds) - (onemod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)). - - (* we need to strip off [Hrv : ... = Pipeline.Success rv] and related arguments *) - Definition rcarry_mul_correctT rv : Prop - := type_of_strip_3arrow (@rcarry_mul_correct rv). - Definition rcarry_correctT rv : Prop - := type_of_strip_3arrow (@rcarry_correct rv). - Definition rrelax_correctT rv : Prop - := type_of_strip_3arrow (@rrelax_correct rv). - Definition radd_correctT rv : Prop - := type_of_strip_3arrow (@radd_correct rv). - Definition rsub_correctT rv : Prop - := type_of_strip_3arrow (@rsub_correct rv). - Definition ropp_correctT rv : Prop - := type_of_strip_3arrow (@ropp_correct rv). - Definition rencode_correctT rv : Prop - := type_of_strip_3arrow (@rencode_correct rv). - Definition rzero_correctT rv : Prop - := type_of_strip_3arrow (@rzero_correct rv). - Definition rone_correctT rv : Prop - := type_of_strip_3arrow (@rone_correct rv). - - Section make_ring. - Let m : positive := Z.to_pos (s - Associational.eval c). - Context (curve_good : check_args (Pipeline.Success tt) = Pipeline.Success tt) - {rcarry_mulv} (Hrmulv : rcarry_mul_correctT rcarry_mulv) - {rcarryv} (Hrcarryv : rcarry_correctT rcarryv) - {rrelaxv} (Hrrelaxv : rrelax_correctT rrelaxv) - {raddv} (Hraddv : radd_correctT raddv) - {rsubv} (Hrsubv : rsub_correctT rsubv) - {roppv} (Hroppv : ropp_correctT roppv) - {rzerov} (Hrzerov : rzero_correctT rzerov) - {ronev} (Hronev : rone_correctT ronev) - {rencodev} (Hrencodev : rencode_correctT rencodev). - - Local Ltac use_curve_good_t := - repeat first [ progress rewrite ?map_length, ?Z.mul_0_r, ?Pos.mul_1_r, ?Z.mul_1_r in * - | reflexivity - | lia - | rewrite interp_reify_list, ?map_map - | rewrite map_ext with (g:=id), map_id - | progress distr_length - | progress cbv [Qceiling Qfloor Qopp Qdiv Qplus inject_Z Qmult Qinv] in * - | progress cbv [Qle] in * - | progress cbn -[reify_list] in * - | progress intros - | solve [ auto ] ]. - - Lemma use_curve_good - : Z.pos m = s - Associational.eval c - /\ Z.pos m <> 0 - /\ s - Associational.eval c <> 0 - /\ s <> 0 - /\ 0 < machine_wordsize - /\ n <> 0%nat - /\ List.length tight_bounds = n - /\ List.length loose_bounds = n - /\ 0 < Qden limbwidth <= Qnum limbwidth. - Proof. - clear -curve_good. - cbv [check_args] in curve_good. - break_innermost_match_hyps; try discriminate. - rewrite negb_false_iff in *. - Z.ltb_to_lt. - rewrite Qle_bool_iff in *. - rewrite NPeano.Nat.eqb_neq in *. - intros. - cbv [Qnum Qden limbwidth Qceiling Qfloor Qopp Qdiv Qplus inject_Z Qmult Qinv] in *. - rewrite ?map_length, ?Z.mul_0_r, ?Pos.mul_1_r, ?Z.mul_1_r in *. - specialize_by lia. - repeat match goal with H := _ |- _ => subst H end. - repeat apply conj. - { destruct (s - Associational.eval c); cbn; lia. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - Qed. - - Definition GoodT : Prop - := @Ring.GoodT - (Qnum limbwidth) - (Z.pos (Qden limbwidth)) - n s c - tight_bounds - (Interp rrelaxv) - (Interp rcarry_mulv) - (Interp rcarryv) - (Interp raddv) - (Interp rsubv) - (Interp roppv) - (Interp rzerov tt) - (Interp ronev tt) - (Interp rencodev). - - Theorem Good : GoodT. - Proof. - pose proof use_curve_good; destruct_head'_and; destruct_head_hnf' ex. - eapply Ring.Good; - repeat first [ assumption - | intros; apply eval_carry_mulmod - | intros; apply eval_carrymod - | intros; apply eval_addmod - | intros; apply eval_submod - | intros; apply eval_oppmod - | intros; apply eval_encodemod - | eassumption - | apply conj - | progress intros - | progress cbv [onemod zeromod] - | eapply Hrzerov (* to handle diff with whether or not correctness asks for boundedness of tt *) - | eapply Hronev (* to handle diff with whether or not correctness asks for boundedness of tt *) - | match goal with - | [ |- ?x = ?x ] => reflexivity - | [ |- ?x = ?ev ] => is_evar ev; reflexivity - | [ |- ZRange.type.option.is_bounded_by tt tt = true ] => reflexivity - end ]. - Qed. - End make_ring. -End rcarry_mul. - -Ltac peel_interp_app _ := - lazymatch goal with - | [ |- ?R' (?InterpE ?arg) (?f ?arg) ] - => apply fg_equal_rel; [ | reflexivity ]; - try peel_interp_app () - | [ |- ?R' (Interp ?ev) (?f ?x) ] - => let sv := type of x in - let fx := constr:(f x) in - let dv := type of fx in - let rs := type.reify sv in - let rd := type.reify dv in - etransitivity; - [ apply @Interp_APP_rel_reflexive with (s:=rs) (d:=rd) (R:=R'); - typeclasses eauto - | apply fg_equal_rel; - [ try peel_interp_app () - | try lazymatch goal with - | [ |- ?R (Interp ?ev) (Interp _) ] - => reflexivity - | [ |- ?R (Interp ?ev) ?c ] - => let rc := constr:(GallinaReify.Reify c) in - unify ev rc; reflexivity - end ] ] - end. -Ltac pre_cache_reify _ := - cbv [app_curried]; - let arg := fresh "arg" in - intros arg _; - peel_interp_app (); - [ lazymatch goal with - | [ |- ?R (Interp ?ev) _ ] - => (tryif is_evar ev - then let ev' := fresh "ev" in set (ev' := ev) - else idtac) - end; - cbv [pointwise_relation]; intros; clear - | .. ]. -Ltac do_inline_cache_reify do_if_not_cached := - pre_cache_reify (); - [ try solve [ - repeat match goal with H := ?e |- _ => is_evar e; subst H end; - eauto with nocore reify_gen_cache; - do_if_not_cached () - ]; - cache_reify (); exact admit - | .. ]. - -(* TODO: MOVE ME *) -Ltac vm_compute_lhs_reflexivity := - lazymatch goal with - | [ |- ?LHS = ?RHS ] - => let x := (eval vm_compute in LHS) in - (* we cannot use the unify tactic, which just gives "not - unifiable" as the error message, because we want to see the - terms that were not unifable. See also - COQBUG(https://github.com/coq/coq/issues/7291) *) - let _unify := constr:(ltac:(reflexivity) : RHS = x) in - vm_cast_no_check (eq_refl x) - end. - -Ltac solve_rop' rop_correct do_if_not_cached machine_wordsizev := - eapply rop_correct with (machine_wordsize:=machine_wordsizev); - [ do_inline_cache_reify do_if_not_cached - | subst_evars; vm_compute_lhs_reflexivity (* lazy; reflexivity *) ]. -Ltac solve_rop_nocache rop_correct := - solve_rop' rop_correct ltac:(fun _ => idtac). -Ltac solve_rop rop_correct := - solve_rop' - rop_correct - ltac:(fun _ => let G := get_goal in fail 2 "Could not find a solution in reify_gen_cache for" G). -Ltac solve_rcarry_mul := solve_rop rcarry_mul_correct. -Ltac solve_rcarry_mul_nocache := solve_rop_nocache rcarry_mul_correct. -Ltac solve_rcarry := solve_rop rcarry_correct. -Ltac solve_radd := solve_rop radd_correct. -Ltac solve_rsub := solve_rop rsub_correct. -Ltac solve_ropp := solve_rop ropp_correct. -Ltac solve_rencode := solve_rop rencode_correct. -Ltac solve_rrelax := solve_rop rrelax_correct. -Ltac solve_rzero := solve_rop rzero_correct. -Ltac solve_rone := solve_rop rone_correct. - -Module PrintingNotations. - Export ident. - (*Global Set Printing Width 100000.*) - Open Scope zrange_scope. - Notation "'uint256'" - := (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : zrange_scope. - Notation "'uint128'" - := (r[0 ~> 340282366920938463463374607431768211455]%zrange) : zrange_scope. - Notation "'uint64'" - := (r[0 ~> 18446744073709551615]) : zrange_scope. - Notation "'uint32'" - := (r[0 ~> 4294967295]) : zrange_scope. - Notation "'bool'" - := (r[0 ~> 1]%zrange) : zrange_scope. - Notation "ls [[ n ]]" - := ((List.nth_default_concrete _ n @@ ls)%expr) - (at level 30, format "ls [[ n ]]") : expr_scope. - Notation "( range )( ls [[ n ]] )" - := ((ident.Z.cast range @@ (List.nth_default_concrete _ n @@ ls))%expr) - (format "( range )( ls [[ n ]] )") : expr_scope. - (*Notation "( range )( v )" := (ident.Z.cast range @@ v)%expr : expr_scope.*) - Notation "x *₂₅₆ y" - := (ident.Z.cast uint256 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope. - Notation "x *₁₂₈ y" - := (ident.Z.cast uint128 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope. - Notation "x *₆₄ y" - := (ident.Z.cast uint64 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope. - Notation "x *₃₂ y" - := (ident.Z.cast uint32 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope. - Notation "x +₂₅₆ y" - := (ident.Z.cast uint256 @@ (ident.Z.add @@ (x, y)))%expr (at level 50) : expr_scope. - Notation "x +₁₂₈ y" - := (ident.Z.cast uint128 @@ (ident.Z.add @@ (x, y)))%expr (at level 50) : expr_scope. - Notation "x +₆₄ y" - := (ident.Z.cast uint64 @@ (ident.Z.add @@ (x, y)))%expr (at level 50) : expr_scope. - Notation "x +₃₂ y" - := (ident.Z.cast uint32 @@ (ident.Z.add @@ (x, y)))%expr (at level 50) : expr_scope. - Notation "x -₁₂₈ y" - := (ident.Z.cast uint128 @@ (ident.Z.sub @@ (x, y)))%expr (at level 50) : expr_scope. - Notation "x -₆₄ y" - := (ident.Z.cast uint64 @@ (ident.Z.sub @@ (x, y)))%expr (at level 50) : expr_scope. - Notation "x -₃₂ y" - := (ident.Z.cast uint32 @@ (ident.Z.sub @@ (x, y)))%expr (at level 50) : expr_scope. - Notation "( out_t )( v >> count )" - := ((ident.Z.cast out_t @@ (ident.Z.shiftr count @@ v))%expr) - (format "( out_t )( v >> count )") : expr_scope. - Notation "( out_t )( v << count )" - := ((ident.Z.cast out_t @@ (ident.Z.shiftl count @@ v))%expr) - (format "( out_t )( v << count )") : expr_scope. - Notation "( range )( v )" - := ((ident.Z.cast range @@ Var v)%expr) - (format "( range )( v )") : expr_scope. - Notation "( ( out_t )( v ) & mask )" - := ((ident.Z.cast out_t @@ (ident.Z.land mask @@ v))%expr) - (format "( ( out_t )( v ) & mask )") - : expr_scope. - - Notation "x" := (ident.Z.cast _ @@ Var x)%expr (only printing, at level 9) : expr_scope. - Notation "x" := (ident.Z.cast2 _ @@ Var x)%expr (only printing, at level 9) : expr_scope. - Notation "v ₁" := (ident.fst @@ Var v)%expr (at level 10, format "v ₁") : expr_scope. - Notation "v ₂" := (ident.snd @@ Var v)%expr (at level 10, format "v ₂") : expr_scope. - Notation "v ₁" := (ident.Z.cast _ @@ (ident.fst @@ Var v))%expr (at level 10, format "v ₁") : expr_scope. - Notation "v ₂" := (ident.Z.cast _ @@ (ident.snd @@ Var v))%expr (at level 10, format "v ₂") : expr_scope. - Notation "v ₁" := (ident.Z.cast _ @@ (ident.fst @@ (ident.Z.cast2 _ @@ Var v)))%expr (at level 10, format "v ₁") : expr_scope. - Notation "v ₂" := (ident.Z.cast _ @@ (ident.snd @@ (ident.Z.cast2 _ @@ Var v)))%expr (at level 10, format "v ₂") : expr_scope. - - (*Notation "ls [[ n ]]" := (List.nth_default_concrete _ n @@ ls)%expr : expr_scope. - Notation "( range )( v )" := (ident.Z.cast range @@ v)%expr : expr_scope. - Notation "x *₁₂₈ y" - := (ident.Z.cast uint128 @@ (ident.Z.mul (x, y)))%expr (at level 40) : expr_scope. - Notation "( out_t )( v >> count )" - := (ident.Z.cast out_t (ident.Z.shiftr count @@ v)%expr) - (format "( out_t )( v >> count )") : expr_scope. - Notation "( out_t )( v >> count )" - := (ident.Z.cast out_t (ident.Z.shiftr count @@ v)%expr) - (format "( out_t )( v >> count )") : expr_scope. - Notation "v ₁" := (ident.fst @@ v)%expr (at level 10, format "v ₁") : expr_scope. - Notation "v ₂" := (ident.snd @@ v)%expr (at level 10, format "v ₂") : expr_scope.*) - (* - Notation "'ℤ'" - := BoundsAnalysis.type.Z : zrange_scope. - Notation "ls [[ n ]]" := (List.nth n @@ ls)%nexpr : nexpr_scope. - Notation "x *₆₄₋₆₄₋₁₂₈ y" - := (mul uint64 uint64 uint128 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x *₆₄₋₆₄₋₆₄ y" - := (mul uint64 uint64 uint64 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x *₃₂₋₃₂₋₃₂ y" - := (mul uint32 uint32 uint32 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x *₃₂₋₁₂₈₋₁₂₈ y" - := (mul uint32 uint128 uint128 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x *₃₂₋₆₄₋₆₄ y" - := (mul uint32 uint64 uint64 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x *₃₂₋₃₂₋₆₄ y" - := (mul uint32 uint32 uint64 @@ (x, y))%nexpr (at level 40) : nexpr_scope. - Notation "x +₁₂₈ y" - := (add uint128 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x +₆₄₋₁₂₈₋₁₂₈ y" - := (add uint64 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x +₃₂₋₆₄₋₆₄ y" - := (add uint32 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x +₆₄ y" - := (add uint64 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x +₃₂ y" - := (add uint32 uint32 uint32 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x -₁₂₈ y" - := (sub uint128 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x -₆₄₋₁₂₈₋₁₂₈ y" - := (sub uint64 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x -₃₂₋₆₄₋₆₄ y" - := (sub uint32 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x -₆₄ y" - := (sub uint64 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x -₃₂ y" - := (sub uint32 uint32 uint32 @@ (x, y))%nexpr (at level 50) : nexpr_scope. - Notation "x" := ({| BoundsAnalysis.type.value := x |}) (only printing) : nexpr_scope. - Notation "( out_t )( v >> count )" - := ((shiftr _ out_t count @@ v)%nexpr) - (format "( out_t )( v >> count )") - : nexpr_scope. - Notation "( out_t )( v << count )" - := ((shiftl _ out_t count @@ v)%nexpr) - (format "( out_t )( v << count )") - : nexpr_scope. - Notation "( ( out_t ) v & mask )" - := ((land _ out_t mask @@ v)%nexpr) - (format "( ( out_t ) v & mask )") - : nexpr_scope. -*) - (* TODO: come up with a better notation for arithmetic with carries - that still distinguishes it from arithmetic without carries? *) - Local Notation "'TwoPow256'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 (only parsing). - Notation "'ADD_256' ( x , y )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.add_get_carry_concrete TwoPow256 @@ (x, y)))%expr : expr_scope. - Notation "'ADD_128' ( x , y )" := (ident.Z.cast2 (uint128, bool)%core @@ (ident.Z.add_get_carry_concrete TwoPow256 @@ (x, y)))%expr : expr_scope. - Notation "'ADDC_256' ( x , y , z )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.add_with_get_carry_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope. - Notation "'ADDC_128' ( x , y , z )" := (ident.Z.cast2 (uint128, bool)%core @@ (ident.Z.add_with_get_carry_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope. - Notation "'SUB_256' ( x , y )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.sub_get_borrow_concrete TwoPow256 @@ (x, y)))%expr : expr_scope. - Notation "'SUBB_256' ( x , y , z )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.sub_with_get_borrow_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope. - Notation "'ADDM' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.add_modulo @@ (x, y, z)))%expr : expr_scope. - Notation "'RSHI' ( x , y , z )" := (ident.Z.cast _ @@ (ident.Z.rshi_concrete _ z @@ (x, y)))%expr : expr_scope. - Notation "'SELC' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.zselect @@ (x, y, z)))%expr : expr_scope. - Notation "'SELM' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.zselect @@ (Z.cast bool @@ (Z.cc_m_concrete _ @@ x), y, z)))%expr : expr_scope. - Notation "'SELL' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.zselect @@ (Z.cast bool @@ (Z.land 1 @@ x), y, z)))%expr : expr_scope. -End PrintingNotations. - -(* -Notation "a ∈ b" := (ZRange.type.is_bounded_by b%zrange a = true) (at level 10) : type_scope. -Notation Interp := (expr.Interp _). -Notation "'ℤ'" := (type.type_primitive type.Z). -Set Printing Width 70. -Goal False. - let rop' := Reify (fun v1v2 : Z * Z => fst v1v2 + snd v1v2) in - pose rop' as rop. - pose (@Pipeline.BoundsPipeline_full - false (fun v => Some v) (type.Z * type.Z) type.Z - rop - (r[0~>10], r[0~>10])%zrange - r[0~>20]%zrange - ) as E. - simple refine (let Ev := _ in - let compiler_outputs_Ev : E = Pipeline.Success Ev := _ in - _); [ shelve | .. ]; revgoals. - clearbody compiler_outputs_Ev. - refine (let H' := - (fun H'' => - @Pipeline.BoundsPipeline_full_correct - _ _ - H'' _ _ _ _ _ _ compiler_outputs_Ev) _ - in _); - clearbody H'. - Focus 2. - { cbv [Pipeline.BoundsPipeline_full] in E. - remember (Pipeline.PrePipeline rop) as cache eqn:Hcache in (value of E). - lazy in Hcache. - subst cache. - lazy in E. - subst E Ev; reflexivity. - } Unfocus. - cbv [rop] in H'; cbn [expr.Interp expr.interp for_reification.ident.interp] in H'. -(* - H' : forall arg : type.interp (ℤ * ℤ), - arg ∈ (r[0 ~> 10], r[0 ~> 10]) -> - (Interp Ev arg) ∈ r[0 ~> 20] /\ - Interp Ev arg = fst arg + snd arg -*) -Abort. -*) - -Time Compute - (Pipeline.BoundsPipeline_full - true (relax_zrange_gen [64; 128]) - ltac:(let r := Reify (to_associational (weight 51 1) 5) in - exact r) - ZRange.type.option.None ZRange.type.option.None). - -(* N.B. When the uncurrying PR lands, we will no longer need to - manually uncurry this function example before reification *) -Time Compute - (Pipeline.BoundsPipeline_full - true (relax_zrange_gen [64; 128]) - ltac:(let r := Reify (fun '(x, y) => scmul (weight 51 1) 5 x y) in - exact r) - ZRange.type.option.None ZRange.type.option.None). - -Module X25519_64. - Definition n := 5%nat. - Definition s := 2^255. - Definition c := [(1, 19)]. - Definition machine_wordsize := 64. - - Derive base_51_relax - SuchThat (rrelax_correctT n s c machine_wordsize base_51_relax) - As base_51_relax_correct. - Proof. Time solve_rrelax machine_wordsize. Time Qed. - Derive base_51_carry_mul - SuchThat (rcarry_mul_correctT n s c machine_wordsize base_51_carry_mul) - As base_51_carry_mul_correct. - Proof. Time solve_rcarry_mul machine_wordsize. Time Qed. - Derive base_51_carry - SuchThat (rcarry_correctT n s c machine_wordsize base_51_carry) - As base_51_carry_correct. - Proof. Time solve_rcarry machine_wordsize. Time Qed. - Derive base_51_add - SuchThat (radd_correctT n s c machine_wordsize base_51_add) - As base_51_add_correct. - Proof. Time solve_radd machine_wordsize. Time Qed. - Derive base_51_sub - SuchThat (rsub_correctT n s c machine_wordsize base_51_sub) - As base_51_sub_correct. - Proof. Time solve_rsub machine_wordsize. Time Qed. - Derive base_51_opp - SuchThat (ropp_correctT n s c machine_wordsize base_51_opp) - As base_51_opp_correct. - Proof. Time solve_ropp machine_wordsize. Time Qed. - Derive base_51_encode - SuchThat (rencode_correctT n s c machine_wordsize base_51_encode) - As base_51_encode_correct. - Proof. Time solve_rencode machine_wordsize. Time Qed. - Derive base_51_zero - SuchThat (rzero_correctT n s c machine_wordsize base_51_zero) - As base_51_zero_correct. - Proof. Time solve_rzero machine_wordsize. Time Qed. - Derive base_51_one - SuchThat (rone_correctT n s c machine_wordsize base_51_one) - As base_51_one_correct. - Proof. Time solve_rone machine_wordsize. Time Qed. - Lemma base_51_curve_good - : check_args n s c machine_wordsize (Pipeline.Success tt) = Pipeline.Success tt. - Proof. vm_compute; reflexivity. Qed. - - Definition base_51_good : GoodT n s c - := Good n s c machine_wordsize - base_51_curve_good - base_51_carry_mul_correct - base_51_carry_correct - base_51_relax_correct - base_51_add_correct - base_51_sub_correct - base_51_opp_correct - base_51_zero_correct - base_51_one_correct - base_51_encode_correct. - - Print Assumptions base_51_good. - Import PrintingNotations. - Set Printing Width 80. - Print base_51_carry_mul. -(*base_51_carry_mul = -fun var : type -> Type => -(λ x : var - (type.list (type.type_primitive type.Z) * - type.list (type.type_primitive type.Z))%ctype, - expr_let x0 := x₁ [[0]] *₁₂₈ x₂ [[0]] +₁₂₈ - (x₁ [[1]] *₁₂₈ (19 * (uint64)(x₂[[4]])) +₁₂₈ - (x₁ [[2]] *₁₂₈ (19 * (uint64)(x₂[[3]])) +₁₂₈ - (x₁ [[3]] *₁₂₈ (19 * (uint64)(x₂[[2]])) +₁₂₈ - x₁ [[4]] *₁₂₈ (19 * (uint64)(x₂[[1]]))))) in - expr_let x1 := (uint64)(x0 >> 51) +₁₂₈ - (x₁ [[0]] *₁₂₈ x₂ [[1]] +₁₂₈ - (x₁ [[1]] *₁₂₈ x₂ [[0]] +₁₂₈ - (x₁ [[2]] *₁₂₈ (19 * (uint64)(x₂[[4]])) +₁₂₈ - (x₁ [[3]] *₁₂₈ (19 * (uint64)(x₂[[3]])) +₁₂₈ - x₁ [[4]] *₁₂₈ (19 * (uint64)(x₂[[2]])))))) in - expr_let x2 := (uint64)(x1 >> 51) +₁₂₈ - (x₁ [[0]] *₁₂₈ x₂ [[2]] +₁₂₈ - (x₁ [[1]] *₁₂₈ x₂ [[1]] +₁₂₈ - (x₁ [[2]] *₁₂₈ x₂ [[0]] +₁₂₈ - (x₁ [[3]] *₁₂₈ (19 * (uint64)(x₂[[4]])) +₁₂₈ - x₁ [[4]] *₁₂₈ (19 * (uint64)(x₂[[3]])))))) in - expr_let x3 := (uint64)(x2 >> 51) +₁₂₈ - (x₁ [[0]] *₁₂₈ x₂ [[3]] +₁₂₈ - (x₁ [[1]] *₁₂₈ x₂ [[2]] +₁₂₈ - (x₁ [[2]] *₁₂₈ x₂ [[1]] +₁₂₈ - (x₁ [[3]] *₁₂₈ x₂ [[0]] +₁₂₈ - x₁ [[4]] *₁₂₈ (19 * (uint64)(x₂[[4]])))))) in - expr_let x4 := (uint64)(x3 >> 51) +₁₂₈ - (x₁ [[0]] *₁₂₈ x₂ [[4]] +₁₂₈ - (x₁ [[1]] *₁₂₈ x₂ [[3]] +₁₂₈ - (x₁ [[2]] *₁₂₈ x₂ [[2]] +₁₂₈ - (x₁ [[3]] *₁₂₈ x₂ [[1]] +₁₂₈ x₁ [[4]] *₁₂₈ x₂ [[0]])))) in - expr_let x5 := ((uint64)(x0) & 2251799813685247) +₆₄ 19 *₆₄ (uint64)(x4 >> 51) in - expr_let x6 := (uint64)(x5 >> 51) +₆₄ ((uint64)(x1) & 2251799813685247) in - ((uint64)(x5) & 2251799813685247) - :: ((uint64)(x6) & 2251799813685247) - :: (uint64)(x6 >> 51) +₆₄ ((uint64)(x2) & 2251799813685247) - :: ((uint64)(x3) & 2251799813685247) - :: ((uint64)(x4) & 2251799813685247) :: [])%expr - : Expr - (type.uncurry - (type.list (type.type_primitive type.Z) -> - type.list (type.type_primitive type.Z) -> - type.list (type.type_primitive type.Z))) -*) - Print base_51_sub. - (* -base_51_sub = -fun var : type -> Type => -(λ x : var - (type.list (type.type_primitive type.Z) * - type.list (type.type_primitive type.Z))%ctype, - (4503599627370458 + (uint64)(x₁[[0]])) -₆₄ x₂ [[0]] - :: (4503599627370494 + (uint64)(x₁[[1]])) -₆₄ x₂ [[1]] - :: (4503599627370494 + (uint64)(x₁[[2]])) -₆₄ x₂ [[2]] - :: (4503599627370494 + (uint64)(x₁[[3]])) -₆₄ x₂ [[3]] - :: (4503599627370494 + (uint64)(x₁[[4]])) -₆₄ x₂ [[4]] :: [])%expr - : Expr - (type.uncurry - (type.list (type.type_primitive type.Z) -> - type.list (type.type_primitive type.Z) -> - type.list (type.type_primitive type.Z))) -*) -End X25519_64. - -(** TODO: factor out bounds analysis pipeline as a single definition / proof *) -(** TODO: factor out apply one argument in the fst of a pair *) -(** TODO: write a definition that specializes the PHOAS thing and composes with bounds analysis, + proof *) -(** TODO: write wrappers for (string) prime -> reified arguments *) -(** TODO: write indexed AST interp that returns default value, prove that given correctness for specialized thing, the defaulting interp is totally equal to the original operation *) -(** TODO: write a lemma that for things equal to all operations using defaulting interp, we get a ring isomorphic to F m *) -(** TODO: compose with stringification + wrappers for prime, extract to OCaml/Haskell *) -(** TODO: proofs *) -(* -Module X25519_32. - Definition n := 10%nat. - Definition s := 2^255. - Definition c := [(1, 19)]. - Definition machine_wordsize := 32. - - Derive base_25p5_sub - SuchThat (rsub_correctT n s c machine_wordsize base_25p5_sub) - As base_25p5_sub_correct. - Proof. Time solve_rsub machine_wordsize. Time Qed. - - Derive base_25p5_carry_mul - SuchThat (rcarry_mul_correctT n s c machine_wordsize base_25p5_carry_mul) - As base_25p5_carry_mul_correct. - Proof. Time solve_rcarry_mul machine_wordsize. Time Qed. - - Import PrintingNotations. - Print base_25p5_carry_mul. -(* -base_25p5_carry_mul = -fun var : type -> Type => -(λ x : var (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z))%ctype, - expr_let x0 := x₁ [[0]] *₆₄ x₂ [[0]] +₆₄ - ((uint64)(x₁ [[1]] *₆₄ (19 * (uint32)(x₂[[9]])) << 1) +₆₄ - (x₁ [[2]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄ - ((uint64)(x₁ [[3]] *₆₄ (19 * (uint32)(x₂[[7]])) << 1) +₆₄ - (x₁ [[4]] *₆₄ (19 * (uint32)(x₂[[6]])) +₆₄ - ((uint64)(x₁ [[5]] *₆₄ (19 * (uint32)(x₂[[5]])) << 1) +₆₄ - (x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[4]])) +₆₄ - ((uint64)(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[3]])) << 1) +₆₄ - (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[2]])) +₆₄ (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[1]])) << 1))))))))) in - expr_let x1 := (uint64)(x0 >> 26) +₆₄ - (x₁ [[0]] *₆₄ x₂ [[1]] +₆₄ - (x₁ [[1]] *₆₄ x₂ [[0]] +₆₄ - (x₁ [[2]] *₆₄ (19 * (uint32)(x₂[[9]])) +₆₄ - (x₁ [[3]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄ - (x₁ [[4]] *₆₄ (19 * (uint32)(x₂[[7]])) +₆₄ - (x₁ [[5]] *₆₄ (19 * (uint32)(x₂[[6]])) +₆₄ - (x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[5]])) +₆₄ - (x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[4]])) +₆₄ - (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[3]])) +₆₄ x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[2]]))))))))))) in - expr_let x2 := (uint64)(x1 >> 25) +₆₄ - (x₁ [[0]] *₆₄ x₂ [[2]] +₆₄ - ((uint64)(x₁ [[1]] *₆₄ x₂ [[1]] << 1) +₆₄ - (x₁ [[2]] *₆₄ x₂ [[0]] +₆₄ - ((uint64)(x₁ [[3]] *₆₄ (19 * (uint32)(x₂[[9]])) << 1) +₆₄ - (x₁ [[4]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄ - ((uint64)(x₁ [[5]] *₆₄ (19 * (uint32)(x₂[[7]])) << 1) +₆₄ - (x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[6]])) +₆₄ - ((uint64)(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[5]])) << 1) +₆₄ - (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[4]])) +₆₄ (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[3]])) << 1)))))))))) in - expr_let x3 := (uint64)(x2 >> 26) +₆₄ - (x₁ [[0]] *₆₄ x₂ [[3]] +₆₄ - (x₁ [[1]] *₆₄ x₂ [[2]] +₆₄ - (x₁ [[2]] *₆₄ x₂ [[1]] +₆₄ - (x₁ [[3]] *₆₄ x₂ [[0]] +₆₄ - (x₁ [[4]] *₆₄ (19 * (uint32)(x₂[[9]])) +₆₄ - (x₁ [[5]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄ - (x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[7]])) +₆₄ - (x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[6]])) +₆₄ - (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[5]])) +₆₄ x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[4]]))))))))))) in - expr_let x4 := (uint64)(x3 >> 25) +₆₄ - (x₁ [[0]] *₆₄ x₂ [[4]] +₆₄ - ((uint64)(x₁ [[1]] *₆₄ x₂ [[3]] << 1) +₆₄ - (x₁ [[2]] *₆₄ x₂ [[2]] +₆₄ - ((uint64)(x₁ [[3]] *₆₄ x₂ [[1]] << 1) +₆₄ - (x₁ [[4]] *₆₄ x₂ [[0]] +₆₄ - ((uint64)(x₁ [[5]] *₆₄ (19 * (uint32)(x₂[[9]])) << 1) +₆₄ - (x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄ - ((uint64)(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[7]])) << 1) +₆₄ - (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[6]])) +₆₄ (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[5]])) << 1)))))))))) in - expr_let x5 := (uint64)(x4 >> 26) +₆₄ - (x₁ [[0]] *₆₄ x₂ [[5]] +₆₄ - (x₁ [[1]] *₆₄ x₂ [[4]] +₆₄ - (x₁ [[2]] *₆₄ x₂ [[3]] +₆₄ - (x₁ [[3]] *₆₄ x₂ [[2]] +₆₄ - (x₁ [[4]] *₆₄ x₂ [[1]] +₆₄ - (x₁ [[5]] *₆₄ x₂ [[0]] +₆₄ - (x₁ [[6]] *₆₄ (19 * (uint32)(x₂[[9]])) +₆₄ - (x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄ - (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[7]])) +₆₄ x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[6]]))))))))))) in - expr_let x6 := (uint64)(x5 >> 25) +₆₄ - (x₁ [[0]] *₆₄ x₂ [[6]] +₆₄ - ((uint64)(x₁ [[1]] *₆₄ x₂ [[5]] << 1) +₆₄ - (x₁ [[2]] *₆₄ x₂ [[4]] +₆₄ - ((uint64)(x₁ [[3]] *₆₄ x₂ [[3]] << 1) +₆₄ - (x₁ [[4]] *₆₄ x₂ [[2]] +₆₄ - ((uint64)(x₁ [[5]] *₆₄ x₂ [[1]] << 1) +₆₄ - (x₁ [[6]] *₆₄ x₂ [[0]] +₆₄ - ((uint64)(x₁ [[7]] *₆₄ (19 * (uint32)(x₂[[9]])) << 1) +₆₄ - (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[8]])) +₆₄ (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[7]])) << 1)))))))))) in - expr_let x7 := (uint64)(x6 >> 26) +₆₄ - (x₁ [[0]] *₆₄ x₂ [[7]] +₆₄ - (x₁ [[1]] *₆₄ x₂ [[6]] +₆₄ - (x₁ [[2]] *₆₄ x₂ [[5]] +₆₄ - (x₁ [[3]] *₆₄ x₂ [[4]] +₆₄ - (x₁ [[4]] *₆₄ x₂ [[3]] +₆₄ - (x₁ [[5]] *₆₄ x₂ [[2]] +₆₄ - (x₁ [[6]] *₆₄ x₂ [[1]] +₆₄ - (x₁ [[7]] *₆₄ x₂ [[0]] +₆₄ (x₁ [[8]] *₆₄ (19 * (uint32)(x₂[[9]])) +₆₄ x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[8]]))))))))))) in - expr_let x8 := (uint64)(x7 >> 25) +₆₄ - (x₁ [[0]] *₆₄ x₂ [[8]] +₆₄ - ((uint64)(x₁ [[1]] *₆₄ x₂ [[7]] << 1) +₆₄ - (x₁ [[2]] *₆₄ x₂ [[6]] +₆₄ - ((uint64)(x₁ [[3]] *₆₄ x₂ [[5]] << 1) +₆₄ - (x₁ [[4]] *₆₄ x₂ [[4]] +₆₄ - ((uint64)(x₁ [[5]] *₆₄ x₂ [[3]] << 1) +₆₄ - (x₁ [[6]] *₆₄ x₂ [[2]] +₆₄ - ((uint64)(x₁ [[7]] *₆₄ x₂ [[1]] << 1) +₆₄ - (x₁ [[8]] *₆₄ x₂ [[0]] +₆₄ (uint64)(x₁ [[9]] *₆₄ (19 * (uint32)(x₂[[9]])) << 1)))))))))) in - expr_let x9 := (uint64)(x8 >> 26) +₆₄ - (x₁ [[0]] *₆₄ x₂ [[9]] +₆₄ - (x₁ [[1]] *₆₄ x₂ [[8]] +₆₄ - (x₁ [[2]] *₆₄ x₂ [[7]] +₆₄ - (x₁ [[3]] *₆₄ x₂ [[6]] +₆₄ - (x₁ [[4]] *₆₄ x₂ [[5]] +₆₄ - (x₁ [[5]] *₆₄ x₂ [[4]] +₆₄ - (x₁ [[6]] *₆₄ x₂ [[3]] +₆₄ (x₁ [[7]] *₆₄ x₂ [[2]] +₆₄ (x₁ [[8]] *₆₄ x₂ [[1]] +₆₄ x₁ [[9]] *₆₄ x₂ [[0]]))))))))) in - expr_let x10 := ((uint32)(x0) & 67108863) +₆₄ 19 *₆₄ (uint64)(x9 >> 25) in - expr_let x11 := (uint32)(x10 >> 26) +₃₂ ((uint32)(x1) & 33554431) in - ((uint32)(x10) & 67108863) - :: ((uint32)(x11) & 33554431) - :: (uint32)(x11 >> 25) +₃₂ ((uint32)(x2) & 67108863) - :: ((uint32)(x3) & 33554431) - :: ((uint32)(x4) & 67108863) - :: ((uint32)(x5) & 33554431) - :: ((uint32)(x6) & 67108863) - :: ((uint32)(x7) & 33554431) :: ((uint32)(x8) & 67108863) :: ((uint32)(x9) & 33554431) :: [])%expr - : Expr - (type.uncurry - (type.list (type.type_primitive type.Z) -> - type.list (type.type_primitive type.Z) -> - type.list (type.type_primitive type.Z))) - *) - Print base_25p5_sub. - (* -base_25p5_sub = -fun var : type -> Type => -(λ x : var - (type.list (type.type_primitive type.Z) * - type.list (type.type_primitive type.Z))%ctype, - (134217690 + (uint32)(x₁[[0]])) -₃₂ x₂ [[0]] - :: (67108862 + (uint32)(x₁[[1]])) -₃₂ x₂ [[1]] - :: (134217726 + (uint32)(x₁[[2]])) -₃₂ x₂ [[2]] - :: (67108862 + (uint32)(x₁[[3]])) -₃₂ x₂ [[3]] - :: (134217726 + (uint32)(x₁[[4]])) -₃₂ x₂ [[4]] - :: (67108862 + (uint32)(x₁[[5]])) -₃₂ x₂ [[5]] - :: (134217726 + (uint32)(x₁[[6]])) -₃₂ x₂ [[6]] - :: (67108862 + (uint32)(x₁[[7]])) -₃₂ x₂ [[7]] - :: (134217726 + (uint32)(x₁[[8]])) -₃₂ x₂ [[8]] - :: (67108862 + (uint32)(x₁[[9]])) -₃₂ x₂ [[9]] :: [])%expr - : Expr - (type.uncurry - (type.list (type.type_primitive type.Z) -> - type.list (type.type_primitive type.Z) -> - type.list (type.type_primitive type.Z))) -*) -End X25519_32. - *) - -Module Straightline. - Module expr. - (* TODO: move these to a better location *) - Module type. - Definition primitive_eq_dec (a b : type.primitive) : {a = b} + {a <> b}. - Proof. destruct a,b; auto; right; congruence. Defined. - Fixpoint type_eq_dec (a b : type) : {a = b} + {a <> b}. - Proof. - destruct a, b; try solve [right; congruence]; [ | | | ]. - { destruct (primitive_eq_dec p p0); subst; [left | right]; congruence. } - { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence]. - left; congruence. } - { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence]. - left; congruence. } - { destruct (type_eq_dec a b); [left | right]; congruence. } - Defined. - End type. - - Section with_var. - Context {var : type.type -> Type}. - Context {dummy_arrow : forall s d, var (s -> d)}. (* TODO: remove once arrow-containing pairs are removed at type level *) - - Let uexpr t := @Uncurried.expr.expr ident.ident var t. - - Section with_ident. - Context {ident : type.type -> type.type -> Type}. - Inductive scalar : type.type -> Type := - | Var t : var t -> scalar t - | TT : scalar (type.type_primitive type.unit) - | Nil t : scalar (type.list t) - | Pair {a b} : scalar a -> scalar b -> scalar (a * b) - | Cast : zrange -> scalar type.Z -> scalar type.Z - | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z) - | Fst {a b} : scalar (a * b) -> scalar a - | Snd {a b} : scalar (a * b) -> scalar b - | Shiftr : Z -> scalar type.Z -> scalar type.Z - | Shiftl : Z -> scalar type.Z -> scalar type.Z - | Land : Z -> scalar type.Z -> scalar type.Z - | CC_m : Z -> scalar type.Z -> scalar type.Z - | Primitive {t} : type.interp (type.type_primitive t) -> scalar t - . - - Inductive expr : type.type -> Type := - | Scalar {t} : scalar t -> expr t - | LetInAppIdentZ {s d} : zrange -> ident s (type.Z) -> scalar s -> (var (type.Z) -> expr d) -> expr d - | LetInAppIdentZZ {s d} : zrange * zrange -> ident s (type.Z*type.Z) -> scalar s -> (var (type.Z*type.Z) -> expr d) -> expr d - . - - Fixpoint dummy_scalar t : scalar t := - match t with - | type.type_primitive p => Primitive (@DefaultValue.type.primitive.default p) - | type.prod A B => Pair (dummy_scalar A) (dummy_scalar B) - | type.arrow A B => Var _ (dummy_arrow A B) - | type.list A => Nil A - end. - - Definition dummy t : expr t := Scalar (dummy_scalar t). - End with_ident. - - Definition of_uncurried_scalar_ident {s d} (idc : ident.ident s d) - : scalar s -> option (scalar d) := - match idc in ident.ident s d return scalar s -> option (scalar d) with - | ident.Z.cast r => fun args => Some (Cast r args) - | ident.Z.cast2 r => fun args => Some (Cast2 r args) - | @ident.fst A B => fun args => Some (Fst args) - | @ident.snd A B => fun args => Some (Snd args) - | ident.Z.shiftr n => fun args => Some (Shiftr n args) - | ident.Z.shiftl n => fun args => Some (Shiftl n args) - | ident.Z.land n => fun args => Some (Land n args) - | ident.Z.cc_m_concrete s => fun args => Some (CC_m s args) - | @ident.primitive p x => fun _ => Some (Primitive x) - | _ => fun _ => None - end. - - Fixpoint of_uncurried_scalar {t} (e : uexpr t) : option (scalar t) := - match e in Uncurried.expr.expr t return option (scalar t) with - | expr.Var t v as e => Some (Var t v) - | expr.TT as e => Some TT - | expr.Pair A B a b - => match of_uncurried_scalar a, of_uncurried_scalar b with - | Some x, Some y => Some (Pair x y) - | _, _ => None - end - | expr.AppIdent _ _ idc args - => match of_uncurried_scalar args with - | Some x => of_uncurried_scalar_ident idc x - | None => None - end - | _ => None - end. - - Fixpoint range_type t : Type := - match t with - | type.type_primitive type.Z => zrange - | type.prod x y => range_type x * range_type y - | _ => unit - end. - - Definition invert_cast {t} (e : uexpr t) - : option (range_type t * uexpr t) := - match invert_AppIdent e with - | Some (existT s (idc, x)) => - (match idc in ident.ident s t return uexpr s -> option (range_type t * uexpr t) with - | ident.Z.cast r => fun x => Some (r, x) - | ident.Z.cast2 r => fun x => Some (r, x) - | _ => fun _ => None - end) x - | None => None - end. - - (* ident.Let_In @@ (cast r x) => r, x *) - Definition invert_LetInCast {tx tC} (args : uexpr (tx * (tx -> tC))) - : option (range_type tx * uexpr tx * uexpr (tx -> tC)) := - match invert_Pair args with - | Some (x, e) => - match invert_cast x with - | Some (r, x') => Some (r, x', e) - | None => None - end - | None => None - end. - - Definition invert_LetInAppIdent {tx tC} (args : uexpr (tx * (tx -> tC))) - : option { s : type.type & (range_type tx * ident.ident s tx * scalar s * (var tx -> uexpr tC))%type } := - match invert_LetInCast args with - | Some (r, x, e) => - match invert_AppIdent x with - | Some (existT s idc_x') => - match of_uncurried_scalar (snd idc_x') with - | Some x'' => - match invert_Abs e with - | Some k => Some (existT _ s (r, fst idc_x', x'', k)) - | None => None - end - | None => None - end - | None => None - end - | None => None - end. - - Definition mk_LetInAppIdent {s d t} (default : expr t) - : range_type d -> ident.ident s d -> scalar s -> (var d -> expr t) -> expr t := - match d as d0 return range_type d0 -> ident.ident s d0 -> scalar s -> (var d0 -> expr t) -> expr t with - | type.type_primitive type.Z => - fun r idc x k => @LetInAppIdentZ ident.ident s t r idc x k - | type.prod type.Z type.Z => - fun r idc x k => @LetInAppIdentZZ ident.ident s t r idc x k - | _ => fun _ _ _ _ => default - end. - - Definition of_uncurried_ident - (of_uncurried : forall t, uexpr t -> expr t) - {s d} (idc : ident.ident s d) - : uexpr s -> expr d -> expr d := - match idc in ident.ident s d return uexpr s -> expr d -> expr d with - | ident.Let_In tx tC => - fun args default => - match invert_LetInAppIdent args return expr tC with - | Some (existT s (r, idc, x, k)) => - @mk_LetInAppIdent s tx tC default r idc x (fun y : var tx => of_uncurried _ (k y)) - | None => default - end - | ident.Z.cast r => - fun (args : uexpr _) default => - match invert_AppIdent args with - | Some (existT s idc_x') => - match of_uncurried_scalar (snd idc_x') with - | Some x'' => - @mk_LetInAppIdent s type.Z type.Z default r (fst idc_x') x'' (fun y => Scalar (Var _ y)) - | None => default - end - | None => default - end - | ident.Z.cast2 r => - fun (args : uexpr _) default => - match invert_AppIdent args with - | Some (existT s idc_x') => - match of_uncurried_scalar (snd idc_x') with - | Some x'' => - @mk_LetInAppIdent s (type.Z*type.Z) (type.Z*type.Z) default r (fst idc_x') x'' (fun y => Scalar (Var _ y)) - | None => default - end - | None => default - end - | _ => fun _ default => default - end. - - Definition of_uncurried_step {t} (e : uexpr t) - (of_uncurried : forall t, uexpr t -> expr t) - : expr t -> expr t := - match e in Uncurried.expr.expr t return expr t -> expr t with - | AppIdent s d idc args => - fun default => - of_uncurried_ident of_uncurried idc args - (match of_uncurried_scalar (AppIdent idc args) with - | Some s => Scalar s - | None => default - end) - | _ as e => - (fun default => - match of_uncurried_scalar e with - | Some s => Scalar s - | None => default - end) - end. - - (* TODO : uses fuel; ideally want a cleaner termination proof *) - Fixpoint of_uncurried (fuel : nat) {t} (e : uexpr t) - : expr t := - match fuel with - | S fuel' => of_uncurried_step e (@of_uncurried fuel') (dummy t) - | O => dummy t - end. - End with_var. - - Section depth. - Context (var : type -> Type) (dummy_var : forall t, var t). - Fixpoint depth {t} (e : @Uncurried.expr.expr ident var t) : nat := - match e with - | Uncurried.expr.Var _ _ => 1 - | Uncurried.expr.TT => 1 - | Uncurried.expr.AppIdent _ _ idc args => S (depth args) - | Uncurried.expr.App _ _ f x => S (Nat.max (depth f) (depth x)) - | Uncurried.expr.Pair _ _ x y => S (Nat.max (depth x) (depth y)) - | Uncurried.expr.Abs _ _ f => S (depth (f (dummy_var _))) - end. - - Definition Expr_depth {t} (e : Expr t) : nat := depth (e _). - End depth. - - Section interp. - Context {ident : type -> type -> Type} {interp_ident : forall s d, ident s d -> type.interp s -> type.interp d}. - Context {interp_cast : zrange -> Z -> Z}. - - Definition interp_cast2 (r : zrange * zrange) (x : Z * Z) : Z * Z := - (interp_cast (fst r) (fst x), interp_cast (snd r) (snd x)). - - Fixpoint interp_scalar {t} (s : @scalar type.interp t) : type.interp t := - match s with - | Var t v => v - | TT => tt - | Nil _ => [] - | Pair _ _ x y => (interp_scalar x, interp_scalar y) - | Cast r x => interp_cast r (interp_scalar x) - | Cast2 r x => interp_cast2 r (interp_scalar x) - | Fst _ _ p => fst (interp_scalar p) - | Snd _ _ p => snd (interp_scalar p) - | Shiftr n x => Z.shiftr (interp_scalar x) n - | Shiftl n x => Z.shiftl (interp_scalar x) n - | Land n x => Z.land (interp_scalar x) n - | CC_m n x => Z.cc_m n (interp_scalar x) - | Primitive _ x => x - end. - - Fixpoint interp {t} (e : @expr type.interp ident t) : type.interp t := - match e with - | Scalar _ s => interp_scalar s - | LetInAppIdentZ _ _ r idc x f => - interp (f (interp_cast r (interp_ident _ _ idc (interp_scalar x)))) - | LetInAppIdentZZ _ _ r idc x f => - interp (f (interp_cast2 r (interp_ident _ _ idc (interp_scalar x)))) - end. - End interp. - - Section proofs. - Local Notation straightline_interp := (expr.interp (ident:=default.ident) (interp_ident:=@ident.interp) (interp_cast:=ident.cast (@ident.cast_outside_of_range))). - Local Notation uinterp := (Uncurried.expr.interp (@ident.interp)). - Local Notation uexpr := (@Uncurried.expr.expr ident type.interp). - Local Notation interp_scalar := (interp_scalar (interp_cast:=ident.cast (@ident.cast_outside_of_range))). - - Inductive ok_scalar_ident : forall {s d}, ident.ident s d -> Prop := - | ok_si_cast : forall r, ok_scalar_ident (ident.Z.cast r) - | ok_si_cast2 : forall r, ok_scalar_ident (ident.Z.cast2 r) - | ok_si_fst : forall A B, ok_scalar_ident (@ident.fst A B) - | ok_si_snd : forall A B, ok_scalar_ident (@ident.snd A B) - | ok_si_shiftr : forall n, ok_scalar_ident (@ident.Z.shiftr n) - | ok_si_shiftl : forall n, ok_scalar_ident (@ident.Z.shiftl n) - | ok_si_land : forall n, ok_scalar_ident (@ident.Z.land n) - | ok_si_cc_m : forall n, ok_scalar_ident (@ident.Z.cc_m_concrete n) - | ok_prim : forall p x, ok_scalar_ident (@ident.primitive p x) - . - - Inductive ok_scalar: forall {t}, uexpr t -> Prop := - | ok_Var : forall t v, @ok_scalar t (Uncurried.expr.Var v) - | ok_TT : ok_scalar Uncurried.expr.TT - | ok_AppIdent : - forall s d idc args, - ok_scalar args -> - @ok_scalar_ident s d idc -> - ok_scalar (AppIdent idc args) - | ok_Pair : - forall A B a b, - @ok_scalar A a -> - @ok_scalar B b -> - ok_scalar (Uncurried.expr.Pair a b) - . - - Inductive ok_expr : forall {t}, uexpr t -> Prop := - | ok_LetInAppIdentZ : - forall tC r s (idc : ident s type.Z) x k, - ok_scalar x -> (forall y, @ok_expr tC (k y)) -> - @ok_expr tC (AppIdent (@ident.Let_In _ tC) (Uncurried.expr.Pair (AppIdent (ident.Z.cast r) (AppIdent idc x)) (Abs k))) - | ok_LetInAppIdentZZ : - forall tC r s (idc : ident s (type.prod type.Z type.Z)) x k, - ok_scalar x -> (forall y, @ok_expr tC (k y)) -> - @ok_expr tC (AppIdent (@ident.Let_In _ tC) (Uncurried.expr.Pair (AppIdent (ident.Z.cast2 r) (AppIdent idc x)) (Abs k))) - | ok_scalar_cast : - forall r s (idc : ident s _) x, - ok_scalar x -> - @ok_expr type.Z (AppIdent (ident.Z.cast r) (AppIdent idc x)) - | ok_scalar_cast2 : - forall r s (idc : ident s _) x, - ok_scalar x -> - @ok_expr (type.prod type.Z type.Z) (AppIdent (ident.Z.cast2 r) (AppIdent idc x)) - | ok_scalar_nocast : - forall t x, @ok_scalar t x -> @ok_expr t x - . - - Lemma interp_cast_correct r (x : uexpr type.Z) : - ident.cast ident.cast_outside_of_range r (uinterp x) = uinterp (AppIdent (ident.Z.cast r) x). - Proof. reflexivity. Qed. - - Lemma interp_cast2_correct r (x : uexpr (type.prod type.Z type.Z)) : - @interp_cast2 (ident.cast ident.cast_outside_of_range) r (uinterp x) = uinterp (AppIdent (ident.Z.cast2 r) x). - Proof. cbn; break_match; reflexivity. Qed. - - Ltac invert H := - inversion H; subst; - repeat match goal with - | H : existT _ _ _ = existT _ _ _ |- _ => apply (Eqdep_dec.inj_pair2_eq_dec _ type.type_eq_dec) in H; subst - end. - - Ltac invert_ok_expr := - match goal with H : ok_expr _ |- _ => invert H end. - Ltac invert_ok_scalar := - match goal with H : ok_scalar _ |- _ => invert H end. - Ltac invert_ok_scalar_ident := - match goal with H : ok_scalar_ident _ |- _ => invert H end. - Ltac simpl_inversions := - cbn [invert_LetInAppIdent invert_LetInCast invert_Pair invert_cast invert_AppIdent invert_Abs]. - - Lemma invert_AppIdent_correct {d} (e : uexpr d) x p : - invert_AppIdent e = Some (existT (fun s : type => (ident s d * default.expr s)%type) x p) -> - e = AppIdent (fst p) (snd p). - Proof. - cbv [invert_AppIdent]. - break_match; try discriminate. - intro H; invert H. reflexivity. - Qed. - - Lemma depth_positive {var t} dummy_var (e : Uncurried.expr.expr t) : 0 < depth var dummy_var e. - Proof. destruct e; cbn [depth]; rewrite Nat2Z.inj_succ; omega. Qed. - - Lemma of_uncurried_scalar_ident_correct {s d} (idc : ident s d) args args': - ok_scalar_ident idc -> - of_uncurried_scalar args = Some args' -> - interp_scalar args' = uinterp args -> - exists s, - of_uncurried_scalar_ident idc args' = Some s - /\ interp_scalar s = uinterp (AppIdent idc args). - Proof. - destruct 1; intros; - repeat match goal with - | _ => eexists; split; [ reflexivity | cbn [interp_scalar] ] - | H : interp_scalar _ = _ |- _ => rewrite H - | _ => reflexivity - | _ => solve [auto using interp_cast2_correct] - | |- context [@Uncurried.expr.interp _ _ (type.type_primitive _)] => - cbn; break_match; reflexivity - end. - Qed. - - Lemma of_uncurried_scalar_correct {t} (e : uexpr t) : - ok_scalar e -> - exists s, - of_uncurried_scalar e = Some s - /\ interp_scalar s = uinterp e. - Proof. - induction 1; cbn [of_uncurried_scalar]; intros; - repeat match goal with - | _ => progress cbn [interp_scalar] - | IH : exists _, _ /\ _ |- _ => destruct IH as [? [? ?] ] - | H : of_uncurried_scalar _ = _ |- _ => rewrite H - | H : interp_scalar _ = _ |- _ => rewrite H - | _ => apply of_uncurried_scalar_ident_correct; solve [auto] - | _ => eexists; split; [ reflexivity | ] - | _ => reflexivity - end. - Qed. - - Ltac rewrite_ok_scalar := - match goal with H : ok_scalar _ |- _ => - let P := fresh in destruct (of_uncurried_scalar_correct _ H) as [? [P ?] ]; rewrite P in * - end; - repeat match goal with - | H : Some _ = Some _ |- _ => inversion H; progress subst - | _ => progress break_match; - match goal with | H: Some _ = Some _ |- _ => inversion H; progress subst end - end. - - Lemma of_uncurried_correct dummy_arrow fuel dummy_var : - forall {t} (e : uexpr t), - (depth _ dummy_var e <= fuel)%nat -> - ok_expr e -> - uinterp e = straightline_interp (@of_uncurried _ dummy_arrow fuel _ e). - Proof. - induction fuel; intros; [ pose proof (depth_positive dummy_var e); omega | ]. - destruct e; cbn [depth of_uncurried expr.interp interp]; intros; invert_ok_expr; - repeat match goal with - | |- context [of_uncurried_scalar _ ] => progress rewrite_ok_scalar - | _ => progress (cbn [of_uncurried_step of_uncurried_ident fst snd mk_LetInAppIdent expr.interp interp depth] in * ) - | _ => progress simpl_inversions - | _ => congruence - end; [ | | | | ]. - { - match goal with H : interp_scalar _ = _ |- _ => rewrite H end. - rewrite <-IHfuel. - { reflexivity. } - { cbn [depth] in *. - (* here we have to reason about the depth calculation for arrows; this will probably be unnecessary with new compilers setup *) - admit. } - { auto. } } - { - match goal with H : interp_scalar _ = _ |- _ => rewrite H end. - rewrite <-IHfuel. - { cbn; break_match; reflexivity. } - { cbn [depth] in *. - (* here we have to reason about the depth calculation for arrows; this will probably be unnecessary with new compilers setup *) - admit. } - { auto. } } - { - match goal with H : interp_scalar _ = _ |- _ => rewrite H end. - rewrite <-interp_cast_correct. - reflexivity. } - { - match goal with H : interp_scalar _ = _ |- _ => rewrite H end. - rewrite <-interp_cast2_correct. - cbn; break_match; reflexivity. } - { invert_ok_scalar. - rewrite <-H2. - invert_ok_scalar_ident; try reflexivity. - { match goal with H : context [of_uncurried_scalar _ = Some _ ] |- _ => cbn in H end. - rewrite_ok_scalar. - cbn [of_uncurried_ident]. - cbn [interp_scalar]. - cbn. - break_match; cbn; auto. - match goal with H : _ |- _ => apply invert_AppIdent_correct in H end. - subst. - invert_ok_scalar. - rewrite_ok_scalar. - repeat match goal with H : interp_scalar _ = _ |- _ => rewrite H end. - reflexivity. } - { match goal with H : context [of_uncurried_scalar _ = Some _ ] |- _ => cbn in H end. - rewrite_ok_scalar. - cbn [of_uncurried_ident]. - break_match; cbn; auto. - match goal with H : _ |- _ => apply invert_AppIdent_correct in H end. - subst. - invert_ok_scalar. - rewrite_ok_scalar. - repeat match goal with H : interp_scalar _ = _ |- _ => rewrite H end. - destruct r; reflexivity. } - Admitted. - End proofs. - End expr. - - Definition of_Expr {s d} (e : Expr (s->d)) (var : type -> Type) (x:var s) dummy_arrow: expr.expr d - := - match invert_Abs (e var) with - | Some f => - expr.of_uncurried (dummy_arrow:=dummy_arrow) (expr.depth (fun _ => unit) (fun _ => tt) (e (fun _ => unit))) (f x) - | None => expr.dummy (dummy_arrow:=dummy_arrow) d - end. - -End Straightline. - -Module StraightlineTest. - Definition test : Expr (type.Z -> type.Z) := - fun var => - Abs - (fun (x : var type.Z) => - AppIdent (var:=var) ident.Let_In - (Pair (AppIdent (var:=var) (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (var:=var) (ident.Z.shiftr 8) (Var x))) - (Abs (fun x : var type.Z => expr.Var x)))). - - Check eq_refl : - Straightline.of_Expr test = - fun var x _ => - Straightline.expr.LetInAppIdentZ r[0 ~> 4294967295] (ident.Z.shiftr 8) (Straightline.expr.Var _ x) - (fun x0 => Straightline.expr.Scalar (Straightline.expr.Var _ x0)). - - Definition test_mul : Expr (type.Z -> type.Z) := - fun var => - Abs - (fun (x : var type.Z) => - AppIdent (var:=var) ident.Let_In - (Pair (AppIdent (var:=var) (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (var:=var) (ident.Z.shiftr 8) (Var x))) - (Abs (fun y : var type.Z => - AppIdent ident.Let_In - (Pair (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent ident.Z.mul (Pair (AppIdent (@ident.primitive type.Z 12) TT) (Var y)))) - (Abs (fun z : var type.Z => (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (ident.Z.shiftr 3) (Var z))))) - ))))). - - Check eq_refl : - Straightline.of_Expr test_mul = - fun var x _ => - Straightline.expr.LetInAppIdentZ r[0 ~> 4294967295] (ident.Z.shiftr 8) (Straightline.expr.Var _ x) - (fun x0 => - Straightline.expr.LetInAppIdentZ r[0 ~> 4294967295] ident.Z.mul - (Straightline.expr.Pair (Straightline.expr.Primitive (t:=type.Z) 12) (Straightline.expr.Var _ x0)) - (fun x1 => - Straightline.expr.LetInAppIdentZ r[0 ~> 4294967295] (ident.Z.shiftr 3) - (Straightline.expr.Var _ x1) - (fun x2 => Straightline.expr.Scalar (Straightline.expr.Var _ x2)))). - - Definition test_selm : Expr (type.Z -> type.Z) := - fun var => - Abs (fun x : var type.Z => - AppIdent (var:=var) ident.Let_In - (Pair (AppIdent (var:=var) (ident.Z.cast r[0~>4294967295]%zrange) - (AppIdent (var:=var) ident.Z.zselect - (Pair - (Pair - (AppIdent (var:=var) (ident.Z.cast r[0~>1]%zrange) - (AppIdent (var:=var) (ident.Z.cc_m_concrete 4294967296) - (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (Var x)))) - (AppIdent (@ident.primitive type.Z 0) TT)) - (AppIdent (@ident.primitive type.Z 100) TT)))) - (Abs (fun z : var type.Z => Var z)))). - - Check eq_refl : - Straightline.of_Expr test_selm = - fun var x _ => - Straightline.expr.LetInAppIdentZ r[0 ~> 4294967295] ident.Z.zselect - (Straightline.expr.Pair - (Straightline.expr.Pair - (Straightline.expr.Cast r[0 ~> 1] - (Straightline.expr.CC_m 4294967296 - (Straightline.expr.Cast r[0 ~> 4294967295] (Straightline.expr.Var _ x)))) - (Straightline.expr.Primitive (t:=type.Z) 0)) (Straightline.expr.Primitive (t:=type.Z) 100)) - (fun x0 => Straightline.expr.Scalar (Straightline.expr.Var _ x0)). -End StraightlineTest. - -(* Convert straightline code to code that uses only a certain set of identifiers *) -Module PreFancy. - Import Straightline.expr. - Section with_wordmax. - Context (log2wordmax : Z) (log2wordmax_pos : 1 < log2wordmax) (log2wordmax_even : log2wordmax mod 2 = 0). - Let wordmax := 2 ^ log2wordmax. - Lemma wordmax_gt_2 : 2 < wordmax. - Proof. - apply Z.le_lt_trans with (m:=2 ^ 1); [ reflexivity | ]. - apply Z.pow_lt_mono_r; omega. - Qed. - - Lemma wordmax_even : wordmax mod 2 = 0. - Proof. - replace 2 with (2 ^ 1) by reflexivity. - subst wordmax. apply Z.mod_same_pow; omega. - Qed. - - Let half_bits := log2wordmax / 2. - - Lemma half_bits_nonneg : 0 <= half_bits. - Proof. subst half_bits; Z.zero_bounds. Qed. - - Let wordmax_half_bits := 2 ^ half_bits. - - Lemma wordmax_half_bits_pos : 0 < wordmax_half_bits. - Proof. subst wordmax_half_bits half_bits. Z.zero_bounds. Qed. - - Lemma half_bits_squared : (wordmax_half_bits - 1) * (wordmax_half_bits - 1) <= wordmax - 1. - Proof. - pose proof wordmax_half_bits_pos. - subst wordmax_half_bits. - transitivity (2 ^ (half_bits + half_bits) - 2 * 2 ^ half_bits + 1). - { rewrite Z.pow_add_r by (subst half_bits; Z.zero_bounds). - autorewrite with push_Zmul; omega. } - { transitivity (wordmax - 2 * 2 ^ half_bits + 1); [ | lia]. - subst wordmax. - apply Z.add_le_mono_r. - apply Z.sub_le_mono_r. - apply Z.pow_le_mono_r; [ omega | ]. - rewrite Z.add_diag; subst half_bits. - apply BinInt.Z.mul_div_le; omega. } - Qed. - - Lemma wordmax_half_bits_le_wordmax : wordmax_half_bits <= wordmax. - Proof. - subst wordmax half_bits wordmax_half_bits. - apply Z.pow_le_mono_r; [lia|]. - apply Z.div_le_upper_bound; lia. - Qed. - - Lemma ones_half_bits : wordmax_half_bits - 1 = Z.ones half_bits. - Proof. - subst wordmax_half_bits. cbv [Z.ones]. - rewrite Z.shiftl_mul_pow2, <-Z.sub_1_r by auto using half_bits_nonneg. - lia. - Qed. - - Lemma wordmax_half_bits_squared : wordmax_half_bits * wordmax_half_bits = wordmax. - Proof. - subst wordmax half_bits wordmax_half_bits. - rewrite <-Z.pow_add_r by Z.zero_bounds. - rewrite Z.add_diag, Z.mul_div_eq by omega. - f_equal; lia. - Qed. - - Section with_var. - Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d)) (consts : list Z). - Local Notation Z := (type.type_primitive type.Z). - - Inductive ident : type -> type -> Type := - | add (imm : BinInt.Z) : ident (Z * Z) (Z * Z) - | addc (imm : BinInt.Z) : ident (Z * Z * Z) (Z * Z) - | sub (imm : BinInt.Z) : ident (Z * Z) (Z * Z) - | subb (imm : BinInt.Z) : ident (Z * Z * Z) (Z * Z) - | mulll : ident (Z * Z) Z - | mullh : ident (Z * Z) Z - | mulhl : ident (Z * Z) Z - | mulhh : ident (Z * Z) Z - | rshi : BinInt.Z -> ident (Z * Z) Z - | selc : ident (Z * Z * Z) Z - | selm : ident (Z * Z * Z) Z - | sell : ident (Z * Z * Z) Z - | addm : ident (Z * Z * Z) Z - . - Definition dummy t : @expr var ident t := Scalar (dummy_scalar (dummy_arrow:=dummy_arrow) t). - - Definition constant_to_scalar_single (const x : BinInt.Z) : option (@scalar var Z) := - if x =? (BinInt.Z.shiftr const half_bits) - then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Shiftr half_bits (Primitive (t:=type.Z) const))) - else if x =? (BinInt.Z.land const (wordmax_half_bits - 1)) - then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Land (wordmax_half_bits-1) (Primitive (t:=type.Z) const))) - else None. - - Definition constant_to_scalar (x : BinInt.Z) - : option (Straightline.expr.scalar Z) := - fold_right (fun c res => match res with - | Some s => Some s - | None => constant_to_scalar_single c x - end) None consts. - - Definition invert_lower' {t} (e : @scalar var t) : - option (@scalar var Z) := - match e in scalar t return option (@scalar var Z) with - | Cast r (Land n x) => - if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? wordmax_half_bits-1) - then Some x - else None - | _ => None - end. - - Definition invert_upper' {t} (e : @scalar var t) : - option (@scalar var Z) := - match e in scalar t return option (@scalar var Z) with - | Cast r (Shiftr n x) => - if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? half_bits) - then Some x - else None - | _ => None - end. - - Definition invert_lower {t} (e : @scalar var t) : - option (@scalar var Z) := - match e in scalar t return option (@scalar var Z) with - | Primitive type.Z x => - match constant_to_scalar x with - | Some y => invert_lower' y - | None => None - end - | _ => invert_lower' e - end. - - Definition invert_upper {t} (e : @scalar var t) : - option (@scalar var Z) := - match e in scalar t return option (@scalar var Z) with - | Primitive type.Z x => - match constant_to_scalar x with - | Some y => invert_upper' y - | None => None - end - | _ => invert_upper' e - end. - - Definition invert_sell {t} (e : @scalar var t) : - option (@scalar var Z * @scalar var Z * @scalar var Z) := - match e return _ with - | Pair _ Z (Pair Z Z x y) z => - match x return option (@scalar var Z * @scalar var Z * @scalar var Z) with - | Cast r (Land n x') => - if (lower r =? 0) && (upper r =? 1) && (n =? 1) - then Some (x', y, z) - else None - | _ => (@None _) - end - | _ => None - end. - - Definition invert_selm {t} (e : @scalar var t) : - option (@scalar var Z * @scalar var Z * @scalar var Z) := - match e return _ with - | Pair _ Z (Pair Z Z x y) z => - match x return option (@scalar var Z * @scalar var Z * @scalar var Z) with - | Cast r (CC_m n x') => - if (lower r =? 0) && (upper r =? 1) && (n =? wordmax) - then Some (x', y, z) - else None - | _ => (@None _) - end - | _ => None - end. - - Definition invert_shift {t} (s : @scalar var t) - : option (@scalar var Z * BinInt.Z) := - match s return option (@scalar var Z * BinInt.Z) with - | Cast r (Shiftl n x) => - match invert_lower x return option (@scalar var Z * BinInt.Z) with - | Some x' => - if (lower r =? 0) && (upper r =? wordmax-1) && (n =? half_bits) - then Some (x', half_bits) - else None - | None => None - end - | _ => - match invert_upper s return _ with - | Some x => Some (x, -half_bits) - | None => None - end - end. - - Definition of_straightline_ident {s d} (idc : ident.ident s d) - : forall t, range_type d -> @scalar var s -> (var d -> @expr var ident t) -> @expr var ident t := - match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with - | ident.Z.add_get_carry_concrete w => - fun t r x f => - if w =? wordmax - then - match x with - | Pair Z Z xl xr => - match invert_shift xl, invert_shift xr with - | _, Some (xr', imm) => LetInAppIdentZZ r (add imm) (Pair xl xr') f - | Some (xl', imm), None => LetInAppIdentZZ r (add imm) (Pair xr xl') f - - | None, None => LetInAppIdentZZ r (add 0) (Pair xl xr) f - end - | _ => dummy _ - end - else dummy _ - | ident.Z.add_with_get_carry_concrete w => - fun t r x f => - if w =? wordmax - then - match x with - | Pair (type.prod Z Z) Z (Pair Z Z xc xl) xr => - match invert_shift xl, invert_shift xr with - | _, Some (xr', imm) => LetInAppIdentZZ r (addc imm) (Pair (Pair xc xl) xr') f - | Some (xl', imm), None => LetInAppIdentZZ r (addc imm) (Pair (Pair xc xr) xl') f - - | None, None => LetInAppIdentZZ r (addc 0) (Pair (Pair xc xl) xr) f - end - | _ => dummy _ - end - else dummy _ - | ident.Z.sub_get_borrow_concrete w => - fun t r x f => - if w =? wordmax - then - match x with - | Pair Z Z xl xr => - match invert_shift xr with - | Some (xr', imm) => LetInAppIdentZZ r (sub imm) (Pair xl xr') f - | None => LetInAppIdentZZ r (sub 0) (Pair xl xr) f - end - | _ => dummy _ - end - else dummy _ - | ident.Z.sub_with_get_borrow_concrete w => - fun t r x f => - if w =? wordmax - then - match x with - | Pair (type.prod Z Z) Z (Pair Z Z xb xl) xr => - match invert_shift xr with - | Some (xr', imm) => LetInAppIdentZZ r (subb imm) (Pair (Pair xb xl) xr') f - | None => LetInAppIdentZZ r (subb 0) (Pair (Pair xb xl) xr) f - end - | _ => dummy _ - end - else dummy _ - | ident.Z.rshi_concrete w n => - fun _ r x f => - if w =? wordmax - then LetInAppIdentZ r (rshi n) x f - else dummy _ - | ident.Z.zselect => - fun t r x f => - match invert_selm x with - | Some (x, y, z) => LetInAppIdentZ r selm (Pair (Pair x y) z) f - | None => match invert_sell x with - | Some (x, y, z) => LetInAppIdentZ r sell (Pair (Pair x y) z) f - | None => LetInAppIdentZ r selc x f - end - end - | ident.Z.add_modulo => fun _ r => LetInAppIdentZ r addm - | ident.Z.mul => - fun t r x f => - match x return expr t with - | Pair _ _ x0 x1 => - match invert_lower x0, invert_lower x1 with - | Some y0, Some y1 => LetInAppIdentZ r mulll (Pair y0 y1) f - | Some y0, None => - match invert_upper x1 with - | Some y1 => LetInAppIdentZ r mullh (Pair y0 y1) f - | None => dummy _ - end - | None, Some y1 => - match invert_upper x0 with - | Some y0 => LetInAppIdentZ r mulhl (Pair y0 y1) f - | None => dummy _ - end - | None, None => - match invert_upper x0, invert_upper x1 with - | Some y0, Some y1 => LetInAppIdentZ r mulhh (Pair y0 y1) f - | _,_ => dummy _ - end - end - | _ => dummy _ - end - | _ => fun t _ _ _ => dummy t - end. - - Fixpoint of_straightline {t} (e : @expr var ident.ident t) - : @expr var ident t := - match e with - | Scalar _ s => Scalar s - | LetInAppIdentZ _ t r idc x f => - of_straightline_ident idc t r[0~>wordmax-1]%zrange x (fun y => of_straightline (f y)) - | LetInAppIdentZZ _ t r idc x f => - of_straightline_ident idc t (r[0~>wordmax-1], r[0~>1])%zrange x (fun y => of_straightline (f y)) - end. - End with_var. - - Section interp. - Context {interp_cast : zrange -> Z -> Z}. - Local Notation interp_scalar := (interp_scalar (interp_cast:=interp_cast)). - Local Notation interp_cast2 := (interp_cast2 (interp_cast:=interp_cast)). - Local Notation low x := (Z.land x (wordmax_half_bits - 1)). - Local Notation high x := (x >> half_bits). - Local Notation shift x imm := ((x << imm) mod wordmax). - - Definition interp_ident {s d} (idc : ident s d) : type.interp s -> type.interp d := - match idc with - | add imm => fun x => Z.add_get_carry_full wordmax (fst x) (shift (snd x) imm) - | addc imm => fun x => Z.add_with_get_carry_full wordmax (fst (fst x)) (snd (fst x)) (shift (snd x) imm) - | sub imm => fun x => Z.sub_get_borrow_full wordmax (fst x) (shift (snd x) imm) - | subb imm => fun x => Z.sub_with_get_borrow_full wordmax (fst (fst x)) (snd (fst x)) (shift (snd x) imm) - | mulll => fun x => low (fst x) * low (snd x) - | mullh => fun x => low (fst x) * high (snd x) - | mulhl => fun x => high (fst x) * low (snd x) - | mulhh => fun x => high (fst x) * high (snd x) - | rshi n => fun x => Z.rshi wordmax (fst x) (snd x) n - | selc => fun x => Z.zselect (fst (fst x)) (snd (fst x)) (snd x) - | selm => fun x => Z.zselect (Z.cc_m wordmax (fst (fst x))) (snd (fst x)) (snd x) - | sell => fun x => Z.zselect (Z.land (fst (fst x)) 1) (snd (fst x)) (snd x) - | addm => fun x => Z.add_modulo (fst (fst x)) (snd (fst x)) (snd x) - end. - - Fixpoint interp {t} (e : @expr type.interp ident t) : type.interp t := - match e with - | Scalar t s => interp_scalar s - | LetInAppIdentZ s d r idc x f => - interp (f (interp_cast r (interp_ident idc (interp_scalar x)))) - | LetInAppIdentZZ s d r idc x f => - interp (f (interp_cast2 r (interp_ident idc (interp_scalar x)))) - end. - End interp. - - Section proofs. - Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) (consts : list Z) - (consts_ok : forall x, In x consts -> 0 <= x <= wordmax - 1). - Context {interp_cast : zrange -> Z -> Z} {interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x}. - Local Notation interp_scalar := (interp_scalar (interp_cast:=interp_cast)). - Local Notation interp_cast2 := (interp_cast2 (interp_cast:=interp_cast)). - - Local Notation word_range := (r[0~>wordmax-1])%zrange. - Local Notation half_word_range := (r[0~>wordmax_half_bits-1])%zrange. - Local Notation flag_range := (r[0~>1])%zrange. - - Definition in_word_range (r : zrange) := is_tighter_than_bool r word_range = true. - Definition in_flag_range (r : zrange) := is_tighter_than_bool r flag_range = true. - - Fixpoint get_range_var (t : type) : type.interp t -> range_type t := - match t with - | type.type_primitive type.Z => - fun x => {| lower := x; upper := x |} - | type.prod a b => - fun x => (get_range_var a (fst x), get_range_var b (snd x)) - | _ => fun _ => tt - end. - - Fixpoint get_range {t} (x : @scalar type.interp t) : range_type t := - match x with - | Var t v => get_range_var t v - | TT => tt - | Nil _ => tt - | Pair _ _ x y => (get_range x, get_range y) - | Cast r _ => r - | Cast2 r _ => r - | Fst _ _ p => fst (get_range p) - | Snd _ _ p => snd (get_range p) - | Shiftr n x => ZRange.map (fun y => Z.shiftr y n) (get_range x) - | Shiftl n x => ZRange.map (fun y => Z.shiftl y n) (get_range x) - | Land n x => r[0~>n]%zrange - | CC_m n x => ZRange.map (Z.cc_m n) (get_range x) - | Primitive type.Z x => {| lower := x; upper := x |} - | Primitive p x => tt - end. - - Fixpoint has_range {t} : range_type t -> type.interp t -> Prop := - match t with - | type.type_primitive type.Z => - fun r x => - lower r <= x <= upper r - | type.prod a b => - fun r x => - has_range (fst r) (fst x) /\ has_range (snd r) (snd x) - | _ => fun _ _ => True - end. - - Inductive ok_scalar : forall {t}, @scalar type.interp t -> Prop := - | sc_ok_var : forall t v, ok_scalar (Var t v) - | sc_ok_unit : ok_scalar TT - | sc_ok_nil : forall t, ok_scalar (Nil t) - | sc_ok_pair : forall A B x y, - @ok_scalar A x -> - @ok_scalar B y -> - ok_scalar (Pair x y) - | sc_ok_cast : forall r (x : scalar type.Z), - ok_scalar x -> - is_tighter_than_bool (get_range x) r = true -> - ok_scalar (Cast r x) - | sc_ok_cast2 : forall r (x : scalar (type.prod type.Z type.Z)), - ok_scalar x -> - is_tighter_than_bool (fst (get_range x)) (fst r) = true -> - is_tighter_than_bool (snd (get_range x)) (snd r) = true -> - ok_scalar (Cast2 r x) - | sc_ok_fst : - forall A B p, @ok_scalar (A * B) p -> ok_scalar (Fst p) - | sc_ok_snd : - forall A B p, @ok_scalar (A * B) p -> ok_scalar (Snd p) - | sc_ok_shiftr : - forall n x, 0 <= n -> ok_scalar x -> ok_scalar (Shiftr n x) - | sc_ok_shiftl : - forall n x, 0 <= n -> 0 <= lower (@get_range type.Z x) -> ok_scalar x -> ok_scalar (Shiftl n x) - | sc_ok_land : - forall n x, 0 <= n -> 0 <= lower (@get_range type.Z x) -> ok_scalar x -> ok_scalar (Land n x) - | sc_ok_cc_m : - forall x, ok_scalar x -> ok_scalar (CC_m wordmax x) - | sc_ok_prim : forall p x, ok_scalar (@Primitive _ p x) - . - - Inductive is_halved : scalar type.Z -> Prop := - | is_halved_lower : - forall x : scalar type.Z, - in_word_range (get_range x) -> - is_halved (Cast half_word_range (Land (wordmax_half_bits - 1) x)) - | is_halved_upper : - forall x : scalar type.Z, - in_word_range (get_range x) -> - is_halved (Cast half_word_range (Shiftr half_bits x)) - | is_halved_constant : - forall y z, - constant_to_scalar consts z = Some y -> - is_halved y -> - is_halved (Primitive (t:=type.Z) z) - . - - Inductive ok_ident : forall s d, scalar s -> range_type d -> ident.ident s d -> Prop := - | ok_add : - forall x y : scalar type.Z, - in_word_range (get_range x) -> - in_word_range (get_range y) -> - ok_ident _ - (type.prod type.Z type.Z) - (Pair x y) - (word_range, flag_range) - (ident.Z.add_get_carry_concrete wordmax) - | ok_addc : - forall (c x y : scalar type.Z) outr, - in_flag_range (get_range c) -> - in_word_range (get_range x) -> - in_word_range (get_range y) -> - lower outr = 0 -> - (0 <= upper (get_range c) + upper (get_range x) + upper (get_range y) <= upper outr \/ outr = word_range) -> - ok_ident _ - (type.prod type.Z type.Z) - (Pair (Pair c x) y) - (outr, flag_range) - (ident.Z.add_with_get_carry_concrete wordmax) - | ok_sub : - forall x y : scalar type.Z, - in_word_range (get_range x) -> - in_word_range (get_range y) -> - ok_ident _ - (type.prod type.Z type.Z) - (Pair x y) - (word_range, flag_range) - (ident.Z.sub_get_borrow_concrete wordmax) - | ok_subb : - forall b x y : scalar type.Z, - in_flag_range (get_range b) -> - in_word_range (get_range x) -> - in_word_range (get_range y) -> - ok_ident _ - (type.prod type.Z type.Z) - (Pair (Pair b x) y) - (word_range, flag_range) - (ident.Z.sub_with_get_borrow_concrete wordmax) - | ok_rshi : - forall (x : scalar (type.prod type.Z type.Z)) n outr, - in_word_range (fst (get_range x)) -> - in_word_range (snd (get_range x)) -> - (* note : using [outr] rather than [word_range] allows for cases where the result has been put in a smaller word size. *) - lower outr = 0 -> - 0 <= n -> - ((0 <= (upper (snd (get_range x)) + upper (fst (get_range x)) * wordmax) / 2^n <= upper outr) - \/ outr = word_range) -> - ok_ident (type.prod type.Z type.Z) type.Z x outr (ident.Z.rshi_concrete wordmax n) - | ok_selc : - forall (x : scalar (type.prod type.Z type.Z)) (y z : scalar type.Z), - in_flag_range (snd (get_range x)) -> - in_word_range (get_range y) -> - in_word_range (get_range z) -> - ok_ident _ - type.Z - (Pair (Pair (Cast flag_range (Snd x)) y) z) - word_range - ident.Z.zselect - | ok_selm : - forall x y z : scalar type.Z, - in_word_range (get_range x) -> - in_word_range (get_range y) -> - in_word_range (get_range z) -> - ok_ident _ - type.Z - (Pair (Pair (Cast flag_range (CC_m wordmax x)) y) z) - word_range - ident.Z.zselect - | ok_sell : - forall x y z : scalar type.Z, - in_word_range (get_range x) -> - in_word_range (get_range y) -> - in_word_range (get_range z) -> - ok_ident _ - type.Z - (Pair (Pair (Cast flag_range (Land 1 x)) y) z) - word_range - ident.Z.zselect - | ok_addm : - forall (x : scalar (type.prod (type.prod type.Z type.Z) type.Z)), - in_word_range (fst (fst (get_range x))) -> - in_word_range (snd (fst (get_range x))) -> - in_word_range (snd (get_range x)) -> - upper (fst (fst (get_range x))) + upper (snd (fst (get_range x))) - lower (snd (get_range x)) < wordmax -> - ok_ident _ - type.Z - x - word_range - ident.Z.add_modulo - | ok_mul : - forall x y : scalar type.Z, - is_halved x -> - is_halved y -> - ok_ident (type.prod type.Z type.Z) - type.Z - (Pair x y) - word_range - ident.Z.mul - . - - Inductive ok_expr : forall {t}, @expr type.interp ident.ident t -> Prop := - | ok_of_scalar : forall t s, ok_scalar s -> @ok_expr t (Scalar s) - | ok_letin_z : forall s d r idc x f, - ok_ident _ type.Z x r idc -> - (r <=? word_range)%zrange = true -> - ok_scalar x -> - (forall y, has_range (t:=type.Z) r y -> ok_expr (f y)) -> - ok_expr (@LetInAppIdentZ _ _ s d r idc x f) - | ok_letin_zz : forall s d r idc x f, - ok_ident _ (type.prod type.Z type.Z) x (r, flag_range) idc -> - (r <=? word_range)%zrange = true -> - ok_scalar x -> - (forall y, has_range (t:=type.Z * type.Z) (r, flag_range) y -> ok_expr (f y)) -> - ok_expr (@LetInAppIdentZZ _ _ s d (r, flag_range) idc x f) - . - - Ltac invert H := - inversion H; subst; - repeat match goal with - | H : existT _ _ _ = existT _ _ _ |- _ => apply (Eqdep_dec.inj_pair2_eq_dec _ type.type_eq_dec) in H; subst - end. - - Lemma has_range_get_range_var {t} (v : type.interp t) : - has_range (get_range_var _ v) v. - Proof. - induction t; cbn [get_range_var has_range fst snd]; auto. - destruct p; auto; cbn [upper lower]; omega. - Qed. - - Lemma has_range_loosen r1 r2 (x : Z) : - @has_range type.Z r1 x -> - is_tighter_than_bool r1 r2 = true -> - @has_range type.Z r2 x. - Proof. - cbv [is_tighter_than_bool has_range]; intros; - match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end; - Z.ltb_to_lt; omega. - Qed. - - Lemma interp_cast_noop x r : - @has_range type.Z r x -> - interp_cast r x = x. - Proof. cbv [has_range]; intros; auto. Qed. - - Lemma interp_cast2_noop x r : - @has_range (type.prod type.Z type.Z) r x -> - interp_cast2 r x = x. - Proof. - cbv [has_range interp_cast2]; intros. - rewrite !interp_cast_correct by tauto. - destruct x; reflexivity. - Qed. - - Lemma has_range_shiftr n (x : scalar type.Z) : - 0 <= n -> - has_range (get_range x) (interp_scalar x) -> - @has_range type.Z (ZRange.map (fun y : Z => y >> n) (get_range x)) (interp_scalar x >> n). - Proof. cbv [has_range]; intros; cbn. auto using Z.shiftr_le with omega. Qed. - Hint Resolve has_range_shiftr : has_range. - - Lemma has_range_shiftl n r x : - 0 <= n -> 0 <= lower r -> - @has_range type.Z r x -> - @has_range type.Z (ZRange.map (fun y : Z => y << n) r) (x << n). - Proof. cbv [has_range]; intros; cbn. auto using Z.shiftl_le_mono with omega. Qed. - Hint Resolve has_range_shiftl : has_range. - - Lemma has_range_land n (x : scalar type.Z) : - 0 <= n -> 0 <= lower (get_range x) -> - has_range (get_range x) (interp_scalar x) -> - @has_range type.Z (r[0~>n])%zrange (Z.land (interp_scalar x) n). - Proof. - cbv [has_range]; intros; cbn. - split; [ apply Z.land_nonneg | apply Z.land_upper_bound_r ]; omega. - Qed. - Hint Resolve has_range_land : has_range. - - Lemma has_range_interp_scalar {t} (x : scalar t) : - ok_scalar x -> - has_range (get_range x) (interp_scalar x). - Proof. - induction 1; cbn [interp_scalar get_range]; - auto with has_range; - try solve [try inversion IHok_scalar; cbn [has_range]; - auto using has_range_get_range_var]; [ | | | ]. - { rewrite interp_cast_noop by eauto using has_range_loosen. - eapply has_range_loosen; eauto. } - { inversion IHok_scalar. - rewrite interp_cast2_noop; - cbn [has_range]; split; eapply has_range_loosen; eauto. } - { cbn. cbv [has_range] in *. - pose proof wordmax_gt_2. - rewrite !Z.cc_m_eq by omega. - split; apply Z.div_le_mono; Z.zero_bounds; omega. } - { destruct p; cbn [has_range upper lower]; auto; omega. } - Qed. - Hint Resolve has_range_interp_scalar : has_range. - - Lemma has_word_range_interp_scalar (x : scalar type.Z) : - ok_scalar x -> - in_word_range (get_range x) -> - @has_range type.Z word_range (interp_scalar x). - Proof. eauto using has_range_loosen, has_range_interp_scalar. Qed. - - Lemma in_word_range_nonneg r : in_word_range r -> 0 <= lower r. - Proof. - cbv [in_word_range is_tighter_than_bool]. - rewrite andb_true_iff; intuition. - Qed. - - Lemma in_word_range_upper_nonneg r x : @has_range type.Z r x -> in_word_range r -> 0 <= upper r. - Proof. - cbv [in_word_range is_tighter_than_bool]; cbn. - rewrite andb_true_iff; intuition. - Z.ltb_to_lt. omega. - Qed. - - Lemma has_word_range_shiftl n r x : - 0 <= n -> upper r * 2 ^ n <= wordmax - 1 -> - @has_range type.Z r x -> - in_word_range r -> - @has_range type.Z word_range (x << n). - Proof. - intros. - eapply has_range_loosen; - [ apply has_range_shiftl; eauto using in_word_range_nonneg with has_range; omega | ]. - cbv [is_tighter_than_bool]. cbn. - apply andb_true_iff; split; apply Z.leb_le; - [ apply Z.shiftl_nonneg; solve [auto using in_word_range_nonneg] | ]. - rewrite Z.shiftl_mul_pow2 by omega. - auto. - Qed. - - Lemma has_range_rshi r n x y : - 0 <= n -> - 0 <= x -> - 0 <= y -> - lower r = 0 -> - (0 <= (y + x * wordmax) / 2^n <= upper r \/ r = word_range) -> - @has_range type.Z r (Z.rshi wordmax x y n). - Proof. - pose proof wordmax_gt_2. - intros. cbv [has_range]. - rewrite Z.rshi_correct by omega. - match goal with |- context [?x mod ?m] => - pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - split; [lia|]. - intuition. - { destruct (Z_lt_dec (upper r) wordmax); [ | lia]. - rewrite Z.mod_small by (split; Z.zero_bounds; omega). - omega. } - { subst r. cbn [upper]. omega. } - Qed. - - Lemma in_word_range_spec r : - (0 <= lower r /\ upper r <= wordmax - 1) - <-> in_word_range r. - Proof. - intros; cbv [in_word_range is_tighter_than_bool]. - rewrite andb_true_iff. - intuition; apply Z.leb_le; cbn [upper lower]; try omega. - Qed. - - Ltac destruct_scalar := - match goal with - | x : scalar (type.prod (type.prod _ _) _) |- _ => - match goal with |- context [interp_scalar x] => - destruct (interp_scalar x) as [ [? ?] ?]; - destruct (get_range x) as [ [? ?] ?] - end - | x : scalar (type.prod _ _) |- _ => - match goal with |- context [interp_scalar x] => - destruct (interp_scalar x) as [? ?]; destruct (get_range x) as [? ?] - end - end. - - Ltac extract_ok_scalar' level x := - match goal with - | H : ok_scalar (Pair (Pair (?f (?g x)) _) _) |- _ => - match (eval compute in (4 <=? level)) with - | true => invert H; extract_ok_scalar' 3 x - | _ => fail - end - | H : ok_scalar (Pair (?f (?g x)) _) |- _ => - match (eval compute in (3 <=? level)) with - | true => invert H; extract_ok_scalar' 2 x - | _ => fail - end - | H : ok_scalar (Pair _ (?f (?g x))) |- _ => - match (eval compute in (3 <=? level)) with - | true => invert H; extract_ok_scalar' 2 x - | _ => fail - end - | H : ok_scalar (?f (?g x)) |- _ => - match (eval compute in (2 <=? level)) with - | true => invert H; extract_ok_scalar' 1 x - | _ => fail - end - | H : ok_scalar (Pair (Pair x _) _) |- _ => - match (eval compute in (2 <=? level)) with - | true => invert H; extract_ok_scalar' 1 x - | _ => fail - end - | H : ok_scalar (Pair (Pair _ x) _) |- _ => - match (eval compute in (2 <=? level)) with - | true => invert H; extract_ok_scalar' 1 x - | _ => fail - end - | H : ok_scalar (?g x) |- _ => invert H - | H : ok_scalar (Pair x _) |- _ => invert H - | H : ok_scalar (Pair _ x) |- _ => invert H - end. - - Ltac extract_ok_scalar := - match goal with |- ok_scalar ?x => extract_ok_scalar' 4 x; assumption end. - - Lemma has_half_word_range_shiftr r x : - in_word_range r -> - @has_range type.Z r x -> - @has_range type.Z half_word_range (x >> half_bits). - Proof. - cbv [in_word_range is_tighter_than_bool]. - rewrite andb_true_iff. - cbn [has_range upper lower]; intros; intuition; Z.ltb_to_lt. - { apply Z.shiftr_nonneg. omega. } - { pose proof half_bits_nonneg. - pose proof half_bits_squared. - assert (x >> half_bits < wordmax_half_bits); [|omega]. - rewrite Z.shiftr_div_pow2 by auto. - apply Z.div_lt_upper_bound; Z.zero_bounds. - subst wordmax_half_bits half_bits. - rewrite <-Z.pow_add_r by omega. - rewrite Z.add_diag, Z.mul_div_eq, log2wordmax_even by omega. - autorewrite with zsimplify_fast. subst wordmax. omega. } - Qed. - - Lemma has_half_word_range_land r x : - in_word_range r -> - @has_range type.Z r x -> - @has_range type.Z half_word_range (x &' (wordmax_half_bits - 1)). - Proof. - pose proof wordmax_half_bits_pos. - cbv [in_word_range is_tighter_than_bool]. - rewrite andb_true_iff. - cbn [has_range upper lower]; intros; intuition; Z.ltb_to_lt. - { apply Z.land_nonneg; omega. } - { apply Z.land_upper_bound_r; omega. } - Qed. - - Section constant_to_scalar. - Lemma constant_to_scalar_single_correct s x z : - 0 <= x <= wordmax - 1 -> - constant_to_scalar_single x z = Some s -> interp_scalar s = z. - Proof. - cbv [constant_to_scalar_single]. - break_match; try discriminate; intros; Z.ltb_to_lt; subst; - try match goal with H : Some _ = Some _ |- _ => inversion H; subst end; - cbn [interp_scalar]; apply interp_cast_noop. - { apply has_half_word_range_shiftr with (r:=r[x~>x]%zrange); - cbv [in_word_range is_tighter_than_bool upper lower has_range]; try omega. - apply andb_true_iff; split; apply Z.leb_le; omega. } - { apply has_half_word_range_land with (r:=r[x~>x]%zrange); - cbv [in_word_range is_tighter_than_bool upper lower has_range]; try omega. - apply andb_true_iff; split; apply Z.leb_le; omega. } - Qed. - - Lemma constant_to_scalar_correct s z : - constant_to_scalar consts z = Some s -> interp_scalar s = z. - Proof. - cbv [constant_to_scalar]. - apply fold_right_invariant; try discriminate. - intros until 2; break_match; eauto using constant_to_scalar_single_correct. - Qed. - - Lemma constant_to_scalar_single_cases x y z : - @constant_to_scalar_single type.interp x z = Some y -> - (y = Cast half_word_range (Land (wordmax_half_bits - 1) (Primitive (t:=type.Z) x))) - \/ (y = Cast half_word_range (Shiftr half_bits (Primitive (t:=type.Z) x))). - Proof. - cbv [constant_to_scalar_single]. - break_match; try discriminate; intros; Z.ltb_to_lt; subst; - try match goal with H : Some _ = Some _ |- _ => inversion H; subst end; - tauto. - Qed. - - Lemma constant_to_scalar_cases y z : - @constant_to_scalar type.interp consts z = Some y -> - (exists x, - @has_range type.Z word_range x - /\ y = Cast half_word_range (Land (wordmax_half_bits - 1) (Primitive x))) - \/ (exists x, - @has_range type.Z word_range x - /\ y = Cast half_word_range (Shiftr half_bits (Primitive x))). - Proof. - cbv [constant_to_scalar]. - apply fold_right_invariant; try discriminate. - intros until 2; break_match; eauto; intros. - match goal with H : constant_to_scalar_single _ _ = _ |- _ => - destruct (constant_to_scalar_single_cases _ _ _ H); subst end. - { left; eexists; split; eauto. - apply consts_ok; auto. } - { right; eexists; split; eauto. - apply consts_ok; auto. } - Qed. - - Lemma ok_scalar_constant_to_scalar y z : constant_to_scalar consts z = Some y -> ok_scalar y. - Proof. - pose proof wordmax_half_bits_pos. pose proof half_bits_nonneg. - let H := fresh in - intro H; apply constant_to_scalar_cases in H; destruct H as [ [? ?] | [? ?] ]; intuition; subst; - cbn [has_range lower upper] in *; repeat constructor; cbn [lower get_range]; try apply Z.leb_refl; try omega. - assert (in_word_range r[x~>x]) by (apply in_word_range_spec; cbn [lower upper]; omega). - pose proof (has_half_word_range_shiftr r[x~>x] x ltac:(assumption) ltac:(cbv [has_range lower upper]; omega)). - cbn [has_range ZRange.map is_tighter_than_bool lower upper] in *. - apply andb_true_iff; cbn [lower upper]; split; apply Z.leb_le; omega. - Qed. - End constant_to_scalar. - Hint Resolve ok_scalar_constant_to_scalar. - - Lemma is_halved_has_range x : - ok_scalar x -> - is_halved x -> - @has_range type.Z half_word_range (interp_scalar x). - Proof. - intro; pose proof (has_range_interp_scalar x ltac:(assumption)). - induction 1; cbn [interp_scalar] in *; intros; try assumption; [ ]. - rewrite <-(constant_to_scalar_correct y z) by assumption. - eauto using has_range_interp_scalar. - Qed. - - Lemma ident_interp_has_range s d x r idc: - ok_scalar x -> - ok_ident s d x r idc -> - has_range r (ident.interp idc (interp_scalar x)). - Proof. - intro. - pose proof (has_range_interp_scalar x ltac:(assumption)). - pose proof wordmax_gt_2. - induction 1; cbn [ident.interp ident.gen_interp]; intros; try destruct_scalar; - repeat match goal with - | H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt - | H : _ /\ _ |- _ => destruct H - | H : is_halved _ |- _ => apply is_halved_has_range in H; [ | extract_ok_scalar ] - | _ => progress subst - | _ => progress (cbv [in_word_range in_flag_range is_tighter_than_bool] in * ) - | _ => progress (cbn [interp_scalar get_range has_range upper lower fst snd] in * ) - end. - { - autorewrite with to_div_mod. - match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - rewrite Z.div_between_0_if by omega. - split; break_match; lia. } - { - autorewrite with to_div_mod. - match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - rewrite Z.div_between_0_if by omega. - match goal with H : _ \/ _ |- _ => destruct H; subst end. - { split; break_match; try lia. - destruct (Z_lt_dec (upper outr) wordmax). - { match goal with |- _ <= ?y mod _ <= ?u => - assert (y <= u) by nia end. - rewrite Z.mod_small by omega. omega. } - { match goal with|- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - omega. } } - { split; break_match; cbn; lia. } } - { - autorewrite with to_div_mod. - match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - rewrite Z.div_sub_small by omega. - split; break_match; lia. } - { - autorewrite with to_div_mod. - match goal with |- context [?a - ?b - ?c] => replace (a - b - c) with (a - (b + c)) by ring end. - match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - rewrite Z.div_sub_small by omega. - split; break_match; lia. } - { apply has_range_rshi; try nia; [ ]. - match goal with H : context [upper ?ra + upper ?rb * wordmax] |- context [?a + ?b * wordmax] => - assert ((a + b * wordmax) / 2^n <= (upper ra + upper rb * wordmax) / 2^n) by (apply Z.div_le_mono; Z.zero_bounds; nia) - end. - match goal with H : _ \/ ?P |- _ \/ ?P => destruct H; [left|tauto] end. - split; Z.zero_bounds; nia. } - { rewrite Z.zselect_correct. break_match; omega. } - { cbn [interp_scalar fst snd get_range] in *. - rewrite Z.zselect_correct. break_match; omega. } - { cbn [interp_scalar fst snd get_range] in *. - rewrite Z.zselect_correct. break_match; omega. } - { rewrite Z.add_modulo_correct. - break_match; Z.ltb_to_lt; omega. } - { cbn [interp_scalar has_range fst snd get_range upper lower] in *. - pose proof half_bits_squared. nia. } - Qed. - - Lemma has_flag_range_cc_m r x : - @has_range type.Z r x -> - in_word_range r -> - @has_range type.Z flag_range (Z.cc_m wordmax x). - Proof. - cbv [has_range in_word_range is_tighter_than_bool]. - cbn [upper lower]; rewrite andb_true_iff; intros. - match goal with H : _ /\ _ |- _ => destruct H; Z.ltb_to_lt end. - pose proof wordmax_gt_2. pose proof wordmax_even. - pose proof (Z.cc_m_small wordmax x). omega. - Qed. - - Lemma has_flag_range_cc_m' (x : scalar type.Z) : - ok_scalar x -> - in_word_range (get_range x) -> - @has_range type.Z flag_range (Z.cc_m wordmax (interp_scalar x)). - Proof. eauto using has_flag_range_cc_m with has_range. Qed. - - Lemma has_flag_range_land r x : - @has_range type.Z r x -> - in_word_range r -> - @has_range type.Z flag_range (Z.land x 1). - Proof. - cbv [has_range in_word_range is_tighter_than_bool]. - cbn [upper lower]; rewrite andb_true_iff; intuition; Z.ltb_to_lt. - { apply Z.land_nonneg. left; omega. } - { apply Z.land_upper_bound_r; omega. } - Qed. - - Lemma has_flag_range_land' (x : scalar type.Z) : - ok_scalar x -> - in_word_range (get_range x) -> - @has_range type.Z flag_range (Z.land (interp_scalar x) 1). - Proof. eauto using has_flag_range_land with has_range. Qed. - - Ltac rewrite_cast_noop_in_mul := - repeat match goal with - | _ => rewrite interp_cast_noop with (r:=half_word_range) in * - by (eapply has_range_loosen; auto using has_range_land, has_range_interp_scalar) - | _ => rewrite interp_cast_noop with (r:=half_word_range) in * - by (eapply has_range_loosen; try apply has_range_shiftr; auto using has_range_interp_scalar; - cbn [ZRange.map get_range] in *; auto) - | _ => rewrite interp_cast_noop by assumption - end. - - Lemma is_halved_cases x : - is_halved x -> - ok_scalar x -> - (exists y, - invert_lower consts x = Some y - /\ invert_upper consts x = None - /\ interp_scalar y &' (wordmax_half_bits - 1) = interp_scalar x) - \/ (exists y, - invert_lower consts x = None - /\ invert_upper consts x = Some y - /\ interp_scalar y >> half_bits = interp_scalar x). - Proof. - induction 1; intros; cbn; rewrite ?Z.eqb_refl; cbn. - { left. eexists; repeat split; auto. - rewrite interp_cast_noop; [ reflexivity | ]. - apply has_half_word_range_land with (r:=get_range x); auto. - apply has_range_interp_scalar; extract_ok_scalar. } - { right. eexists; repeat split; auto. - rewrite interp_cast_noop; [ reflexivity | ]. - apply has_half_word_range_shiftr with (r:=get_range x); auto. - apply has_range_interp_scalar; extract_ok_scalar. } - { match goal with H : constant_to_scalar _ _ = Some _ |- _ => - rewrite H; - let P := fresh in - destruct (constant_to_scalar_cases _ _ H) as [ [? [? ?] ] | [? [? ?] ] ]; - subst; cbn; rewrite ?Z.eqb_refl; cbn - end. - { left; eexists; repeat split; auto. - erewrite <-constant_to_scalar_correct by eassumption. - subst. cbn. - rewrite interp_cast_noop; [ reflexivity | ]. - eapply has_half_word_range_land with (r:=word_range); auto. - cbv [in_word_range is_tighter_than_bool]. - rewrite !Z.leb_refl; reflexivity. } - { right; eexists; repeat split; auto. - erewrite <-constant_to_scalar_correct by eassumption. - subst. cbn. - rewrite interp_cast_noop; [ reflexivity | ]. - eapply has_half_word_range_shiftr with (r:=word_range); auto. - cbv [in_word_range is_tighter_than_bool]. - rewrite !Z.leb_refl; reflexivity. } } - Qed. - - Lemma halved_mul_range x y : - ok_scalar (Pair x y) -> - is_halved x -> - is_halved y -> - 0 <= interp_scalar x * interp_scalar y < wordmax. - Proof. - intro Hok; invert Hok. intros. - repeat match goal with H : _ |- _ => apply is_halved_has_range in H; [|assumption] end. - cbv [has_range lower upper] in *. - pose proof half_bits_squared. nia. - Qed. - - Lemma of_straightline_ident_mul_correct r t x y g : - is_halved x -> - is_halved y -> - ok_scalar (Pair x y) -> - (word_range <=? r)%zrange = true -> - @has_range type.Z word_range (ident.interp ident.Z.mul (interp_scalar (Pair x y))) -> - @interp interp_cast _ (of_straightline_ident dummy_arrow consts ident.Z.mul t r (Pair x y) g) = - @interp interp_cast _ (g (ident.interp ident.Z.mul (interp_scalar (Pair x y)))). - Proof. - intros Hx Hy Hok ? ?; invert Hok; cbn [interp_scalar of_straightline_ident]; - destruct (is_halved_cases x Hx ltac:(assumption)) as [ [? [Pxlow [Pxhigh Pxi] ] ] | [? [Pxlow [Pxhigh Pxi] ] ] ]; - rewrite ?Pxlow, ?Pxhigh; - destruct (is_halved_cases y Hy ltac:(assumption)) as [ [? [Pylow [Pyhigh Pyi] ] ] | [? [Pylow [Pyhigh Pyi] ] ] ]; - rewrite ?Pylow, ?Pyhigh; - cbn; rewrite Pxi, Pyi; assert (0 <= interp_scalar x * interp_scalar y < wordmax) by (auto using halved_mul_range); - rewrite interp_cast_noop by (cbv [is_tighter_than_bool] in *; cbn [has_range upper lower] in *; rewrite andb_true_iff in *; intuition; Z.ltb_to_lt; lia); reflexivity. - Qed. - - Lemma has_word_range_mod_small x: - @has_range type.Z word_range x -> - x mod wordmax = x. - Proof. - cbv [has_range upper lower]. - intros. apply Z.mod_small; omega. - Qed. - - Lemma half_word_range_le_word_range r : - upper r = wordmax_half_bits - 1 -> - lower r = 0 -> - (r <=? word_range)%zrange = true. - Proof. - pose proof wordmax_half_bits_le_wordmax. - destruct r; cbv [is_tighter_than_bool ZRange.lower ZRange.upper]. - intros; subst. - apply andb_true_iff; split; Z.ltb_to_lt; lia. - Qed. - - Lemma and_shiftl_half_bits_eq x : - (x &' (wordmax_half_bits - 1)) << half_bits = x << half_bits mod wordmax. - Proof. - rewrite ones_half_bits. - rewrite Z.land_ones, !Z.shiftl_mul_pow2 by auto using half_bits_nonneg. - rewrite <-wordmax_half_bits_squared. - subst wordmax_half_bits. - rewrite Z.mul_mod_distr_r_full. - reflexivity. - Qed. - - Lemma in_word_range_word_range : in_word_range word_range. - Proof. - cbv [in_word_range is_tighter_than_bool]. - rewrite !Z.leb_refl; reflexivity. - Qed. - - Lemma invert_shift_correct (s : scalar type.Z) x imm : - ok_scalar s -> - invert_shift consts s = Some (x, imm) -> - interp_scalar s = (interp_scalar x << imm) mod wordmax. - Proof. - intros Hok ?; invert Hok; - try match goal with H : ok_scalar ?x, H' : context[Cast _ ?x] |- _ => - invert H end; - try match goal with H : ok_scalar ?x, H' : context[Shiftl _ ?x] |- _ => - invert H end; - try match goal with H : ok_scalar ?x, H' : context[Shiftl _ (Cast _ ?x)] |- _ => - invert H end; - try (cbn [invert_shift invert_upper invert_upper'] in *; discriminate); - repeat match goal with - | _ => progress (cbn [invert_shift invert_lower invert_lower' invert_upper invert_upper' interp_scalar fst snd] in * ) - | _ => rewrite interp_cast_noop by eauto using has_half_word_range_land, has_half_word_range_shiftr, in_word_range_word_range, has_range_loosen - | H : ok_scalar (Shiftr _ _) |- _ => apply has_range_interp_scalar in H - | H : ok_scalar (Shiftl _ _) |- _ => apply has_range_interp_scalar in H - | H : ok_scalar (Land _ _) |- _ => apply has_range_interp_scalar in H - | H : context [if ?x then _ else _] |- _ => - let Heq := fresh in case_eq x; intro Heq; rewrite Heq in H - | H : context [match @constant_to_scalar ?v ?consts ?x with _ => _ end] |- _ => - let Heq := fresh in - case_eq (@constant_to_scalar v consts x); intros until 0; intro Heq; rewrite Heq in *; [|discriminate]; - destruct (constant_to_scalar_cases _ _ Heq) as [ [? [? ?] ] | [? [? ?] ] ]; subst; - pose proof (ok_scalar_constant_to_scalar _ _ Heq) - | H : constant_to_scalar _ _ = Some _ |- _ => erewrite <-(constant_to_scalar_correct _ _ H) - | H : _ |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt - | H : Some _ = Some _ |- _ => progress (invert H) - | _ => rewrite has_word_range_mod_small by eauto using has_range_loosen, half_word_range_le_word_range - | _ => rewrite has_word_range_mod_small by - (eapply has_range_loosen with (r1:=half_word_range); - [ eapply has_half_word_range_shiftr with (r:=word_range) | ]; - eauto using in_word_range_word_range, half_word_range_le_word_range) - | _ => rewrite and_shiftl_half_bits_eq - | _ => progress subst - | _ => reflexivity - | _ => discriminate - end. - Qed. - - Local Ltac solve_commutative_replace := - match goal with - | |- @eq (_ * _) ?x ?y => - replace x with (fst x, snd x) by (destruct x; reflexivity); - replace y with (fst y, snd y) by (destruct y; reflexivity) - end; autorewrite with to_div_mod; solve [repeat (f_equal; try ring)]. - - Fixpoint is_tighter_than_bool_range_type t : range_type t -> range_type t -> bool := - match t with - | type.type_primitive type.Z => (fun r1 r2 => (r1 <=? r2)%zrange) - | type.prod a b => fun r1 r2 => - (is_tighter_than_bool_range_type a (fst r1) (fst r2)) - && (is_tighter_than_bool_range_type b (snd r1) (snd r2)) - | _ => fun _ _ => true - end. - - Definition range_ok {t} : range_type t -> Prop := - match t with - | type.type_primitive type.Z => fun r => in_word_range r - | type.prod type.Z type.Z => fun r => in_word_range (fst r) /\ snd r = flag_range - | _ => fun _ => False - end. - - Lemma of_straightline_ident_correct s d t x r r' (idc : ident.ident s d) g : - ok_ident s d x r idc -> - range_ok r' -> - is_tighter_than_bool_range_type d r r' = true -> - ok_scalar x -> - @interp interp_cast _ (of_straightline_ident dummy_arrow consts idc t r' x g) = - @interp interp_cast _ (g (ident.interp idc (interp_scalar x))). - Proof. - intros. - pose proof wordmax_half_bits_pos. - pose proof (ident_interp_has_range _ _ x r idc ltac:(assumption) ltac:(assumption)). - match goal with H : ok_ident _ _ _ _ _ |- _ => induction H end; - try solve [auto using of_straightline_ident_mul_correct]; - cbv [is_tighter_than_bool_range_type is_tighter_than_bool range_ok] in *; - cbn [of_straightline_ident ident.interp ident.gen_interp - invert_selm invert_sell] in *; - intros; rewrite ?Z.eqb_refl; cbn [andb]; - try match goal with |- context [invert_shift] => break_match end; - cbn [interp interp_ident]; try destruct_scalar; - repeat match goal with - | _ => progress (cbn [fst snd interp_scalar] in * ) - | _ => progress break_match; [ ] - | _ => progress autorewrite with zsimplify_fast - | _ => progress Z.ltb_to_lt - | H : _ /\ _ |- _ => destruct H - | _ => rewrite andb_true_iff in * - | _ => rewrite interp_cast_noop with (r:=flag_range) in * - by (apply has_flag_range_cc_m'; auto; extract_ok_scalar) - | _ => rewrite interp_cast_noop with (r:=flag_range) in * - by (apply has_flag_range_land'; auto; extract_ok_scalar) - | H : _ = (_,_) |- _ => progress (inversion H; subst) - | H : invert_shift _ _ = Some _ |- _ => - apply invert_shift_correct in H; [|extract_ok_scalar]; - rewrite <-H - | H : has_range ?r (?f ?x ?y) |- context [?f ?y ?x] => - replace (f y x) with (f x y) by solve_commutative_replace - | _ => rewrite has_word_range_mod_small - by (eapply has_range_loosen; - [apply has_range_interp_scalar; extract_ok_scalar|]; - assumption) - | _ => rewrite interp_cast_noop by (cbn [has_range fst snd] in *; split; lia) - | _ => rewrite interp_cast2_noop by (cbn [has_range fst snd] in *; split; lia) - | _ => reflexivity - end. - Qed. - - Lemma of_straightline_correct {t} (e : expr t) : - ok_expr e -> - @interp interp_cast _ (of_straightline dummy_arrow consts e) - = Straightline.expr.interp (interp_ident:=@ident.interp) (interp_cast:=interp_cast) e. - Proof. - induction 1; cbn [of_straightline]; intros; - repeat match goal with - | _ => progress cbn [Straightline.expr.interp] - | _ => erewrite of_straightline_ident_correct - by (cbv [range_ok is_tighter_than_bool_range_type]; - eauto using in_word_range_word_range; - try apply andb_true_iff; auto) - | _ => rewrite interp_cast_noop by eauto using has_range_loosen, ident_interp_has_range - | _ => rewrite interp_cast2_noop by eauto using has_range_loosen, ident_interp_has_range - | H : forall y, has_range _ y -> interp _ = _ |- _ => rewrite H by eauto using has_range_loosen, ident_interp_has_range - | _ => reflexivity - end. - Qed. - End proofs. - - Section no_interp_cast. - Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) (consts : list Z) - (consts_ok : forall x, In x consts -> 0 <= x <= wordmax - 1). - - Local Arguments interp _ {_} _. - Local Arguments interp_scalar _ {_} _. - - Local Ltac tighter_than_to_le := - repeat match goal with - | _ => progress (cbv [is_tighter_than_bool] in * ) - | _ => rewrite andb_true_iff in * - | H : _ /\ _ |- _ => destruct H - end; Z.ltb_to_lt. - - Lemma replace_interp_cast_scalar {t} (x : scalar t) interp_cast interp_cast' - (interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x) - (interp_cast'_correct : forall r x, lower r <= x <= upper r -> interp_cast' r x = x) : - ok_scalar x -> - interp_scalar interp_cast x = interp_scalar interp_cast' x. - Proof. - induction 1; cbn [interp_scalar Straightline.expr.interp_scalar]; - repeat match goal with - | _ => progress (cbv [has_range interp_cast2] in * ) - | _ => progress tighter_than_to_le - | H : ok_scalar _ |- _ => apply (has_range_interp_scalar (interp_cast_correct:=interp_cast_correct)) in H - | _ => rewrite <-IHok_scalar - | _ => rewrite interp_cast_correct by omega - | _ => rewrite interp_cast'_correct by omega - | _ => congruence - end. - Qed. - - Lemma replace_interp_cast {t} (e : expr t) interp_cast interp_cast' - (interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x) - (interp_cast'_correct : forall r x, lower r <= x <= upper r -> interp_cast' r x = x) : - ok_expr consts e -> - interp interp_cast (of_straightline dummy_arrow consts e) = - interp interp_cast' (of_straightline dummy_arrow consts e). - Proof. - induction 1; intros; cbn [of_straightline interp]. - { apply replace_interp_cast_scalar; auto. } - { erewrite !of_straightline_ident_correct by (eauto; cbv [range_ok]; apply in_word_range_word_range). - rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto. - eauto using ident_interp_has_range. } - { erewrite !of_straightline_ident_correct by - (eauto; try solve [cbv [range_ok]; split; auto using in_word_range_word_range]; - cbv [is_tighter_than_bool_range_type]; apply andb_true_iff; split; auto). - rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto. - eauto using ident_interp_has_range. } - Qed. - End no_interp_cast. - End with_wordmax. - - Definition of_Expr {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d)) - (var : type -> Type) (x : var s) dummy_arrow : @Straightline.expr.expr var ident d := - @of_straightline log2wordmax var dummy_arrow consts _ (Straightline.of_Expr e var x dummy_arrow). - - Definition interp_cast_mod w r x := if (lower r =? 0) - then if (upper r =? 2^w - 1) - then x mod (2^w) - else if (upper r =? 1) - then x mod 2 - else x - else x. - - Lemma interp_cast_mod_correct w r x : - lower r <= x <= upper r -> - interp_cast_mod w r x = x. - Proof. - cbv [interp_cast_mod]. - intros; break_match; rewrite ?andb_true_iff in *; intuition; Z.ltb_to_lt; - apply Z.mod_small; omega. - Qed. - - Lemma of_Expr_correct {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d)) - (e' : (type.interp s -> Uncurried.expr.expr d)) - (x : type.interp s) dummy_arrow : - e type.interp = Abs e' -> - 1 < log2wordmax -> - log2wordmax mod 2 = 0 -> - Straightline.expr.ok_expr (e' x) -> - (forall x0 : Z, In x0 consts -> 0 <= x0 <= 2 ^ log2wordmax - 1) -> - ok_expr log2wordmax consts - (of_uncurried (dummy_arrow:=dummy_arrow) (depth (fun _ : type => unit) (fun _ : type => tt) (e _)) (e' x)) -> - (depth type.interp (@DefaultValue.type.default) (e' x) <= depth (fun _ : type => unit) (fun _ : type => tt) (e _))%nat -> - @interp log2wordmax (interp_cast_mod log2wordmax) _ (of_Expr log2wordmax consts e type.interp x dummy_arrow) = @Uncurried.expr.interp _ (@ident.interp) _ (e type.interp) x. - Proof. - intro He'; intros; cbv [of_Expr Straightline.of_Expr]. - rewrite He'; cbn [invert_Abs expr.interp]. - assert (forall r z, lower r <= z <= upper r -> ident.cast ident.cast_outside_of_range r z = z) as interp_cast_correct. - { cbv [ident.cast]; intros; break_match; rewrite ?andb_true_iff, ?andb_false_iff in *; intuition; Z.ltb_to_lt; omega. } - erewrite replace_interp_cast with (interp_cast':=ident.cast ident.cast_outside_of_range) by auto using interp_cast_mod_correct. - rewrite of_straightline_correct by auto. - erewrite Straightline.expr.of_uncurried_correct by eassumption. - reflexivity. - Qed. - - Module Notations. - Import PrintingNotations. - Import Straightline.expr. - - Local Notation "'tZ'" := (type.type_primitive type.Z). - Notation "'RegZero'" := (Primitive (t:=type.Z) 0). - Notation "$ x" := (Cast uint256 (Fst (Cast2 (uint256,bool)%core (Var (tZ * tZ) x)))) (at level 10, format "$ x"). - Notation "$ x" := (Cast uint128 (Fst (Cast2 (uint128,bool)%core (Var (tZ * tZ) x)))) (at level 10, format "$ x"). - Notation "$ x ₁" := (Cast uint256 (Fst (Var (tZ * tZ) x))) (at level 10, format "$ x ₁"). - Notation "$ x ₂" := (Cast uint256 (Snd (Var (tZ * tZ) x))) (at level 10, format "$ x ₂"). - Notation "$ x" := (Cast uint256 (Var tZ x)) (at level 10, format "$ x"). - Notation "$ x" := (Cast uint128 (Var tZ x)) (at level 10, format "$ x"). - Notation "$ x" := (Cast bool (Var tZ x)) (at level 10, format "$ x"). - Notation "carry{ $ x }" := (Cast bool (Snd (Cast2 (uint256, bool)%core (Var (tZ * tZ) x)))) - (at level 10, format "carry{ $ x }"). - Notation "Lower{ x }" := (Cast uint128 (Land 340282366920938463463374607431768211455 x)) - (at level 10, format "Lower{ x }"). - Notation "f @( y , x1 , x2 ); g " - := (LetInAppIdentZZ (uint256, bool)%core f (Pair x1 x2) (fun y => g)) - (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g "). - Notation "f @( y , x1 , x2 , x3 ); g " - := (LetInAppIdentZZ (uint256, bool)%core f (Pair (Pair x1 x2) x3) (fun y => g)) - (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g "). - Notation "f @( y , x1 , x2 , x3 ); '#128' g " - := (LetInAppIdentZZ (uint128, bool)%core f (Pair (Pair x1 x2) x3) (fun y => g)) - (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '#128' '//' g "). - Notation "f @( y , x1 , x2 ); g " - := (LetInAppIdentZ uint256 f (Pair x1 x2) (fun y => g)) - (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g "). - Notation "f @( y , x1 , x2 , x3 ); g " - := (LetInAppIdentZ uint256 f (Pair (Pair x1 x2) x3) (fun y => g)) - (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g "). - (* special cases for when the ident constructor takes a constant argument *) - Notation "add@( y , x1 , x2 , n ); g" - := (LetInAppIdentZZ (uint256, bool) (add n) (Pair x1 x2) (fun y => g)) - (at level 10, g at level 200, format "add@( y , x1 , x2 , n ); '//' g"). - Notation "addc@( y , x1 , x2 , x3 , n ); g" - := (LetInAppIdentZZ (uint256, bool) (addc n) (Pair (Pair x1 x2) x3) (fun y => g)) - (at level 10, g at level 200, format "addc@( y , x1 , x2 , x3 , n ); '//' g"). - Notation "addc@( y , x1 , x2 , x3 , n ); '#128' g" - := (LetInAppIdentZZ (uint128, bool) (addc n) (Pair (Pair x1 x2) x3) (fun y => g)) - (at level 10, g at level 200, format "addc@( y , x1 , x2 , x3 , n ); '#128' '//' g"). - Notation "sub@( y , x1 , x2 , n ); g" - := (LetInAppIdentZZ (uint256, bool) (sub n) (Pair x1 x2) (fun y => g)) - (at level 10, g at level 200, format "sub@( y , x1 , x2 , n ); '//' g"). - Notation "subb@( y , x1 , x2 , x3 , n ); g" - := (LetInAppIdentZZ (uint256, bool) (subb n) (Pair (Pair x1 x2) x3) (fun y => g)) - (at level 10, g at level 200, format "subb@( y , x1 , x2 , x3 , n ); '//' g"). - Notation "rshi@( y , x1 , x2 , n ); g" - := (LetInAppIdentZ _ (rshi n) (Pair x1 x2) (fun y => g)) - (at level 10, g at level 200, format "rshi@( y , x1 , x2 , n ); '//' g "). - Notation "'ret' $ x" := (Scalar (Var tZ x)) (at level 10, format "'ret' $ x"). - Notation "( x , y )" := (Pair x y) (at level 10, left associativity). - End Notations. - - Module Tactics. - Ltac ok_expr_step' := - match goal with - | _ => assumption - | |- _ <= _ <= _ \/ @eq zrange _ _ => - right; lazy; try split; congruence - | |- _ <= _ <= _ \/ @eq zrange _ _ => - left; lazy; try split; congruence - | |- context [PreFancy.ok_ident] => constructor - | |- context [PreFancy.ok_scalar] => constructor; try omega - | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ] - | |- context [PreFancy.is_halved] => constructor - | |- context [PreFancy.in_word_range] => lazy; reflexivity - | |- context [PreFancy.in_flag_range] => lazy; reflexivity - | |- context [PreFancy.get_range] => - cbn [PreFancy.get_range lower upper fst snd ZRange.map] - | x : type.interp (type.prod _ _) |- _ => destruct x - | |- (_ <=? _)%zrange = true => - match goal with - | |- context [PreFancy.get_range_var] => - cbv [is_tighter_than_bool PreFancy.has_range fst snd upper lower] in *; cbn; - apply andb_true_iff; split; apply Z.leb_le - | _ => lazy - end; omega || reflexivity - | |- @eq zrange _ _ => lazy; reflexivity - | |- _ <= _ => omega - | |- _ <= _ <= _ => omega - end; intros. - - Ltac ok_expr_step := - match goal with - | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step' - end; intros; cbn [Nat.max]. - End Tactics. -End PreFancy. - -Module Fancy. - Import Straightline.expr. - - Module CC. - Inductive code : Type := - | C : code - | M : code - | L : code - | Z : code - . - - Record state := - { cc_c : bool; cc_m : bool; cc_l : bool; cc_z : bool }. - - Definition code_dec (x y : code) : {x = y} + {x <> y}. - Proof. destruct x, y; try apply (left eq_refl); right; congruence. Defined. - - Definition update (to_write : list code) (result : BinInt.Z) (cc_spec : code -> BinInt.Z -> bool) (old_state : state) - : state := - {| - cc_c := if (In_dec code_dec C to_write) - then cc_spec C result - else old_state.(cc_c); - cc_m := if (In_dec code_dec M to_write) - then cc_spec M result - else old_state.(cc_m); - cc_l := if (In_dec code_dec L to_write) - then cc_spec L result - else old_state.(cc_l); - cc_z := if (In_dec code_dec Z to_write) - then cc_spec Z result - else old_state.(cc_z) - |}. - - End CC. - - Record instruction := - { - num_source_regs : nat; - writes_conditions : list CC.code; - spec : tuple Z num_source_regs -> CC.state -> Z - }. - - Section expr. - Context {name : Type} (name_eqb : name -> name -> bool) (wordmax : Z) (cc_spec : CC.code -> Z -> bool). - - Inductive expr := - | Ret : name -> expr - | Instr (i : instruction) - (rd : name) (* destination register *) - (args : tuple name i.(num_source_regs)) (* source registers *) - (cont : expr) (* next line *) - : expr - . - - Fixpoint interp (e : expr) (cc : CC.state) (ctx : name -> Z) : Z := - match e with - | Ret n => ctx n - | Instr i rd args cont => - let result := i.(spec) (Tuple.map ctx args) cc in - let new_cc := CC.update i.(writes_conditions) result cc_spec cc in - let new_ctx := (fun n : name => if name_eqb n rd then result mod wordmax else ctx n) in - interp cont new_cc new_ctx - end. - End expr. - - Section ISA. - Import CC. - - (* For the C flag, we have to consider cases with a negative result (like the one returned by an underflowing borrow). - In these cases, we want to set the C flag to true. *) - Definition cc_spec (x : CC.code) (result : BinInt.Z) : bool := - match x with - | CC.C => if result <? 0 then true else Z.testbit result 256 - | CC.M => Z.testbit result 255 - | CC.L => Z.testbit result 0 - | CC.Z => result =? 0 - end. - - Local Definition lower128 x := (Z.land x (Z.ones 128)). - Local Definition upper128 x := (Z.shiftr x 128). - Local Notation "x '[C]'" := (if x.(cc_c) then 1 else 0) (at level 20). - Local Notation "x '[M]'" := (if x.(cc_m) then 1 else 0) (at level 20). - Local Notation "x '[L]'" := (if x.(cc_l) then 1 else 0) (at level 20). - Local Notation "x '[Z]'" := (if x.(cc_z) then 1 else 0) (at level 20). - Local Notation "'int'" := (BinInt.Z). - Local Notation "x << y" := ((x << y) mod (2^256)) : Z_scope. (* truncating left shift *) - - - (* Note: In the specification document, argument order gets a bit - confusing. Like here, r0 is always the first argument "source 0" - and r1 the second. But the specification of MUL128LU is: - (R[RS1][127:0] * R[RS0][255:128]) - - while the specification of SUB is: - (R[RS0] - shift(R[RS1], imm)) - - In the SUB case, r0 is really treated the first argument, but in - MUL128LU the order seems to be reversed; rather than low-high, we - take the high part of the first argument r0 and the low parts of - r1. This is also true for MUL128UL. *) - - Definition ADD (imm : int) : instruction := - {| - num_source_regs := 2; - writes_conditions := [C; M; L; Z]; - spec := (fun '(r0, r1) cc => - r0 + (r1 << imm)) - |}. - - Definition ADDC (imm : int) : instruction := - {| - num_source_regs := 2; - writes_conditions := [C; M; L; Z]; - spec := (fun '(r0, r1) cc => - r0 + (r1 << imm) + cc[C]) - |}. - - Definition SUB (imm : int) : instruction := - {| - num_source_regs := 2; - writes_conditions := [C; M; L; Z]; - spec := (fun '(r0, r1) cc => - r0 - (r1 << imm)) - |}. - - Definition SUBC (imm : int) : instruction := - {| - num_source_regs := 2; - writes_conditions := [C; M; L; Z]; - spec := (fun '(r0, r1) cc => - r0 - (r1 << imm) - cc[C]) - |}. - - - Definition MUL128LL : instruction := - {| - num_source_regs := 2; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1) cc => - (lower128 r0) * (lower128 r1)) - |}. - - Definition MUL128LU : instruction := - {| - num_source_regs := 2; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1) cc => - (lower128 r1) * (upper128 r0)) (* see note *) - |}. - - Definition MUL128UL : instruction := - {| - num_source_regs := 2; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1) cc => - (upper128 r1) * (lower128 r0)) (* see note *) - |}. - - Definition MUL128UU : instruction := - {| - num_source_regs := 2; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1) cc => - (upper128 r0) * (upper128 r1)) - |}. - - (* Note : Unlike the other operations, the output of RSHI is - truncated in the specification. This is not strictly necessary, - since the interpretation function truncates the output - anyway. However, it is useful to make the definition line up - exactly with Z.rshi. *) - Definition RSHI (imm : int) : instruction := - {| - num_source_regs := 2; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1) cc => - (((2^256 * r0) + r1) >> imm) mod (2^256)) - |}. - - Definition SELC : instruction := - {| - num_source_regs := 2; - writes_conditions := []; - spec := (fun '(r0, r1) cc => - if cc[C] =? 1 then r0 else r1) - |}. - - Definition SELM : instruction := - {| - num_source_regs := 2; - writes_conditions := []; - spec := (fun '(r0, r1) cc => - if cc[M] =? 1 then r0 else r1) - |}. - - Definition SELL : instruction := - {| - num_source_regs := 2; - writes_conditions := []; - spec := (fun '(r0, r1) cc => - if cc[L] =? 1 then r0 else r1) - |}. - - (* TODO : treat the MOD register specially, like CC *) - Definition ADDM : instruction := - {| - num_source_regs := 3; - writes_conditions := [M; L; Z]; - spec := (fun '(r0, r1, MOD) cc => - let ra := r0 + r1 in - if ra >=? MOD - then ra - MOD - else ra) - |}. - - End ISA. - - Module Registers. - Inductive register : Type := - | r0 : register - | r1 : register - | r2 : register - | r3 : register - | r4 : register - | r5 : register - | r6 : register - | r7 : register - | r8 : register - | r9 : register - | r10 : register - | r11 : register - | r12 : register - | r13 : register - | r14 : register - | r15 : register - | r16 : register - | r17 : register - | r18 : register - | r19 : register - | r20 : register - | r21 : register - | r22 : register - | r23 : register - | r24 : register - | r25 : register - | r26 : register - | r27 : register - | r28 : register - | r29 : register - | r30 : register - | RegZero : register (* r31 *) - | RegMod : register - . - - Definition reg_dec (x y : register) : {x = y} + {x <> y}. - Proof. destruct x, y; try (apply left; congruence); right; congruence. Defined. - Definition reg_eqb x y := if reg_dec x y then true else false. - - Lemma reg_eqb_neq x y : x <> y -> reg_eqb x y = false. - Proof. cbv [reg_eqb]; break_match; congruence. Qed. - Lemma reg_eqb_refl x : reg_eqb x x = true. - Proof. cbv [reg_eqb]; break_match; congruence. Qed. - End Registers. - - Section of_prefancy. - Context (name : Type) (name_succ : name -> name) (error : name) (consts : Z -> option name). - - Fixpoint var (t : type) : Type := - match t with - | type.type_primitive type.Z => name - | type.prod a b => var a * var b - | _ => unit - end. - - Fixpoint of_prefancy_scalar {t} (s : @scalar var t) : var t := - match s with - | Var t v => v - | Pair a b x y => (of_prefancy_scalar x, of_prefancy_scalar y) - | Cast r x => of_prefancy_scalar x - | Cast2 r x => of_prefancy_scalar x - | Fst a b x => fst (of_prefancy_scalar x) - | Snd a b x => snd (of_prefancy_scalar x) - | Shiftr n x => error - | Shiftl n x => error - | Land n x => error - | CC_m n x => error - | @Primitive _ type.Z x => match consts x with - | Some n => n - | None => error - end - | @Primitive _ _ x => tt - | TT => tt - | Nil _ => tt - end. - - (* Note : some argument orders are reversed for MUL128LU, MUL128UL, SELC, SELM, and SELL *) - Definition of_prefancy_ident {s d} (idc : PreFancy.ident s d) - : @scalar var s -> {i : instruction & tuple name i.(num_source_regs) } := - match idc in PreFancy.ident s d return _ with - | PreFancy.add imm => fun args : @scalar var (type.Z * type.Z) => - existT _ (ADD imm) (of_prefancy_scalar args) - | PreFancy.addc imm => fun args : @scalar var (type.Z * type.Z * type.Z) => - existT _ (ADDC imm) (of_prefancy_scalar (Pair (Snd (Fst args)) (Snd args))) - | PreFancy.sub imm => fun args : @scalar var (type.Z * type.Z) => - existT _ (SUB imm) (of_prefancy_scalar args) - | PreFancy.subb imm => fun args : @scalar var (type.Z * type.Z * type.Z) => - existT _ (SUBC imm) (of_prefancy_scalar (Pair (Snd (Fst args)) (Snd args))) - | PreFancy.mulll => fun args : @scalar var (type.Z * type.Z) => - existT _ MUL128LL (of_prefancy_scalar args) - | PreFancy.mullh => fun args : @scalar var (type.Z * type.Z) => - existT _ MUL128LU (of_prefancy_scalar (Pair (Snd args) (Fst args))) - | PreFancy.mulhl => fun args : @scalar var (type.Z * type.Z) => - existT _ MUL128UL (of_prefancy_scalar (Pair (Snd args) (Fst args))) - | PreFancy.mulhh => fun args : @scalar var (type.Z * type.Z) => - existT _ MUL128UU (of_prefancy_scalar args) - | PreFancy.rshi imm => fun args : @scalar var (type.Z * type.Z) => - existT _ (RSHI imm) (of_prefancy_scalar args) - | PreFancy.selc => fun args : @scalar var (type.Z * type.Z * type.Z) => - existT _ SELC (of_prefancy_scalar (Pair (Snd args) (Snd (Fst args)))) - | PreFancy.selm => fun args : @scalar var (type.Z * type.Z * type.Z) => - existT _ SELM (of_prefancy_scalar (Pair (Snd args) (Snd (Fst args)))) - | PreFancy.sell => fun args : @scalar var (type.Z * type.Z * type.Z) => - existT _ SELL (of_prefancy_scalar (Pair (Snd args) (Snd (Fst args)))) - | PreFancy.addm => fun args : @scalar var (type.Z * type.Z * type.Z) => - existT _ ADDM (of_prefancy_scalar args) - end. - - Fixpoint of_prefancy (next_name : name) {t} (e : @Straightline.expr.expr var PreFancy.ident t) : expr := - match e with - | LetInAppIdentZ s d r idc x f => - let instr_args := @of_prefancy_ident s type.Z idc x in - let i : instruction := projT1 instr_args in - let args : tuple name i.(num_source_regs) := projT2 instr_args in - Instr i next_name args (of_prefancy (name_succ next_name) (f next_name)) - | LetInAppIdentZZ s d r idc x f => - let instr_args := @of_prefancy_ident s (type.Z * type.Z) idc x in - let i : instruction := projT1 instr_args in - let args : tuple name i.(num_source_regs) := projT2 instr_args in - Instr i next_name args (of_prefancy (name_succ next_name) (f (next_name, error))) (* we pass the error code as the carry register, because it cannot be read from directly. *) - | Scalar type.Z s => Ret (of_prefancy_scalar s) - | _ => Ret error - end. - End of_prefancy. - - Section allocate_registers. - Context (reg name : Type) (name_eqb : name -> name -> bool) (error : reg). - Fixpoint allocate (e : @expr name) (reg_list : list reg) (name_to_reg : name -> reg) : @expr reg := - match e with - | Ret n => Ret (name_to_reg n) - | Instr i rd args cont => - match reg_list with - | r :: reg_list' => Instr i r (Tuple.map name_to_reg args) (allocate cont reg_list' (fun n => if name_eqb n rd then r else name_to_reg n)) - | nil => Ret error - end - end. - End allocate_registers. - - Definition test_prog : @expr positive := - Instr (ADD (128)) 3%positive (1, 2)%positive - (Instr (ADDC 0) 4%positive (3,1)%positive - (Ret 4%positive)). - - Definition x1 := 2^256 - 1. - Definition x2 := 2^128 - 1. - Definition wordmax := 2^256. - Definition expected := - let r3' := (x1 + (x2 << 128)) in - let r3 := r3' mod wordmax in - let c := r3' / wordmax in - let r4' := (r3 + x1 + c) in - r4' mod wordmax. - Definition actual := - interp Pos.eqb - (2^256) cc_spec test_prog {|CC.cc_c:=false; CC.cc_m:=false; CC.cc_l:=false; CC.cc_z:=false|} - (fun n => if n =? 1%positive - then x1 - else if n =? 2%positive - then x2 - else 0). - Lemma test_prog_ok : expected = actual. - Proof. reflexivity. Qed. - - Definition of_Expr {s d} next_name (consts : Z -> option positive) (consts_list : list Z) (e : Expr (s -> d)) (x : var positive s) dummy_arrow : positive -> @expr positive := - fun error => - @of_prefancy positive Pos.succ error consts next_name _ (PreFancy.of_Expr 256 consts_list e (var positive) x dummy_arrow). - -End Fancy. - -Module Prod. - Import Fancy. Import Registers. - - Definition Mul256 (out src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr := - Instr MUL128LL out (src1, src2) - (Instr MUL128UL tmp (src1, src2) - (Instr (ADD 128) out (out, tmp) - (Instr MUL128LU tmp (src1, src2) - (Instr (ADD 128) out (out, tmp) cont)))). - Definition Mul256x256 (out outHigh src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr := - Instr MUL128LL out (src1, src2) - (Instr MUL128UU outHigh (src1, src2) - (Instr MUL128UL tmp (src1, src2) - (Instr (ADD 128) out (out, tmp) - (Instr (ADDC (-128)) outHigh (outHigh, tmp) - (Instr MUL128LU tmp (src1, src2) - (Instr (ADD 128) out (out, tmp) - (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont))))))). - - Definition MontRed256 lo hi y t1 t2 scratch RegPInv : @Fancy.expr register := - Mul256 y lo RegPInv t1 - (Mul256x256 t1 t2 y RegMod scratch - (Instr (ADD 0) lo (lo, t1) - (Instr (ADDC 0) hi (hi, t2) - (Instr SELC y (RegMod, RegZero) - (Instr (SUB 0) lo (hi, y) - (Instr ADDM lo (lo, RegZero, RegMod) - (Ret lo))))))). - - (* Barrett reduction -- this is only the "reduce" part, excluding the initial multiplication. *) - Definition MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 : @Fancy.expr register := - let q1Bottom256 := scratchp1 in - let muSelect := scratchp2 in - let q2 := scratchp3 in - let q2High := scratchp4 in - let q2High2 := scratchp5 in - let q3 := scratchp1 in - let r2 := scratchp2 in - let r2High := scratchp3 in - let maybeM := scratchp1 in - Instr SELM muSelect (RegMuLow, RegZero) - (Instr (RSHI 255) q1Bottom256 (xHigh, x) - (Mul256x256 q2 q2High q1Bottom256 RegMuLow scratchp5 - (Instr (RSHI 255) q2High2 (RegZero, xHigh) - (Instr (ADD 0) q2High (q2High, q1Bottom256) - (Instr (ADDC 0) q2High2 (q2High2, RegZero) - (Instr (ADD 0) q2High (q2High, muSelect) - (Instr (ADDC 0) q2High2 (q2High2, RegZero) - (Instr (RSHI 1) q3 (q2High2, q2High) - (Mul256x256 r2 r2High RegMod q3 scratchp4 - (Instr (SUB 0) muSelect (x, r2) - (Instr (SUBC 0) xHigh (xHigh, r2High) - (Instr SELL maybeM (RegMod, RegZero) - (Instr (SUB 0) q3 (muSelect, maybeM) - (Instr ADDM x (q3, RegZero, RegMod) - (Ret x))))))))))))))). -End Prod. - -Module ProdEquiv. - Import Fancy. Import Registers. - - Definition interp256 := Fancy.interp reg_eqb (2^256) cc_spec. - Lemma interp_step i rd args cont cc ctx : - interp256 (Instr i rd args cont) cc ctx = - let result := spec i (Tuple.map ctx args) cc in - let new_cc := CC.update (writes_conditions i) result cc_spec cc in - let new_ctx := fun n => if reg_eqb n rd then result mod wordmax else ctx n in interp256 cont new_cc new_ctx. - Proof. reflexivity. Qed. - - Lemma interp_state_equiv e : - forall cc ctx cc' ctx', - cc = cc' -> (forall r, ctx r = ctx' r) -> - interp256 e cc ctx = interp256 e cc' ctx'. - Proof. - induction e; intros; subst; cbn; [solve[auto]|]. - apply IHe; rewrite Tuple.map_ext with (g:=ctx') by auto; - [reflexivity|]. - intros; break_match; auto. - Qed. - Lemma cc_overwrite_full x1 x2 l1 cc : - CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec (CC.update l1 x1 cc_spec cc) = CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec cc. - Proof. - cbv [CC.update]. cbn [CC.cc_c CC.cc_m CC.cc_l CC.cc_z]. - break_match; try match goal with H : ~ In _ _ |- _ => cbv [In] in H; tauto end. - reflexivity. - Qed. - - Definition value_unused r e : Prop := - forall x cc ctx, interp256 e cc ctx = interp256 e cc (fun r' => if reg_eqb r' r then x else ctx r'). - - Lemma value_unused_skip r i rd args cont (Hcont: value_unused r cont) : - r <> rd -> - (~ In r (Tuple.to_list _ args)) -> - value_unused r (Instr i rd args cont). - Proof. - cbv [value_unused] in *; intros. - rewrite !interp_step; cbv zeta. - rewrite Hcont with (x:=x). - match goal with |- ?lhs = ?rhs => - match lhs with context [Tuple.map ?f ?t] => - match rhs with context [Tuple.map ?g ?t] => - rewrite (Tuple.map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence) - end end end. - apply interp_state_equiv; [ congruence | ]. - { intros; cbv [reg_eqb] in *; break_match; congruence. } - Qed. - - Lemma value_unused_overwrite r i args cont : - (~ In r (Tuple.to_list _ args)) -> - value_unused r (Instr i r args cont). - Proof. - cbv [value_unused]; intros; rewrite !interp_step; cbv zeta. - match goal with |- ?lhs = ?rhs => - match lhs with context [Tuple.map ?f ?t] => - match rhs with context [Tuple.map ?g ?t] => - rewrite (Tuple.map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence) - end end end. - apply interp_state_equiv; [ congruence | ]. - { intros; cbv [reg_eqb] in *; break_match; congruence. } - Qed. - - Lemma value_unused_ret r r' : - r <> r' -> - value_unused r (Ret r'). - Proof. - cbv - [reg_dec]; intros. - break_match; congruence. - Qed. - - Ltac remember_results := - repeat match goal with |- context [(spec ?i ?args ?flags) mod ?w] => - let x := fresh "x" in - let y := fresh "y" in - let Heqx := fresh "Heqx" in - remember (spec i args flags) as x eqn:Heqx; - remember (x mod w) as y - end. - - Ltac do_interp_step := - rewrite interp_step; cbn - [interp spec]; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; - remember_results. - - Lemma interp_Mul256 out src1 src2 tmp tmp2 cont cc ctx: - out <> src1 -> - out <> src2 -> - out <> tmp -> - out <> tmp2 -> - src1 <> src2 -> - src1 <> tmp -> - src1 <> tmp2 -> - src2 <> tmp -> - src2 <> tmp2 -> - tmp <> tmp2 -> - value_unused tmp cont -> - value_unused tmp2 cont -> - interp256 (Prod.Mul256 out src1 src2 tmp cont) cc ctx = - interp256 ( - Instr MUL128LU tmp (src1, src2) - (Instr MUL128UL tmp2 (src1, src2) - (Instr MUL128LL out (src1, src2) - (Instr (ADD 128) out (out, tmp2) - (Instr (ADD 128) out (out, tmp) cont))))) cc ctx. - Proof. - intros; cbv [Prod.Mul256]. - repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU ADD] in * ). - - match goal with H : value_unused tmp _ |- _ => erewrite H end. - match goal with H : value_unused tmp2 _ |- _ => erewrite H end. - apply interp_state_equiv. - { rewrite !cc_overwrite_full. - f_equal. subst. lia. } - { intros; cbv [reg_eqb]. - repeat (break_match_step ltac:(fun _ => idtac); try congruence); reflexivity. } - Qed. - - Lemma interp_Mul256x256 out outHigh src1 src2 tmp tmp2 cont cc ctx: - out <> src1 -> - out <> outHigh -> - out <> src2 -> - out <> tmp -> - out <> tmp2 -> - outHigh <> src1 -> - outHigh <> src2 -> - outHigh <> tmp -> - outHigh <> tmp2 -> - src1 <> src2 -> - src1 <> tmp -> - src1 <> tmp2 -> - src2 <> tmp -> - src2 <> tmp2 -> - tmp <> tmp2 -> - value_unused tmp cont -> - value_unused tmp2 cont -> - interp256 (Prod.Mul256x256 out outHigh src1 src2 tmp cont) cc ctx = - interp256 ( - Instr MUL128LL out (src1, src2) - (Instr MUL128LU tmp (src1, src2) - (Instr MUL128UL tmp2 (src1, src2) - (Instr MUL128UU outHigh (src1, src2) - (Instr (ADD 128) out (out, tmp2) - (Instr (ADDC (-128)) outHigh (outHigh, tmp2) - (Instr (ADD 128) out (out, tmp) - (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont)))))))) cc ctx. - Proof. - intros; cbv [Prod.Mul256x256]. - repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU MUL128UU ADD ADDC] in * ). - - match goal with H : value_unused tmp _ |- _ => erewrite H end. - match goal with H : value_unused tmp2 _ |- _ => erewrite H end. - apply interp_state_equiv. - { rewrite !cc_overwrite_full. - f_equal. - subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. - lia. } - { intros; cbv [reg_eqb]. - repeat (break_match_step ltac:(fun _ => idtac); try congruence); try reflexivity; [ ]. - subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. - lia. } - Qed. - - Lemma mulll_comm rd x y cont cc ctx : - ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (y, x) cont) cc ctx. - Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. - - Lemma mulhh_comm rd x y cont cc ctx : - ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (y, x) cont) cc ctx. - Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. - - Lemma mullh_mulhl rd x y cont cc ctx : - ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UL rd (y, x) cont) cc ctx. - Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. - - Lemma add_comm rd x y cont cc ctx : - 0 <= ctx x < 2^256 -> - 0 <= ctx y < 2^256 -> - ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (y, x) cont) cc ctx. - Proof. - intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.add_comm. - rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. - Qed. - - Lemma addc_comm rd x y cont cc ctx : - 0 <= ctx x < 2^256 -> - 0 <= ctx y < 2^256 -> - ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (y, x) cont) cc ctx. - Proof. - intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite (Z.add_comm (ctx x)). - rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. - Qed. - - (* Tactics to help prove that something in Fancy is line-by-line equivalent to something in PreFancy *) - Ltac push_value_unused := - repeat match goal with - | |- ~ In _ _ => cbn; intuition; congruence - | _ => apply ProdEquiv.value_unused_overwrite - | _ => apply ProdEquiv.value_unused_skip; [ | congruence | ] - | _ => apply ProdEquiv.value_unused_ret; congruence - end. - - Ltac remember_single_result := - match goal with |- context [(Fancy.spec ?i ?args ?cc) mod ?w] => - let x := fresh "x" in - let y := fresh "y" in - let Heqx := fresh "Heqx" in - remember (Fancy.spec i args cc) as x eqn:Heqx; - remember (x mod w) as y - end. - Ltac step_both_sides := - match goal with |- ProdEquiv.interp256 (Fancy.Instr ?i ?rd1 ?args1 _) _ ?ctx1 = ProdEquiv.interp256 (Fancy.Instr ?i ?rd2 ?args2 _) _ ?ctx2 => - rewrite (ProdEquiv.interp_step i rd1 args1); rewrite (ProdEquiv.interp_step i rd2 args2); - cbn - [Fancy.interp Fancy.spec]; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; - remember_single_result; - lazymatch goal with - | |- context [Fancy.spec i _ _] => - let Heqa1 := fresh in - let Heqa2 := fresh in - remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx1 args1) eqn:Heqa1; - remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx2 args2) eqn:Heqa2; - cbn in Heqa1; cbn in Heqa2; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa1 by congruence; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa2 by congruence; - let a1 := match type of Heqa1 with _ = ?a1 => a1 end in - let a2 := match type of Heqa2 with _ = ?a2 => a2 end in - (fail 1 "arguments to " i " do not match; LHS has " a1 " and RHS has " a2) - | _ => idtac - end - end. -End ProdEquiv. - -(* Lemmas to help prove that a fancy and prefancy expression have the -same meaning -- should be replaced eventually with a proof of fancy -passes in general. *) -Module Fancy_PreFancy_Equiv. - Import Fancy.Registers. - - Lemma interp_cast_mod_eq w u x: u = 2^w - 1 -> PreFancy.interp_cast_mod w r[0 ~> u] x = x mod 2^w. - Proof. - cbv [PreFancy.interp_cast_mod upper lower]; intros; subst. - rewrite !Z.eqb_refl. - reflexivity. - Qed. - Lemma interp_cast_mod_flag w x: PreFancy.interp_cast_mod w r[0 ~> 1] x = x mod 2. - Proof. - cbv [PreFancy.interp_cast_mod upper lower]. - break_match; Z.ltb_to_lt; subst; try omega. - f_equal; lia. - Qed. - - Lemma interp_equivZ {s} w u (Hu : u = 2^w-1) i rd regs e cc ctx idc args f : - (Fancy.spec i (Tuple.map ctx regs) cc - = PreFancy.interp_ident (d:=type.Z) w idc (Straightline.expr.interp_scalar (interp_cast:=PreFancy.interp_cast_mod w) args)) -> - ( let r := Fancy.spec i (Tuple.map ctx regs) cc in - Fancy.interp reg_eqb (2 ^ w) Fancy.cc_spec e - (Fancy.CC.update (Fancy.writes_conditions i) r Fancy.cc_spec cc) - (fun n : register => if reg_eqb n rd then r mod 2 ^ w else ctx n) = - PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w (f (r mod 2 ^ w))) -> - Fancy.interp reg_eqb (2^w) Fancy.cc_spec (Fancy.Instr i rd regs e) cc ctx - = PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w - (@Straightline.expr.LetInAppIdentZ _ _ s _ (r[0~>2^w-1])%zrange idc args f). - Proof. - cbv zeta; intros spec_eq next_eq. - cbn [Fancy.interp PreFancy.interp]. - rewrite next_eq. - rewrite <-spec_eq. - rewrite interp_cast_mod_eq by omega. - reflexivity. - Qed. - - Lemma interp_equivZZ {s} w (Hw : 2 < 2 ^ w) u (Hu : u = 2^w - 1) i rd regs e cc ctx idc args f : - ((Fancy.spec i (Tuple.map ctx regs) cc) mod 2 ^ w - = fst (PreFancy.interp_ident (d:=type.Z*type.Z) w idc (Straightline.expr.interp_scalar (interp_cast:=PreFancy.interp_cast_mod w) args))) -> - ((if Fancy.cc_spec Fancy.CC.C(Fancy.spec i (Tuple.map ctx regs) cc) then 1 else 0) - = snd (PreFancy.interp_ident (d:=type.Z*type.Z) w idc (Straightline.expr.interp_scalar (interp_cast:=PreFancy.interp_cast_mod w) args)) mod 2) -> - ( let r := Fancy.spec i (Tuple.map ctx regs) cc in - Fancy.interp reg_eqb (2 ^ w) Fancy.cc_spec e - (Fancy.CC.update (Fancy.writes_conditions i) r Fancy.cc_spec cc) - (fun n : register => if reg_eqb n rd then r mod 2 ^ w else ctx n) = - PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w - (f (r mod 2 ^ w, if (Fancy.cc_spec Fancy.CC.C r) then 1 else 0))) -> - Fancy.interp reg_eqb (2^w) Fancy.cc_spec (Fancy.Instr i rd regs e) cc ctx - = PreFancy.interp (t:=type.Z) (interp_cast:=PreFancy.interp_cast_mod w) w - (@Straightline.expr.LetInAppIdentZZ _ _ s _ (r[0~>u], r[0~>1])%zrange idc args f). - Proof. - cbv zeta; intros spec_eq1 spec_eq2 next_eq. - cbn [Fancy.interp PreFancy.interp]. - cbv [Straightline.expr.interp_cast2]. cbn [fst snd]. - rewrite next_eq. - rewrite interp_cast_mod_eq by omega. - rewrite interp_cast_mod_flag by omega. - rewrite <-spec_eq1, <-spec_eq2. - rewrite Z.mod_mod by omega. - reflexivity. - Qed. -End Fancy_PreFancy_Equiv. - -Module BarrettReduction. - (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *) - Section Generic. - Context {T} (rep : T -> Z -> Prop) - (k : Z) (k_pos : 0 < k) - (low : T -> Z) - (low_correct : forall a x, rep a x -> low a = x mod 2 ^ k) - (shiftr : T -> Z -> T) - (shiftr_correct : forall a x n, - rep a x -> - 0 <= n <= k -> - rep (shiftr a n) (x / 2 ^ n)) - (mul_high : T -> T -> Z -> T) - (mul_high_correct : forall a b x y x0y1, - rep a x -> - rep b y -> - 2 ^ k <= x < 2^(k+1) -> - 0 <= y < 2^(k+1) -> - x0y1 = x mod 2 ^ k * (y / 2 ^ k) -> - rep (mul_high a b x0y1) (x * y / 2 ^ k)) - (mul : Z -> Z -> T) - (mul_correct : forall x y, - 0 <= x < 2^k -> - 0 <= y < 2^k -> - rep (mul x y) (x * y)) - (sub : T -> T -> T) - (sub_correct : forall a b x y, - rep a x -> - rep b y -> - 0 <= x - y < 2^k * 2^k -> - rep (sub a b) (x - y)) - (cond_sub1 : T -> Z -> Z) - (cond_sub1_correct : forall a x y, - rep a x -> - 0 <= x < 2 * y -> - 0 <= y < 2 ^ k -> - cond_sub1 a y = if (x <? 2 ^ k) then x else x - y) - (cond_sub2 : Z -> Z -> Z) - (cond_sub2_correct : forall x y, cond_sub2 x y = if (x <? y) then x else x - y). - Context (xt mut : T) (M muSelect: Z). - - Let mu := 2 ^ (2 * k) / M. - Context x (mu_rep : rep mut mu) (x_rep : rep xt x). - Context (M_nz : 0 < M) - (x_range : 0 <= x < M * 2 ^ k) - (M_range : 2 ^ (k - 1) < M < 2 ^ k) - (M_good : 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu) - (muSelect_correct: muSelect = mu mod 2 ^ k * (x / 2 ^ (k - 1) / 2 ^ k)). - - Definition qt := - dlet_nd muSelect := muSelect in (* makes sure muSelect is not inlined in the output *) - dlet_nd q1 := shiftr xt (k - 1) in - dlet_nd twoq := mul_high mut q1 muSelect in - shiftr twoq 1. - Definition reduce := - dlet_nd qt := qt in - dlet_nd r2 := mul (low qt) M in - dlet_nd r := sub xt r2 in - let q3 := cond_sub1 r M in - cond_sub2 q3 M. - - Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k). - Proof. clear -M_range M_nz x_range k_pos; rewrite <-Z.add_diag, Z.pow_add_r; nia. Qed. - - Lemma pow_2k_eq : 2 ^ (2*k) = 2 ^ (k - 1) * 2 ^ (k + 1). - Proof. clear -k_pos; rewrite <-Z.pow_add_r by omega. f_equal; ring. Qed. - - Lemma mu_bounds : 2 ^ k <= mu < 2^(k+1). - Proof. - pose proof looser_bound. - subst mu. split. - { apply Z.div_le_lower_bound; omega. } - { apply Z.div_lt_upper_bound; try omega. - rewrite pow_2k_eq; apply Z.mul_lt_mono_pos_r; auto with zarith. } - Qed. - - Lemma shiftr_x_bounds : 0 <= x / 2 ^ (k - 1) < 2^(k+1). - Proof. - pose proof looser_bound. - split; [ solve [Z.zero_bounds] | ]. - apply Z.div_lt_upper_bound; auto with zarith. - rewrite <-pow_2k_eq. omega. - Qed. - Hint Resolve shiftr_x_bounds. - - Ltac solve_rep := eauto using shiftr_correct, mul_high_correct, mul_correct, sub_correct with omega. - - Let q := mu * (x / 2 ^ (k - 1)) / 2 ^ (k + 1). - - Lemma q_correct : rep qt q . - Proof. - pose proof mu_bounds. cbv [qt]; subst q. - rewrite Z.pow_add_r, <-Z.div_div by Z.zero_bounds. - solve_rep. - Qed. - Hint Resolve q_correct. - - Lemma x_mod_small : x mod 2 ^ (k - 1) <= M. - Proof. transitivity (2 ^ (k - 1)); auto with zarith. Qed. - Hint Resolve x_mod_small. - - Lemma q_bounds : 0 <= q < 2 ^ k. - Proof. - pose proof looser_bound. pose proof x_mod_small. pose proof mu_bounds. - split; subst q; [ solve [Z.zero_bounds] | ]. - edestruct q_nice_strong with (n:=M) as [? Hqnice]; - try rewrite Hqnice; auto; try omega; [ ]. - apply Z.le_lt_trans with (m:= x / M). - { break_match; omega. } - { apply Z.div_lt_upper_bound; omega. } - Qed. - - Lemma two_conditional_subtracts : - forall a x, - rep a x -> - 0 <= x < 2 * M -> - cond_sub2 (cond_sub1 a M) M = cond_sub2 (cond_sub2 x M) M. - Proof. - intros. - erewrite !cond_sub2_correct, !cond_sub1_correct by (eassumption || omega). - break_match; Z.ltb_to_lt; try lia; discriminate. - Qed. - - Lemma r_bounds : 0 <= x - q * M < 2 * M. - Proof. - pose proof looser_bound. pose proof q_bounds. pose proof x_mod_small. - subst q mu; split. - { Z.zero_bounds. apply qn_small; omega. } - { apply r_small_strong; rewrite ?Z.pow_1_r; auto; omega. } - Qed. - - Lemma reduce_correct : reduce = x mod M. - Proof. - pose proof looser_bound. pose proof r_bounds. pose proof q_bounds. - assert (2 * M < 2^k * 2^k) by nia. - rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega). - cbv [reduce Let_In]. - erewrite low_correct by eauto. Z.rewrite_mod_small. - erewrite two_conditional_subtracts by solve_rep. - rewrite !cond_sub2_correct. - subst q; reflexivity. - Qed. - End Generic. - - Section BarrettReduction. - Context (k : Z) (k_bound : 2 <= k). - Context (M muLow : Z). - Context (M_pos : 0 < M) - (muLow_eq : muLow + 2^k = 2^(2*k) / M) - (muLow_bounds : 0 <= muLow < 2^k) - (M_bound1 : 2 ^ (k - 1) < M < 2^k) - (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)). - - Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k). - Context (nout : nat) (Hnout : nout = 2%nat). - Let w := weight k 1. - Local Lemma k_range : 0 < 1 <= k. Proof. omega. Qed. - Let props : @weight_properties w := wprops k 1 k_range. - - Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval. - - Definition low (t : list Z) : Z := nth_default 0 t 0. - Definition high (t : list Z) : Z := nth_default 0 t 1. - Definition represents (t : list Z) (x : Z) := - t = [x mod 2^k; x / 2^k] /\ 0 <= x < 2^k * 2^k. - - Lemma represents_eq t x : - represents t x -> t = [x mod 2^k; x / 2^k]. - Proof. cbv [represents]; tauto. Qed. - - Lemma represents_length t x : represents t x -> length t = 2%nat. - Proof. cbv [represents]; intuition. subst t; reflexivity. Qed. - - Lemma represents_low t x : - represents t x -> low t = x mod 2^k. - Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. - - Lemma represents_high t x : - represents t x -> high t = x / 2^k. - Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. - - Lemma represents_low_range t x : - represents t x -> 0 <= x mod 2^k < 2^k. - Proof. auto with zarith. Qed. - - Lemma represents_high_range t x : - represents t x -> 0 <= x / 2^k < 2^k. - Proof. - destruct 1 as [? [? ?] ]; intros. - auto using Z.div_lt_upper_bound with zarith. - Qed. - Hint Resolve represents_length represents_low_range represents_high_range. - - Lemma represents_range t x : - represents t x -> 0 <= x < 2^k*2^k. - Proof. cbv [represents]; tauto. Qed. - - Lemma represents_id x : - 0 <= x < 2^k * 2^k -> - represents [x mod 2^k; x / 2^k] x. - Proof. - intros; cbv [represents]; autorewrite with cancel_pair. - Z.rewrite_mod_small; tauto. - Qed. - - Local Ltac push_rep := - repeat match goal with - | H : represents ?t ?x |- _ => unique pose proof (represents_low_range _ _ H) - | H : represents ?t ?x |- _ => unique pose proof (represents_high_range _ _ H) - | H : represents ?t ?x |- _ => rewrite (represents_low t x) in * by assumption - | H : represents ?t ?x |- _ => rewrite (represents_high t x) in * by assumption - end. - - Definition shiftr (t : list Z) (n : Z) : list Z := - [Z.rshi (2^k) (high t) (low t) n; Z.rshi (2^k) 0 (high t) n]. - - Lemma shiftr_represents a i x : - represents a x -> - 0 <= i <= k -> - represents (shiftr a i) (x / 2 ^ i). - Proof. - cbv [shiftr]; intros; push_rep. - match goal with H : _ |- _ => pose proof (represents_range _ _ H) end. - assert (0 < 2 ^ i) by auto with zarith. - assert (x < 2 ^ i * 2 ^ k * 2 ^ k) by nia. - assert (0 <= x / 2 ^ k / 2 ^ i < 2 ^ k) by - (split; Z.zero_bounds; auto using Z.div_lt_upper_bound with zarith). - repeat match goal with - | _ => rewrite Z.rshi_correct by auto with zarith - | _ => rewrite <-Z.div_mod''' by auto with zarith - | _ => progress autorewrite with zsimplify_fast - | _ => progress Z.rewrite_mod_small - | |- context [represents [(?a / ?c) mod ?b; ?a / ?b / ?c] ] => - rewrite (Z.div_div_comm a b c) by auto with zarith - | _ => solve [auto using represents_id, Z.div_lt_upper_bound with zarith lia] - end. - Qed. - - Context (Hw : forall i, w i = (2 ^ k) ^ Z.of_nat i). - Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r. - - Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2). - (* TODO: use this definition once issue #352 is resolved *) - (* Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2). *) - Definition widesub (t1 t2 : list Z) := - let t1_0 := hd 0 t1 in - let t1_1 := hd 0 (tl t1) in - let t2_0 := hd 0 t2 in - let t2_1 := hd 0 (tl t2) in - dlet_nd x0 := Z.sub_get_borrow_full (2^k) t1_0 t2_0 in - dlet_nd x1 := Z.sub_with_get_borrow_full (2^k) (snd x0) t1_1 t2_1 in - [fst x0; fst x1]. - Definition widemul := BaseConversion.widemul_inlined k n nout. - - Lemma partition_represents x : - 0 <= x < 2^k*2^k -> - represents (Rows.partition w 2 x) x. - Proof. - intros; cbn. change_weight. - Z.rewrite_mod_small. - autorewrite with zsimplify_fast. - auto using represents_id. - Qed. - - Lemma eval_represents t x : - represents t x -> eval w 2 t = x. - Proof. - intros; rewrite (represents_eq t x) by assumption. - cbn. change_weight; push_rep. - autorewrite with zsimplify. reflexivity. - Qed. - - Ltac wide_op partitions_pf := - repeat match goal with - | _ => rewrite partitions_pf by eauto - | _ => rewrite partitions_pf by auto with zarith - | _ => erewrite eval_represents by eauto - | _ => solve [auto using partition_represents, represents_id] - end. - - Lemma wideadd_represents t1 t2 x y : - represents t1 x -> - represents t2 y -> - 0 <= x + y < 2^k*2^k -> - represents (wideadd t1 t2) (x + y). - Proof. intros; cbv [wideadd]. wide_op Rows.add_partitions. Qed. - - Lemma widesub_represents t1 t2 x y : - represents t1 x -> - represents t2 y -> - 0 <= x - y < 2^k*2^k -> - represents (widesub t1 t2) (x - y). - Proof. - intros; cbv [widesub Let_In]. - rewrite (represents_eq t1 x) by assumption. - rewrite (represents_eq t2 y) by assumption. - cbn [hd tl]. - autorewrite with to_div_mod. - pull_Zmod. - match goal with |- represents [?m; ?d] ?x => - replace d with (x / 2 ^ k); [solve [auto using represents_id] |] end. - rewrite <-(Z.mod_small ((x - y) / 2^k) (2^k)) by (split; try apply Z.div_lt_upper_bound; Z.zero_bounds). - f_equal. - transitivity ((x mod 2^k - y mod 2^k + 2^k * (x / 2 ^ k) - 2^k * (y / 2^k)) / 2^k). { - rewrite (Z.div_mod x (2^k)) at 1 by auto using Z.pow_nonzero with omega. - rewrite (Z.div_mod y (2^k)) at 1 by auto using Z.pow_nonzero with omega. - f_equal. ring. } - autorewrite with zsimplify. - ring. - Qed. - (* Works with Rows.sub-based widesub definition - Proof. intros; cbv [widesub]. wide_op Rows.sub_partitions. Qed. - *) - - Lemma widemul_represents x y : - 0 <= x < 2^k -> - 0 <= y < 2^k -> - represents (widemul x y) (x * y). - Proof. - intros; cbv [widemul]. - assert (0 <= x * y < 2^k*2^k) by auto with zarith. - wide_op BaseConversion.widemul_correct. - Qed. - - Definition mul_high (a b : list Z) a0b1 : list Z := - dlet_nd a0b0 := widemul (low a) (low b) in - dlet_nd ab := wideadd [high a0b0; high b] [low b; 0] in - wideadd ab [a0b1; 0]. - - Lemma mul_high_idea d a b a0 a1 b0 b1 : - d <> 0 -> - a = d * a1 + a0 -> - b = d * b1 + b0 -> - (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1. - Proof. - intros. subst a b. autorewrite with push_Zmul. - ring_simplify_subterms. rewrite Z.pow_2_r. - rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega). - repeat match goal with - | |- context [d * ?a * ?b * ?c] => - replace (d * a * b * c) with (a * b * c * d) by ring - | |- context [d * ?a * ?b] => - replace (d * a * b) with (a * b * d) by ring - end. - rewrite !Z.div_add by omega. - autorewrite with zsimplify. - rewrite (Z.mul_comm a0 b0). - ring_simplify. ring. - Qed. - - Lemma represents_trans t x y: - represents t y -> y = x -> - represents t x. - Proof. congruence. Qed. - - Lemma represents_add x y : - 0 <= x < 2 ^ k -> - 0 <= y < 2 ^ k -> - represents [x;y] (x + 2^k*y). - Proof. - intros; cbv [represents]; autorewrite with zsimplify. - repeat split; (reflexivity || nia). - Qed. - - Lemma represents_small x : - 0 <= x < 2^k -> - represents [x; 0] x. - Proof. - intros. - eapply represents_trans. - { eauto using represents_add with zarith. } - { ring. } - Qed. - - Lemma mul_high_represents a b x y a0b1 : - represents a x -> - represents b y -> - 2^k <= x < 2^(k+1) -> - 0 <= y < 2^(k+1) -> - a0b1 = x mod 2^k * (y / 2^k) -> - represents (mul_high a b a0b1) ((x * y) / 2^k). - Proof. - cbv [mul_high Let_In]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros. - assert (4 <= 2 ^ k) by (transitivity (Z.pow 2 2); auto with zarith). - assert (0 <= x * y / 2^k < 2^k*2^k) by (Z.div_mod_to_quot_rem_in_goal; nia). - - rewrite mul_high_idea with (a:=x) (b:=y) (a0 := low a) (a1 := high a) (b0 := low b) (b1 := high b) in * - by (push_rep; Z.div_mod_to_quot_rem_in_goal; lia). - - push_rep. subst a0b1. - assert (y / 2 ^ k < 2) by (apply Z.div_lt_upper_bound; omega). - replace (x / 2 ^ k) with 1 in * by (rewrite Z.div_between_1; lia). - autorewrite with zsimplify_fast in *. - - eapply represents_trans. - { repeat (apply wideadd_represents; - [ | apply represents_small; Z.div_mod_to_quot_rem_in_goal; nia| ]). - erewrite represents_high; [ | apply widemul_represents; solve [ auto with zarith ] ]. - { apply represents_add; try reflexivity; solve [auto with zarith]. } - { match goal with H : 0 <= ?x + ?y < ?z |- 0 <= ?x < ?z => - split; [ solve [Z.zero_bounds] | ]; - eapply Z.le_lt_trans with (m:= x + y); nia - end. } - { omega. } } - { ring. } - Qed. - - Definition cond_sub1 (a : list Z) y : Z := - dlet_nd maybe_y := Z.zselect (Z.cc_l (high a)) 0 y in - dlet_nd diff := Z.sub_get_borrow_full (2^k) (low a) maybe_y in - fst diff. - - Lemma cc_l_only_bit : forall x s, 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s. - Proof. - cbv [Z.cc_l]; intros. - rewrite Z.div_between_0_if by omega. - break_match; Z.ltb_to_lt; Z.rewrite_mod_small; omega. - Qed. - - Lemma cond_sub1_correct a x y : - represents a x -> - 0 <= x < 2 * y -> - 0 <= y < 2 ^ k -> - cond_sub1 a y = if (x <? 2 ^ k) then x else x - y. - Proof. - intros; cbv [cond_sub1 Let_In]. rewrite Z.zselect_correct. push_rep. - break_match; Z.ltb_to_lt; rewrite cc_l_only_bit in *; try omega; - autorewrite with zsimplify_fast to_div_mod pull_Zmod; auto with zarith. - Qed. - - Definition cond_sub2 x y := Z.add_modulo x 0 y. - Lemma cond_sub2_correct x y : - cond_sub2 x y = if (x <? y) then x else x - y. - Proof. - cbv [cond_sub2]. rewrite Z.add_modulo_correct. - autorewrite with zsimplify_fast. break_match; Z.ltb_to_lt; omega. - Qed. - - Section Defn. - Context (xLow xHigh : Z) (xLow_bounds : 0 <= xLow < 2^k) (xHigh_bounds : 0 <= xHigh < M). - Let xt := [xLow; xHigh]. - Let x := xLow + 2^k * xHigh. - - Lemma x_rep : represents xt x. - Proof. cbv [represents]; subst xt x; autorewrite with cancel_pair zsimplify; repeat split; nia. Qed. - - Lemma x_bounds : 0 <= x < M * 2 ^ k. - Proof. subst x; nia. Qed. - - Definition muSelect := Z.zselect (Z.cc_m (2 ^ k) xHigh) 0 muLow. - - Local Hint Resolve Z.div_nonneg Z.div_lt_upper_bound. - Local Hint Resolve shiftr_represents mul_high_represents widemul_represents widesub_represents - cond_sub1_correct cond_sub2_correct represents_low represents_add. - - Lemma muSelect_correct : - muSelect = (2 ^ (2 * k) / M) mod 2 ^ k * ((x / 2 ^ (k - 1)) / 2 ^ k). - Proof. - (* assertions to help arith tactics *) - pose proof x_bounds. - assert (2^k * M < 2 ^ (2*k)) by (rewrite <-Z.add_diag, Z.pow_add_r; nia). - assert (0 <= x / (2 ^ k * (2 ^ k / 2)) < 2) by (Z.div_mod_to_quot_rem_in_goal; auto with nia). - assert (0 < 2 ^ k / 2) by Z.zero_bounds. - assert (2 ^ (k - 1) <> 0) by auto with zarith. - assert (2 < 2 ^ k) by (eapply Z.le_lt_trans with (m:=2 ^ 1); auto with zarith). - - cbv [muSelect]. rewrite <-muLow_eq. - rewrite Z.zselect_correct, Z.cc_m_eq by auto with zarith. - replace xHigh with (x / 2^k) by (subst x; autorewrite with zsimplify; lia). - autorewrite with pull_Zdiv push_Zpow. - rewrite (Z.mul_comm (2 ^ k / 2)). - break_match; [ ring | ]. - match goal with H : 0 <= ?x < 2, H' : ?x <> 0 |- _ => replace x with 1 by omega end. - autorewrite with zsimplify; reflexivity. - Qed. - - Lemma mu_rep : represents [muLow; 1] (2 ^ (2 * k) / M). - Proof. rewrite <-muLow_eq. eapply represents_trans; auto with zarith. Qed. - - Derive barrett_reduce - SuchThat (barrett_reduce = x mod M) - As barrett_reduce_correct. - Proof. - erewrite <-reduce_correct with (rep:=represents) (muSelect:=muSelect) (k0:=k) (mut:=[muLow;1]) (xt0:=xt) - by (auto using x_bounds, muSelect_correct, x_rep, mu_rep; omega). - subst barrett_reduce. reflexivity. - Qed. - End Defn. - End BarrettReduction. - - (* all the list operations from for_reification.ident *) - Strategy 100 [length seq repeat combine map flat_map partition app rev fold_right update_nth nth_default ]. - - Derive barrett_red_gen - SuchThat (forall (k M muLow : Z) - (n nout: nat) - (xLow xHigh : Z), - Interp (t:=type.reify_type_of barrett_reduce) - barrett_red_gen k M muLow n nout xLow xHigh - = barrett_reduce k M muLow n nout xLow xHigh) - As barrett_red_gen_correct. - Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. - (* TODO : reification here is still quite slow (~90s on a beefy machine). Possibly just due to size of term, but warrants further investigation. *) - Module Export ReifyHints. - Global Hint Extern 1 (_ = barrett_reduce _ _ _ _ _ _ _) => simple apply barrett_red_gen_correct : reify_gen_cache. - End ReifyHints. - - Section rbarrett_red. - Context (M : Z) - (machine_wordsize : Z). - - Let bound := Some r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. - Let mu := (2 ^ (2 * machine_wordsize)) / M. - Let muLow := mu mod (2 ^ machine_wordsize). - - Check barrett_reduce_correct. - Print Pipeline.Values_not_provably_distinct. - - Definition relax_zrange_of_machine_wordsize' - := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z. - (* TODO: This is a special-case hack to let the prefancy pass have enough bounds information. *) - Definition relax_zrange_of_machine_wordsize r : option zrange := - if (lower r =? 0) && (upper r =? 2) - then Some r - else relax_zrange_of_machine_wordsize' r. - - Lemma relax_zrange_good (r r' z : zrange) : - (z <=? r)%zrange = true -> - relax_zrange_of_machine_wordsize r = Some r' -> (z <=? r')%zrange = true. - Proof. - cbv [relax_zrange_of_machine_wordsize]; break_match; [congruence|]. - eauto using relax_zrange_gen_good. - Qed. - - Local Arguments relax_zrange_of_machine_wordsize / . - - Let relax_zrange := relax_zrange_of_machine_wordsize. - - Definition check_args {T} (res : Pipeline.ErrorT T) - : Pipeline.ErrorT T - := if (mu / (2 ^ machine_wordsize) =? 0) - then Pipeline.Error (Pipeline.Values_not_provably_distinct "mu / 2 ^ k ≠ 0" (mu / 2 ^ machine_wordsize) 0) - else if (machine_wordsize <? 2) - then Pipeline.Error (Pipeline.Value_not_le "~ (2 <=k)" 2 machine_wordsize) - else if (negb (Z.log2 M + 1 =? machine_wordsize)) - then Pipeline.Error - (Pipeline.Values_not_provably_equal "log2(M)+1 != k" (Z.log2 M + 1) machine_wordsize) - else if (2 ^ (machine_wordsize + 1) - mu <? 2 * (2 ^ (2 * machine_wordsize) mod M)) - then Pipeline.Error - (Pipeline.Value_not_le "~ (2 * (2 ^ (2*k) mod M) <= 2^(k + 1) - mu)" - (2 * (2 ^ (2*machine_wordsize) mod M)) - (2^(machine_wordsize + 1) - mu)) - else res. - - Notation BoundsPipeline_correct in_bounds out_bounds op - := (fun rv (rop : Expr (type.reify_type_of op)) Hrop - => @Pipeline.BoundsPipeline_correct_trans - false (* subst01 *) - relax_zrange - relax_zrange_good - _ - rop - in_bounds - out_bounds - op - Hrop rv) - (only parsing). - - Definition rbarrett_red_correct - := BoundsPipeline_correct - (bound, bound) - bound - (barrett_reduce machine_wordsize M muLow 2 2). - - Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). - Definition rbarrett_red_correctT rv : Prop - := type_of_strip_3arrow (@rbarrett_red_correct rv). - End rbarrett_red. -End BarrettReduction. - -Ltac solve_rbarrett_red := solve_rop BarrettReduction.rbarrett_red_correct. -Ltac solve_rbarrett_red_nocache := solve_rop_nocache BarrettReduction.rbarrett_red_correct. - -Module Barrett256. - - Definition M := Eval lazy in (2^256-2^224+2^192+2^96-1). - Definition machine_wordsize := 256. - - Derive barrett_red256 - SuchThat (BarrettReduction.rbarrett_red_correctT M machine_wordsize barrett_red256) - As barrett_red256_correct. - Proof. Time solve_rbarrett_red machine_wordsize. Time Qed. - - Definition muLow := Eval lazy in (2 ^ (2 * machine_wordsize) / M) mod (2^machine_wordsize). - Definition barrett_red256_prefancy' := PreFancy.of_Expr machine_wordsize [M; muLow] barrett_red256. - - Derive barrett_red256_prefancy - SuchThat (barrett_red256_prefancy = barrett_red256_prefancy' type.interp) - As barrett_red256_prefancy_eq. - Proof. lazy - [type.interp]; reflexivity. Qed. - - Lemma barrett_reduce_correct_specialized : - forall (xLow xHigh : Z), - 0 <= xLow < 2 ^ machine_wordsize -> - 0 <= xHigh < M -> - BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M. - Proof. - intros. - apply BarrettReduction.barrett_reduce_correct; cbv [machine_wordsize M muLow] in *; - try omega; - try match goal with - | |- context [weight] => intros; cbv [weight]; autorewrite with zsimplify; auto using Z.pow_mul_r with omega - end; lazy; try split; congruence. - Qed. - - (* Note: If this is not factored out, then for some reason Qed takes forever in barrett_red256_correct_full. *) - Lemma barrett_red256_correct_proj2 : - forall xy : type.interp (type.prod type.Z type.Z), - ZRange.type.option.is_bounded_by - (t:=type.prod type.Z type.Z) - (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) - xy = true -> - expr.Interp (@ident.interp) barrett_red256 xy = app_curried (t:=type.arrow (type.prod type.Z type.Z) type.Z) (fun xy => BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 (fst xy) (snd xy)) xy. - Proof. intros; destruct (barrett_red256_correct xy); assumption. Qed. - Lemma barrett_red256_correct_proj2' : - forall x y : Z, - ZRange.type.option.is_bounded_by - (t:=type.prod type.Z type.Z) - (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) - (x, y) = true -> - expr.Interp (@ident.interp) barrett_red256 (x, y) = BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 x y. - Proof. intros; rewrite barrett_red256_correct_proj2 by assumption; unfold app_curried; exact eq_refl. Qed. - - Lemma barrett_red256_correct_full : - forall (xLow xHigh : Z), - 0 <= xLow < 2 ^ machine_wordsize -> - 0 <= xHigh < M -> - expr.interp (@ident.interp) (barrett_red256 type.interp) (xLow, xHigh) = (xLow + 2 ^ machine_wordsize * xHigh) mod M. - Proof. - intros. - rewrite <-barrett_reduce_correct_specialized by assumption. - rewrite <-barrett_red256_correct_proj2'. - { cbv [expr.Interp type.uncurried_domain type.uncurry type.final_codomain]. - reflexivity. } - { cbn. rewrite !andb_true_iff. cbv [machine_wordsize M] in *. - cbn in *. repeat split; apply Z.leb_le; omega. } - Qed. - - Import PreFancy.Tactics. (* for ok_expr_step *) - Lemma barrett_red256_prefancy_correct : - forall xLow xHigh dummy_arrow, - 0 <= xLow < 2 ^ machine_wordsize -> - 0 <= xHigh < M -> - @PreFancy.interp machine_wordsize (PreFancy.interp_cast_mod machine_wordsize) type.Z (barrett_red256_prefancy (xLow, xHigh) dummy_arrow) = (xLow + 2 ^ machine_wordsize * xHigh) mod M. - Proof. - intros. rewrite barrett_red256_prefancy_eq; cbv [barrett_red256_prefancy']. - erewrite PreFancy.of_Expr_correct. - { apply barrett_red256_correct_full; try assumption; reflexivity. } - { reflexivity. } - { lazy; reflexivity. } - { lazy; reflexivity. } - { repeat constructor. } - { cbv [In M muLow]; intros; intuition; subst; cbv; congruence. } - { let r := (eval compute in (2 ^ machine_wordsize)) in - replace (2^machine_wordsize) with r in * by reflexivity. - cbv [M muLow machine_wordsize] in *. - assert (lower r[0~>1] = 0) by reflexivity. - repeat (ok_expr_step; [ ]). - ok_expr_step. - lazy; congruence. - constructor. - constructor. } - { lazy. omega. } - Qed. - - Definition barrett_red256_fancy' (xLow xHigh RegMuLow RegMod RegZero error : positive) := - Fancy.of_Expr 3%positive - (fun z => if z =? muLow then Some RegMuLow else if z =? M then Some RegMod else if z =? 0 then Some RegZero else None) - [M; muLow] - barrett_red256 - (xLow, xHigh)%positive - (fun _ _ => tt) - error. - Derive barrett_red256_fancy - SuchThat (forall xLow xHigh RegMuLow RegMod RegZero, - barrett_red256_fancy xLow xHigh RegMuLow RegMod RegZero = barrett_red256_fancy' xLow xHigh RegMuLow RegMod RegZero) - As barrett_red256_fancy_eq. - Proof. - intros. - lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB Fancy.SUBC - Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU - Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM]. - reflexivity. - Qed. - - Import Fancy.Registers. - - Definition barrett_red256_alloc' xLow xHigh RegMuLow := - fun errorP errorR => - Fancy.allocate register - positive Pos.eqb - errorR - (barrett_red256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP) - [r2;r3;r4;r5;r6;r7;r8;r9;r10;r5;r11;r6;r12;r13;r14;r15;r16;r17;r18;r19;r20;r21;r22;r23;r24;r25;r26;r27;r28;r29] - (fun n => if n =? 1000 then xLow - else if n =? 1001 then xHigh - else if n =? 1002 then RegMuLow - else if n =? 1003 then RegMod - else if n =? 1004 then RegZero - else errorR). - Derive barrett_red256_alloc - SuchThat (barrett_red256_alloc = barrett_red256_alloc') - As barrett_red256_alloc_eq. - Proof. - intros. - cbv [barrett_red256_alloc' barrett_red256_fancy]. - cbn. subst barrett_red256_alloc. - reflexivity. - Qed. - - Set Printing Depth 1000. - Import ProdEquiv. - - Local Ltac solve_bounds := - match goal with - | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega - | _ => assumption - end. - - Lemma barrett_red256_alloc_equivalent errorP errorR cc_start_state start_context : - forall x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 extra_reg, - NoDup [x; xHigh; RegMuLow; scratchp1; scratchp2; scratchp3; scratchp4; scratchp5; extra_reg; RegMod; RegZero] -> - 0 <= start_context x < 2^machine_wordsize -> - 0 <= start_context xHigh < 2^machine_wordsize -> - 0 <= start_context RegMuLow < 2^machine_wordsize -> - ProdEquiv.interp256 (barrett_red256_alloc r0 r1 r30 errorP errorR) cc_start_state - (fun r => if reg_eqb r r0 - then start_context x - else if reg_eqb r r1 - then start_context xHigh - else if reg_eqb r r30 - then start_context RegMuLow - else start_context r) - = ProdEquiv.interp256 (Prod.MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5) cc_start_state start_context. - Proof. - intros. - let r := eval compute in (2^machine_wordsize) in - replace (2^machine_wordsize) with r in * by reflexivity. - cbv [Prod.MulMod barrett_red256_alloc]. - - (* Extract proofs that no registers are equal to each other *) - repeat match goal with - | H : NoDup _ |- _ => inversion H; subst; clear H - | H : ~ In _ _ |- _ => cbv [In] in H - | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H - | H : ~ False |- _ => clear H - end. - - step_both_sides. - - (* TODO: To prove equivalence between these two, we need to either relocate the RSHI instructions so they're in the same places or use instruction commutativity to push them down. *) - - Admitted. - - Import Fancy_PreFancy_Equiv. - - Definition interp_equivZZ_256 {s} := - @interp_equivZZ s 256 ltac:(cbv; congruence) 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). - Definition interp_equivZ_256 {s} := - @interp_equivZ s 256 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). - - Local Ltac simplify_op_equiv start_ctx := - cbn - [Fancy.spec PreFancy.interp_ident Fancy.cc_spec Z.shiftl]; - repeat match goal with H : start_ctx _ = _ |- _ => rewrite H end; - cbv - [ - Z.rshi Z.cc_m Fancy.CC.cc_m - Z.add_with_get_carry_full Z.add_get_carry_full - Z.sub_get_borrow_full Z.sub_with_get_borrow_full - Z.le Z.lt Z.ltb Z.leb Z.geb Z.eqb Z.land Z.shiftr Z.shiftl - Z.add Z.mul Z.div Z.sub Z.modulo Z.testbit Z.pow Z.ones - fst snd]; cbn [fst snd]; - try (replace (2 ^ (256 / 2) - 1) with (Z.ones 128) by reflexivity; rewrite !Z.land_ones by omega); - autorewrite with to_div_mod; rewrite ?Z.mod_mod, <-?Z.testbit_spec' by omega; - let r := (eval compute in (2 ^ 256)) in - replace (2^256) with r in * by reflexivity; - repeat match goal with - | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by apply H - | |- context [?x <? 0] => rewrite (proj2 (Z.ltb_ge x 0)) by (break_match; Z.zero_bounds) - | _ => rewrite Z.mod_small with (b:=2) by (break_match; omega) - | |- context [ (if Z.testbit ?a ?n then 1 else 0) + ?b + ?c] => - replace ((if Z.testbit a n then 1 else 0) + b + c) with (b + c + (if Z.testbit a n then 1 else 0)) by ring - end. - - Local Ltac solve_nonneg ctx := - match goal with x := (Fancy.spec _ _ _) |- _ => subst x end; - simplify_op_equiv ctx; Z.zero_bounds. - - Local Ltac generalize_result := - let v := fresh "v" in intro v; generalize v; clear v; intro v. - - Local Ltac generalize_result_nonneg ctx := - let v := fresh "v" in - let v_nonneg := fresh "v_nonneg" in - intro v; assert (0 <= v) as v_nonneg; [solve_nonneg ctx |generalize v v_nonneg; clear v v_nonneg; intros v v_nonneg]. - - Local Ltac step ctx := - match goal with - | |- Fancy.interp _ _ _ (Fancy.Instr (Fancy.ADD _) _ _ (Fancy.Instr (Fancy.ADDC _) _ _ _)) _ _ = _ => - apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result_nonneg ctx] - | _ => apply interp_equivZ_256; [simplify_op_equiv ctx | generalize_result] - | _ => apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result] - end. - - Lemma prod_barrett_red256_correct : - forall (cc_start_state : Fancy.CC.state) (* starting carry flags *) - (start_context : register -> Z) (* starting register values *) - (x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 extra_reg : register), (* registers to use in computation *) - NoDup [x; xHigh; RegMuLow; scratchp1; scratchp2; scratchp3; scratchp4; scratchp5; extra_reg; RegMod; RegZero] -> (* registers are unique *) - 0 <= start_context x < 2^machine_wordsize -> - 0 <= start_context xHigh < M -> - start_context RegMuLow = muLow -> - start_context RegMod = M -> - start_context RegZero = 0 -> - cc_start_state.(Fancy.CC.cc_m) = (Z.cc_m (2^256) (start_context xHigh) =? 1) -> - let X := start_context x + 2^machine_wordsize * start_context xHigh in - ProdEquiv.interp256 (Prod.MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5) cc_start_state start_context = X mod M. - Proof. - intros. subst X. - assert (0 <= start_context xHigh < 2^machine_wordsize) by (cbv [M] in *; cbn; omega). - let r := (eval compute in (2 ^ machine_wordsize)) in - replace (2^machine_wordsize) with r in * by reflexivity. - cbv [M muLow] in *. - - rewrite <-barrett_red256_prefancy_correct with (dummy_arrow := fun s d _ => DefaultValue.type.default) by auto. - rewrite <-barrett_red256_alloc_equivalent with (errorR := RegZero) (errorP := 1%positive) (extra_reg:=extra_reg) - by (cbn in *; auto with omega). - cbv [ProdEquiv.interp256]. - let r := (eval compute in (2 ^ 256)) in - replace (2^256) with r in * by reflexivity. - cbv [barrett_red256_alloc barrett_red256_prefancy]. - - step start_context. - { - match goal with H : Fancy.CC.cc_m _ = _ |- _ => rewrite H end. - match goal with |- context [Z.cc_m ?s ?x] => - pose proof (Z.cc_m_small s x ltac:(reflexivity) ltac:(omega)); - let H := fresh in - assert (Z.cc_m s x = 1 \/ Z.cc_m s x = 0) as H by omega; - destruct H as [H | H]; rewrite H in * - end; break_innermost_match; Z.ltb_to_lt; try congruence. } - apply interp_equivZ_256; [ simplify_op_equiv start_context | ]. (* apply manually instead of using [step] to allow a custom bounds proof *) - { rewrite Z.rshi_correct by omega. - autorewrite with zsimplify_fast. - rewrite Z.shiftr_div_pow2 by omega. - reflexivity. } - - (* Special case to remember the bound for the output of RSHI *) - let v := fresh "v" in - let v_bound := fresh "v_bound" in - intro v; assert (0 <= v <= 1) as v_bound; [ |generalize v v_bound; clear v v_bound; intros v v_bound]. - { solve_nonneg start_context. autorewrite with zsimplify_fast. - rewrite Z.shiftr_div_pow2 by omega. - rewrite Z.mod_pull_div by omega. - rewrite Z.mod_small by (cbn; omega). - split; [Z.zero_bounds|]. - apply Z.lt_succ_r. - apply Z.div_lt_upper_bound; cbn; omega. } - - step start_context. - { rewrite Z.rshi_correct by omega. - rewrite Z.shiftr_div_pow2 by omega. - repeat (f_equal; try ring). } - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; - [ rewrite Z.mod_small with (b:=2) by (rewrite Z.mod_small by omega; omega); (* Here we make use of the bound of RSHI *) - reflexivity - | rewrite Z.mod_small with (b:=2) by (rewrite Z.mod_small by omega; omega); (* Here we make use of the bound of RSHI *) - reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context. - { rewrite Z.rshi_correct by omega. - rewrite Z.shiftr_div_pow2 by omega. - repeat (f_equal; try ring). } - - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - - step start_context. - { reflexivity. } - { autorewrite with zsimplify_fast. - match goal with |- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - rewrite <-Z.testbit_neg_eq_if with (n:=256) by (cbn; omega). - reflexivity. } - step start_context. - { reflexivity. } - { autorewrite with zsimplify_fast. - rewrite Z.mod_small with (a:=(if (if _ <? 0 then true else _) then _ else _)) (b:=2) by (break_innermost_match; omega). - match goal with |- context [?a - ?b - ?c] => replace (a - b - c) with (a - (b + c)) by ring end. - match goal with |- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - rewrite <-Z.testbit_neg_eq_if with (n:=256) by (break_innermost_match; cbn; omega). - reflexivity. } - step start_context. - { rewrite Z.bit0_eqb. - match goal with |- context [(?x mod ?m) &' 1] => - replace (x mod m) with (x &' Z.ones 256) by (rewrite Z.land_ones by omega; reflexivity) end. - rewrite <-Z.land_assoc. - rewrite Z.land_ones with (n:=1) by omega. - cbn. - match goal with |- context [?x mod 2] => - let H := fresh in - assert (x mod 2 = 0 \/ x mod 2 = 1) as H - by (pose proof (Z.mod_pos_bound x 2 ltac:(omega)); omega); - destruct H as [H | H]; rewrite H - end; reflexivity. } - step start_context. - { reflexivity. } - { autorewrite with zsimplify_fast. - repeat match goal with |- context [?x mod ?m] => unique pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - rewrite <-Z.testbit_neg_eq_if with (n:=256) by (cbn; omega). - reflexivity. } - step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. - reflexivity. - Qed. - - Import PrintingNotations. - Set Printing Width 1000. - Open Scope expr_scope. - Print barrett_red256. - (* -barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, - expr_let x0 := SELM (x₂, 0, 26959946667150639793205513449348445388433292963828203772348655992835) in - expr_let x1 := RSHI (0, x₂, 255) in - expr_let x2 := RSHI (x₂, x₁, 255) in - expr_let x3 := 79228162514264337589248983038 *₂₅₆ (uint128)(x2 >> 128) in - expr_let x4 := 79228162514264337589248983038 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in - expr_let x5 := 340282366841710300930663525764514709507 *₂₅₆ (uint128)(x2 >> 128) in - expr_let x6 := 340282366841710300930663525764514709507 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in - expr_let x7 := ADD_256 ((uint256)(((uint128)(x5) & 340282366920938463463374607431768211455) << 128), x6) in - expr_let x8 := ADDC_256 (x7₂, (uint128)(x5 >> 128), x3) in - expr_let x9 := ADD_256 ((uint256)(((uint128)(x4) & 340282366920938463463374607431768211455) << 128), x7₁) in - expr_let x10 := ADDC_256 (x9₂, (uint128)(x4 >> 128), x8₁) in - expr_let x11 := ADD_256 (x2, x10₁) in - expr_let x12 := ADDC_128 (x11₂, 0, x1) in - expr_let x13 := ADD_256 (x0, x11₁) in - expr_let x14 := ADDC_128 (x13₂, 0, x12₁) in - expr_let x15 := RSHI (x14₁, x13₁, 1) in - expr_let x16 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x15 >> 128) in - expr_let x17 := 79228162514264337593543950335 *₂₅₆ (uint128)(x15 >> 128) in - expr_let x18 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x15) & 340282366920938463463374607431768211455) in - expr_let x19 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x15) & 340282366920938463463374607431768211455) in - expr_let x20 := ADD_256 ((uint256)(((uint128)(x18) & 340282366920938463463374607431768211455) << 128), x19) in - expr_let x21 := ADDC_256 (x20₂, (uint128)(x18 >> 128), x16) in - expr_let x22 := ADD_256 ((uint256)(((uint128)(x17) & 340282366920938463463374607431768211455) << 128), x20₁) in - expr_let x23 := ADDC_256 (x22₂, (uint128)(x17 >> 128), x21₁) in - expr_let x24 := SUB_256 (x₁, x22₁) in - expr_let x25 := SUBB_256 (x24₂, x₂, x23₁) in - expr_let x26 := SELL (x25₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in - expr_let x27 := SUB_256 (x24₁, x26) in - ADDM (x27₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) - : Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z)) - *) - - Import PreFancy. - Import PreFancy.Notations. - Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951). - Local Notation "'RegMuLow'" := (Straightline.expr.Primitive (t:=type.Z) 26959946667150639793205513449348445388433292963828203772348655992835). - Print barrett_red256_prefancy. - (* - selm@(y, $x₂, RegZero, RegMuLow); - rshi@(y0, RegZero, $x₂,255); - rshi@(y1, $x₂, $x₁,255); - mulhh@(y2, RegMuLow, $y1); - mulhl@(y3, RegMuLow, $y1); - mullh@(y4, RegMuLow, $y1); - mulll@(y5, RegMuLow, $y1); - add@(y6, $y5, $y4, 128); - addc@(y7, carry{$y6}, $y2, $y4, -128); - add@(y8, $y6, $y3, 128); - addc@(y9, carry{$y8}, $y7, $y3, -128); - add@(y10, $y1, $y9, 0); - addc@(y11, carry{$y10}, RegZero, $y0, 0); #128 - add@(y12, $y, $y10, 0); - addc@(y13, carry{$y12}, RegZero, $y11, 0); #128 - rshi@(y14, $y13, $y12,1); - mulhh@(y15, RegMod, $y14); - mullh@(y16, RegMod, $y14); - mulhl@(y17, RegMod, $y14); - mulll@(y18, RegMod, $y14); - add@(y19, $y18, $y17, 128); - addc@(y20, carry{$y19}, $y15, $y17, -128); - add@(y21, $y19, $y16, 128); - addc@(y22, carry{$y21}, $y20, $y16, -128); - sub@(y23, $x₁, $y21, 0); - subb@(y24, carry{$y23}, $x₂, $y22, 0); - sell@(y25, $y24, RegZero, RegMod); - sub@(y26, $y23, $y25, 0); - addm@(y27, $y26, RegZero, RegMod); - ret $y27 - *) -End Barrett256. - -Module SaturatedSolinas. - Section MulMod. - Context (s : Z) (c : list (Z * Z)) - (s_nz : s <> 0) (modulus_nz : s - Associational.eval c <> 0). - Context (log2base : Z) (log2base_pos : 0 < log2base) - (n nreductions : nat) (n_nz : n <> 0%nat). - - Let weight := weight log2base 1. - Let props : @weight_properties weight := wprops log2base 1 ltac:(omega). - Local Lemma base_nz : 2 ^ log2base <> 0. Proof. auto with zarith. Qed. - - Derive mulmod - SuchThat (forall (f g : list Z) - (Hf : length f = n) - (Hg : length g = n), - (eval weight n (fst (mulmod f g)) + weight n * (snd (mulmod f g))) mod (s - Associational.eval c) - = (eval weight n f * eval weight n g) mod (s - Associational.eval c)) - As eval_mulmod. - Proof. - intros. - rewrite <-Rows.eval_mulmod with (base:=2^log2base) (s:=s) (c:=c) (nreductions:=nreductions) by auto using base_nz. - eapply f_equal2; [|trivial]. - (* expand_lists (). *) (* uncommenting this line removes some unused multiplications but also inlines a bunch of carry stuff at the end *) - subst mulmod. reflexivity. - Qed. - Definition mulmod' := fun x y => fst (mulmod x y). - End MulMod. - - Derive mulmod_gen - SuchThat (forall (log2base s : Z) (c : list (Z * Z)) (n nreductions : nat) - (f g : list Z), - Interp (t:=type.reify_type_of mulmod') - mulmod_gen s c log2base n nreductions f g - = mulmod' s c log2base n nreductions f g) - As mulmod_gen_correct. - Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. - Module Export ReifyHints. - Global Hint Extern 1 (_ = mulmod' _ _ _ _ _ _ _) => simple apply mulmod_gen_correct : reify_gen_cache. - End ReifyHints. - - Section rmulmod. - Context (s : Z) - (c : list (Z * Z)) - (machine_wordsize : Z). - - Definition relax_zrange_of_machine_wordsize - := relax_zrange_gen [1; machine_wordsize]%Z. - - Let n : nat := Z.to_nat (Qceiling (Z.log2_up s / machine_wordsize)). - (* Number of reductions is calculated as follows : - Let i be the highest limb index of c. Then, each reduction - decreases the number of extra limbs by (n-i). So, to go from - the n extra limbs we have post-multiplication down to 0, we - need ceil (n / (n - i)) reductions. *) - Let nreductions : nat := - let i := fold_right Z.max 0 (map (fun t => Z.log2 (fst t) / machine_wordsize) c) in - Z.to_nat (Qceiling (Z.of_nat n / (Z.of_nat n - i))). - Let relax_zrange := relax_zrange_of_machine_wordsize. - Let bound := Some r[0 ~> (2^machine_wordsize - 1)]%zrange. - Let boundsn : list (ZRange.type.option.interp type.Z) - := repeat bound n. - - Definition check_args {T} (res : Pipeline.ErrorT T) - : Pipeline.ErrorT T - := if (negb (0 <? s - Associational.eval c))%Z - then Pipeline.Error (Pipeline.Value_not_lt "s - Associational.eval c ≤ 0" 0 (s - Associational.eval c)) - else if (s =? 0)%Z - then Pipeline.Error (Pipeline.Values_not_provably_distinct "s ≠ 0" s 0) - else if (n =? 0)%nat - then Pipeline.Error (Pipeline.Values_not_provably_distinct "n ≠ 0" n 0%nat) - else if (negb (0 <? machine_wordsize)) - then Pipeline.Error (Pipeline.Value_not_lt "0 < machine_wordsize" 0 machine_wordsize) - else res. - - Notation BoundsPipeline rop in_bounds out_bounds - := (Pipeline.BoundsPipeline - (*false*) true - relax_zrange - rop%Expr in_bounds out_bounds). - - Notation BoundsPipeline_correct in_bounds out_bounds op - := (fun rv (rop : Expr (type.reify_type_of op)) Hrop - => @Pipeline.BoundsPipeline_correct_trans - (*false*) true - relax_zrange - (relax_zrange_gen_good _) - _ - rop - in_bounds - out_bounds - op - Hrop rv) - (only parsing). - - Definition rmulmod_correct - := BoundsPipeline_correct - (Some boundsn, Some boundsn) - (Some boundsn) - (mulmod' s c machine_wordsize n nreductions). - - Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). - Definition rmulmod_correctT rv : Prop - := type_of_strip_3arrow (@rmulmod_correct rv). - End rmulmod. -End SaturatedSolinas. - -Ltac solve_rmulmod := solve_rop SaturatedSolinas.rmulmod_correct. -Ltac solve_rmulmod_nocache := solve_rop_nocache SaturatedSolinas.rmulmod_correct. - -Module P192_64. - Definition s := 2^192. - Definition c := [(2^64, 1); (1,1)]. - Definition machine_wordsize := 64. - - Derive mulmod - SuchThat (SaturatedSolinas.rmulmod_correctT s c machine_wordsize mulmod) - As mulmod_correct. - Proof. Time solve_rmulmod machine_wordsize. Time Qed. - - Import PrintingNotations. - Open Scope expr_scope. - Set Printing Width 100000. - Set Printing Depth 100000. - - Local Notation "'mul64' '(' x ',' y ')'" := - (Z.cast2 (uint64, _)%core @@ (Z.mul_split_concrete 18446744073709551616 @@ (x , y)))%expr (at level 50) : expr_scope. - Local Notation "'add64' '(' x ',' y ')'" := - (Z.cast2 (uint64, bool)%core @@ (Z.add_get_carry_concrete 18446744073709551616 @@ (x , y)))%expr (at level 50) : expr_scope. - Local Notation "'adc64' '(' c ',' x ',' y ')'" := - (Z.cast2 (uint64, bool)%core @@ (Z.add_with_get_carry_concrete 18446744073709551616 @@ (c, x , y)))%expr (at level 50) : expr_scope. - Local Notation "'adx64' '(' c ',' x ',' y ')'" := - (Z.cast bool @@ (Z.add_with_carry @@ (c, x , y)))%expr (at level 50) : expr_scope. - - Print mulmod. -(* -mulmod = fun var : type -> Type => λ x : var (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z))%ctype, - expr_let x0 := mul64 ((uint64)(x₁[[2]]), (uint64)(x₂[[2]])) in - expr_let x1 := mul64 ((uint64)(x₁[[2]]), (uint64)(x₂[[1]])) in - expr_let x2 := mul64 ((uint64)(x₁[[2]]), (uint64)(x₂[[0]])) in - expr_let x3 := mul64 ((uint64)(x₁[[1]]), (uint64)(x₂[[2]])) in - expr_let x4 := mul64 ((uint64)(x₁[[1]]), (uint64)(x₂[[1]])) in - expr_let x5 := mul64 ((uint64)(x₁[[1]]), (uint64)(x₂[[0]])) in - expr_let x6 := mul64 ((uint64)(x₁[[0]]), (uint64)(x₂[[2]])) in - expr_let x7 := mul64 ((uint64)(x₁[[0]]), (uint64)(x₂[[1]])) in - expr_let x8 := mul64 ((uint64)(x₁[[0]]), (uint64)(x₂[[0]])) in - expr_let x9 := add64 (x0₂, x8₂) in - expr_let x10 := adc64 (x9₂, 0, x7₂) in - expr_let x11 := add64 (x0₁, x9₁) in - expr_let x12 := adc64 (x11₂, 0, x10₁) in - expr_let x13 := add64 (x1₂, x11₁) in - expr_let x14 := adc64 (x13₂, 0, x12₁) in - expr_let x15 := add64 (x3₂, x13₁) in - expr_let x16 := adc64 (x15₂, x0₂, x14₁) in - expr_let x17 := add64 (x1₁, x15₁) in - expr_let x18 := adc64 (x17₂, x0₁, x16₁) in - expr_let x19 := add64 (x0₂, x8₁) in - expr_let x20 := adc64 (x19₂, x2₂, x17₁) in - expr_let x21 := adc64 (x20₂, x1₂, x18₁) in - expr_let x22 := add64 (x1₁, x19₁) in - expr_let x23 := adc64 (x22₂, x3₁, x20₁) in - expr_let x24 := adc64 (x23₂, x3₂, x21₁) in - expr_let x25 := add64 (x2₂, x22₁) in - expr_let x26 := adc64 (x25₂, x4₂, x23₁) in - expr_let x27 := adc64 (x26₂, x2₁, x24₁) in - expr_let x28 := add64 (x3₁, x25₁) in - expr_let x29 := adc64 (x28₂, x6₂, x26₁) in - expr_let x30 := adc64 (x29₂, x4₁, x27₁) in - expr_let x31 := add64 (x4₂, x28₁) in - expr_let x32 := adc64 (x31₂, x5₁, x29₁) in - expr_let x33 := adc64 (x32₂, x5₂, x30₁) in - expr_let x34 := add64 (x6₂, x31₁) in - expr_let x35 := adc64 (x34₂, x7₁, x32₁) in - expr_let x36 := adc64 (x35₂, x6₁, x33₁) in - x34₁ :: x35₁ :: x36₁ :: [] - : Expr (type.uncurry (type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z))) -*) - -End P192_64. - -Module P192_32. - Definition s := 2^192. - Definition c := [(2^64, 1); (1,1)]. - Definition machine_wordsize := 32. - - Derive mulmod - SuchThat (SaturatedSolinas.rmulmod_correctT s c machine_wordsize mulmod) - As mulmod_correct. - Proof. Time solve_rmulmod machine_wordsize. Time Qed. - - Import PrintingNotations. - Open Scope expr_scope. - Set Printing Width 100000. - Set Printing Depth 100000. - - Local Notation "'mul32' '(' x ',' y ')'" := - (Z.cast2 (uint32, _)%core @@ (Z.mul_split_concrete 4294967296 @@ (x , y)))%expr (at level 50) : expr_scope. - Local Notation "'add32' '(' x ',' y ')'" := - (Z.cast2 (uint32, bool)%core @@ (Z.add_get_carry_concrete 4294967296 @@ (x , y)))%expr (at level 50) : expr_scope. - Local Notation "'adc32' '(' c ',' x ',' y ')'" := - (Z.cast2 (uint32, bool)%core @@ (Z.add_with_get_carry_concrete 4294967296 @@ (c, x , y)))%expr (at level 50) : expr_scope. - - Print mulmod. - (* -mulmod = fun var : type -> Type => λ x : var (type.list (type.type_primitive type.Z) * type.list (type.type_primitive type.Z))%ctype, - expr_let x0 := mul32 ((uint32)(x₁[[5]]), (uint32)(x₂[[5]])) in - expr_let x1 := mul32 ((uint32)(x₁[[5]]), (uint32)(x₂[[4]])) in - expr_let x2 := mul32 ((uint32)(x₁[[5]]), (uint32)(x₂[[3]])) in - expr_let x3 := mul32 ((uint32)(x₁[[5]]), (uint32)(x₂[[2]])) in - expr_let x4 := mul32 ((uint32)(x₁[[5]]), (uint32)(x₂[[1]])) in - expr_let x5 := mul32 ((uint32)(x₁[[5]]), (uint32)(x₂[[0]])) in - expr_let x6 := mul32 ((uint32)(x₁[[4]]), (uint32)(x₂[[5]])) in - expr_let x7 := mul32 ((uint32)(x₁[[4]]), (uint32)(x₂[[4]])) in - expr_let x8 := mul32 ((uint32)(x₁[[4]]), (uint32)(x₂[[3]])) in - expr_let x9 := mul32 ((uint32)(x₁[[4]]), (uint32)(x₂[[2]])) in - expr_let x10 := mul32 ((uint32)(x₁[[4]]), (uint32)(x₂[[1]])) in - expr_let x11 := mul32 ((uint32)(x₁[[4]]), (uint32)(x₂[[0]])) in - expr_let x12 := mul32 ((uint32)(x₁[[3]]), (uint32)(x₂[[5]])) in - expr_let x13 := mul32 ((uint32)(x₁[[3]]), (uint32)(x₂[[4]])) in - expr_let x14 := mul32 ((uint32)(x₁[[3]]), (uint32)(x₂[[3]])) in - expr_let x15 := mul32 ((uint32)(x₁[[3]]), (uint32)(x₂[[2]])) in - expr_let x16 := mul32 ((uint32)(x₁[[3]]), (uint32)(x₂[[1]])) in - expr_let x17 := mul32 ((uint32)(x₁[[3]]), (uint32)(x₂[[0]])) in - expr_let x18 := mul32 ((uint32)(x₁[[2]]), (uint32)(x₂[[5]])) in - expr_let x19 := mul32 ((uint32)(x₁[[2]]), (uint32)(x₂[[4]])) in - expr_let x20 := mul32 ((uint32)(x₁[[2]]), (uint32)(x₂[[3]])) in - expr_let x21 := mul32 ((uint32)(x₁[[2]]), (uint32)(x₂[[2]])) in - expr_let x22 := mul32 ((uint32)(x₁[[2]]), (uint32)(x₂[[1]])) in - expr_let x23 := mul32 ((uint32)(x₁[[2]]), (uint32)(x₂[[0]])) in - expr_let x24 := mul32 ((uint32)(x₁[[1]]), (uint32)(x₂[[5]])) in - expr_let x25 := mul32 ((uint32)(x₁[[1]]), (uint32)(x₂[[4]])) in - expr_let x26 := mul32 ((uint32)(x₁[[1]]), (uint32)(x₂[[3]])) in - expr_let x27 := mul32 ((uint32)(x₁[[1]]), (uint32)(x₂[[2]])) in - expr_let x28 := mul32 ((uint32)(x₁[[1]]), (uint32)(x₂[[1]])) in - expr_let x29 := mul32 ((uint32)(x₁[[1]]), (uint32)(x₂[[0]])) in - expr_let x30 := mul32 ((uint32)(x₁[[0]]), (uint32)(x₂[[5]])) in - expr_let x31 := mul32 ((uint32)(x₁[[0]]), (uint32)(x₂[[4]])) in - expr_let x32 := mul32 ((uint32)(x₁[[0]]), (uint32)(x₂[[3]])) in - expr_let x33 := mul32 ((uint32)(x₁[[0]]), (uint32)(x₂[[2]])) in - expr_let x34 := mul32 ((uint32)(x₁[[0]]), (uint32)(x₂[[1]])) in - expr_let x35 := mul32 ((uint32)(x₁[[0]]), (uint32)(x₂[[0]])) in - expr_let x36 := add32 (x0₁, x34₂) in - expr_let x37 := adc32 (x36₂, 0, x33₂) in - expr_let x38 := adc32 (x37₂, 0, x32₂) in - expr_let x39 := adc32 (x38₂, 0, x31₂) in - expr_let x40 := add32 (x1₂, x36₁) in - expr_let x41 := adc32 (x40₂, 0, x37₁) in - expr_let x42 := adc32 (x41₂, 0, x38₁) in - expr_let x43 := adc32 (x42₂, 0, x39₁) in - expr_let x44 := add32 (x6₂, x40₁) in - expr_let x45 := adc32 (x44₂, 0, x41₁) in - expr_let x46 := adc32 (x45₂, 0, x42₁) in - expr_let x47 := adc32 (x46₂, 0, x43₁) in - expr_let x48 := add32 (x2₁, x44₁) in - expr_let x49 := adc32 (x48₂, 0, x45₁) in - expr_let x50 := adc32 (x49₂, 0, x46₁) in - expr_let x51 := adc32 (x50₂, 0, x47₁) in - expr_let x52 := add32 (x3₂, x48₁) in - expr_let x53 := adc32 (x52₂, x0₂, x49₁) in - expr_let x54 := adc32 (x53₂, 0, x50₁) in - expr_let x55 := adc32 (x54₂, 0, x51₁) in - expr_let x56 := add32 (x7₁, x52₁) in - expr_let x57 := adc32 (x56₂, x1₁, x53₁) in - expr_let x58 := adc32 (x57₂, 0, x54₁) in - expr_let x59 := adc32 (x58₂, 0, x55₁) in - expr_let x60 := add32 (x8₂, x56₁) in - expr_let x61 := adc32 (x60₂, x2₂, x57₁) in - expr_let x62 := adc32 (x61₂, 0, x58₁) in - expr_let x63 := adc32 (x62₂, 0, x59₁) in - expr_let x64 := add32 (x12₁, x60₁) in - expr_let x65 := adc32 (x64₂, x6₁, x61₁) in - expr_let x66 := adc32 (x65₂, x0₁, x62₁) in - expr_let x67 := adc32 (x66₂, 0, x63₁) in - expr_let x68 := add32 (x13₂, x64₁) in - expr_let x69 := adc32 (x68₂, x7₂, x65₁) in - expr_let x70 := adc32 (x69₂, x1₂, x66₁) in - expr_let x71 := adc32 (x70₂, 0, x67₁) in - expr_let x72 := add32 (x18₂, x68₁) in - expr_let x73 := adc32 (x72₂, x12₂, x69₁) in - expr_let x74 := adc32 (x73₂, x6₂, x70₁) in - expr_let x75 := adc32 (x74₂, x0₂, x71₁) in - expr_let x76 := add32 (x4₁, x72₁) in - expr_let x77 := adc32 (x76₂, x3₁, x73₁) in - expr_let x78 := adc32 (x77₂, x2₁, x74₁) in - expr_let x79 := adc32 (x78₂, x1₁, x75₁) in - expr_let x80 := add32 (x0₁, x35₁) in - expr_let x81 := adc32 (x80₂, 0, x35₂) in - expr_let x82 := adc32 (x81₂, x5₂, x76₁) in - expr_let x83 := adc32 (x82₂, x4₂, x77₁) in - expr_let x84 := adc32 (x83₂, x3₂, x78₁) in - expr_let x85 := adc32 (x84₂, x2₂, x79₁) in - expr_let x86 := add32 (x1₂, x80₁) in - expr_let x87 := adc32 (x86₂, 0, x81₁) in - expr_let x88 := adc32 (x87₂, x9₁, x82₁) in - expr_let x89 := adc32 (x88₂, x8₁, x83₁) in - expr_let x90 := adc32 (x89₂, x7₁, x84₁) in - expr_let x91 := adc32 (x90₂, x6₁, x85₁) in - expr_let x92 := add32 (x6₂, x86₁) in - expr_let x93 := adc32 (x92₂, x0₂, x87₁) in - expr_let x94 := adc32 (x93₂, x10₂, x88₁) in - expr_let x95 := adc32 (x94₂, x9₂, x89₁) in - expr_let x96 := adc32 (x95₂, x8₂, x90₁) in - expr_let x97 := adc32 (x96₂, x7₂, x91₁) in - expr_let x98 := add32 (x4₁, x92₁) in - expr_let x99 := adc32 (x98₂, x3₁, x93₁) in - expr_let x100 := adc32 (x99₂, x14₁, x94₁) in - expr_let x101 := adc32 (x100₂, x13₁, x95₁) in - expr_let x102 := adc32 (x101₂, x12₁, x96₁) in - expr_let x103 := adc32 (x102₂, x12₂, x97₁) in - expr_let x104 := add32 (x5₂, x98₁) in - expr_let x105 := adc32 (x104₂, x4₂, x99₁) in - expr_let x106 := adc32 (x105₂, x15₂, x100₁) in - expr_let x107 := adc32 (x106₂, x14₂, x101₁) in - expr_let x108 := adc32 (x107₂, x13₂, x102₁) in - expr_let x109 := adc32 (x108₂, x5₁, x103₁) in - expr_let x110 := add32 (x9₁, x104₁) in - expr_let x111 := adc32 (x110₂, x8₁, x105₁) in - expr_let x112 := adc32 (x111₂, x19₁, x106₁) in - expr_let x113 := adc32 (x112₂, x18₁, x107₁) in - expr_let x114 := adc32 (x113₂, x18₂, x108₁) in - expr_let x115 := adc32 (x114₂, x10₁, x109₁) in - expr_let x116 := add32 (x10₂, x110₁) in - expr_let x117 := adc32 (x116₂, x9₂, x111₁) in - expr_let x118 := adc32 (x117₂, x20₂, x112₁) in - expr_let x119 := adc32 (x118₂, x19₂, x113₁) in - expr_let x120 := adc32 (x119₂, x11₁, x114₁) in - expr_let x121 := adc32 (x120₂, x11₂, x115₁) in - expr_let x122 := add32 (x14₁, x116₁) in - expr_let x123 := adc32 (x122₂, x13₁, x117₁) in - expr_let x124 := adc32 (x123₂, x24₁, x118₁) in - expr_let x125 := adc32 (x124₂, x24₂, x119₁) in - expr_let x126 := adc32 (x125₂, x16₁, x120₁) in - expr_let x127 := adc32 (x126₂, x15₁, x121₁) in - expr_let x128 := add32 (x15₂, x122₁) in - expr_let x129 := adc32 (x128₂, x14₂, x123₁) in - expr_let x130 := adc32 (x129₂, x25₂, x124₁) in - expr_let x131 := adc32 (x130₂, x17₁, x125₁) in - expr_let x132 := adc32 (x131₂, x17₂, x126₁) in - expr_let x133 := adc32 (x132₂, x16₂, x127₁) in - expr_let x134 := add32 (x19₁, x128₁) in - expr_let x135 := adc32 (x134₂, x18₁, x129₁) in - expr_let x136 := adc32 (x135₂, x30₂, x130₁) in - expr_let x137 := adc32 (x136₂, x22₁, x131₁) in - expr_let x138 := adc32 (x137₂, x21₁, x132₁) in - expr_let x139 := adc32 (x138₂, x20₁, x133₁) in - expr_let x140 := add32 (x20₂, x134₁) in - expr_let x141 := adc32 (x140₂, x19₂, x135₁) in - expr_let x142 := adc32 (x141₂, x23₁, x136₁) in - expr_let x143 := adc32 (x142₂, x23₂, x137₁) in - expr_let x144 := adc32 (x143₂, x22₂, x138₁) in - expr_let x145 := adc32 (x144₂, x21₂, x139₁) in - expr_let x146 := add32 (x24₁, x140₁) in - expr_let x147 := adc32 (x146₂, x24₂, x141₁) in - expr_let x148 := adc32 (x147₂, x28₁, x142₁) in - expr_let x149 := adc32 (x148₂, x27₁, x143₁) in - expr_let x150 := adc32 (x149₂, x26₁, x144₁) in - expr_let x151 := adc32 (x150₂, x25₁, x145₁) in - expr_let x152 := add32 (x25₂, x146₁) in - expr_let x153 := adc32 (x152₂, x29₁, x147₁) in - expr_let x154 := adc32 (x153₂, x29₂, x148₁) in - expr_let x155 := adc32 (x154₂, x28₂, x149₁) in - expr_let x156 := adc32 (x155₂, x27₂, x150₁) in - expr_let x157 := adc32 (x156₂, x26₂, x151₁) in - expr_let x158 := add32 (x30₂, x152₁) in - expr_let x159 := adc32 (x158₂, x34₁, x153₁) in - expr_let x160 := adc32 (x159₂, x33₁, x154₁) in - expr_let x161 := adc32 (x160₂, x32₁, x155₁) in - expr_let x162 := adc32 (x161₂, x31₁, x156₁) in - expr_let x163 := adc32 (x162₂, x30₁, x157₁) in - x158₁ :: x159₁ :: x160₁ :: x161₁ :: x162₁ :: x163₁ :: [] - : Expr (type.uncurry (type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z) -> type.list (type.type_primitive type.Z))) -*) - -End P192_32. - -(* TODO : Too slow! Many, many terms in this one. *) -(* -Module P256_32. - Definition s := 2^256. - Definition c := [(2^224, 1); (2^192, -1); (2^96, -1); (1,1)]. - Definition machine_wordsize := 32. - - Derive mulmod - SuchThat (SaturatedSolinas.rmulmod_correctT s c machine_wordsize mulmod) - As mulmod_correct. - Proof. Time solve_rmulmod machine_wordsize. Time Qed. - - Import PrintingNotations. - Open Scope expr_scope. - Set Printing Width 100000. - - Print mulmod. - -End P256_32. -*) - -Module MontgomeryReduction. - Section MontRed'. - Context (N R N' R' : Z). - Context (HN_range : 0 <= N < R) (HN'_range : 0 <= N' < R) (HN_nz : N <> 0) (R_gt_1 : R > 1) - (N'_good : Z.equiv_modulo R (N*N') (-1)) (R'_good: Z.equiv_modulo N (R*R') 1). - - Context (Zlog2R : Z) . - Let w : nat -> Z := weight Zlog2R 1. - Context (n:nat) (Hn_nz: n <> 0%nat) (n_good : Zlog2R mod Z.of_nat n = 0). - Context (R_big_enough : n <= Zlog2R) - (R_two_pow : 2^Zlog2R = R). - Let w_mul : nat -> Z := weight (Zlog2R / n) 1. - Context (nout : nat) (Hnout : nout = 2%nat). - - Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R n nout (fst lo_hi) N') 0 in - dlet_nd t1_t2 := (BaseConversion.widemul_inlined_reverse Zlog2R n nout N y) in - dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [fst lo_hi; snd lo_hi] t1_t2 in - dlet_nd y' := Z.zselect (snd sum_carry) 0 N in - dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in - Z.add_modulo (fst lo''_carry) 0 N. - - Local Lemma Hw : forall i, w i = R ^ Z.of_nat i. - Proof. - clear -R_big_enough R_two_pow; cbv [w weight]; intro. - autorewrite with zsimplify. - rewrite Z.pow_mul_r, R_two_pow by omega; reflexivity. - Qed. - - Local Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r, ?Z.pow_1_l in *. - Local Ltac solve_range := - repeat match goal with - | _ => progress change_weight - | |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega)) - | |- 0 <= _ => progress Z.zero_bounds - | |- 0 <= _ * _ < _ * _ => - split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ] - | _ => solve [auto] - | _ => omega - end. - - Local Lemma eval2 x y : eval w 2 [x;y] = x + R * y. - Proof. cbn. change_weight. ring. Qed. - - Hint Rewrite BaseConversion.widemul_inlined_reverse_correct BaseConversion.widemul_inlined_correct - using (autorewrite with widemul push_nth_default; solve [solve_range]) : widemul. - - Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) - (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): - montred' lo_hi = reduce_via_partial N R N' T. - Proof. - rewrite <-reduce_via_partial_alt_eq by nia. - cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. - rewrite Hlo, Hhi. - assert (0 <= (T mod R) * N' < w 2) by (solve_range). - - autorewrite with widemul. - rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega). - rewrite R_two_pow. - cbv [Rows.partition seq]. rewrite !eval2. - autorewrite with push_nth_default push_map. - autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. - change_weight. - - (* pull out value before last modular reduction *) - match goal with |- (if (?n <=? ?x)%Z then ?x - ?n else ?x) = (if (?n <=? ?y) then ?y - ?n else ?y)%Z => - let P := fresh "H" in assert (x = y) as P; [|rewrite P; reflexivity] end. - - autorewrite with zsimplify. - rewrite (Z.mul_comm (((T mod R) * N') mod R) N) in *. - break_match; try reflexivity; Z.ltb_to_lt; rewrite Z.div_small_iff in * by omega; - repeat match goal with - | _ => progress autorewrite with zsimplify_fast - | |- context [?x mod (R * R)] => - unique pose proof (Z.mod_pos_bound x (R * R)); - try rewrite (Z.mod_small x (R * R)) in * by Z.rewrite_mod_small_solver - | _ => omega - | _ => progress Z.rewrite_mod_small - end. - Qed. - - Lemma montred'_correct lo_hi T (HT_range: 0 <= T < R * N) - (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): montred' lo_hi = (T * R') mod N. - Proof. - erewrite montred'_eq by eauto. - apply Z.equiv_modulo_mod_small; auto using reduce_via_partial_correct. - replace 0 with (Z.min 0 (R-N)) by (apply Z.min_l; omega). - apply reduce_via_partial_in_range; omega. - Qed. - End MontRed'. - - Derive montred_gen - SuchThat (forall (N R N' : Z) - (Zlog2R : Z) - (n nout: nat) - (lo_hi : Z * Z), - Interp (t:=type.reify_type_of montred') - montred_gen N R N' Zlog2R n nout lo_hi - = montred' N R N' Zlog2R n nout lo_hi) - As montred_gen_correct. - Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. - Module Export ReifyHints. - Global Hint Extern 1 (_ = montred' _ _ _ _ _ _ _) => simple apply montred_gen_correct : reify_gen_cache. - End ReifyHints. - - Section rmontred. - Context (N R N' : Z) - (machine_wordsize : Z). - - Let bound := Some r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. - - Definition relax_zrange_of_machine_wordsize - := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z. - Local Arguments relax_zrange_of_machine_wordsize / . - - Let relax_zrange := relax_zrange_of_machine_wordsize. - - Definition check_args {T} (res : Pipeline.ErrorT T) - : Pipeline.ErrorT T - := res. (* TODO: this should actually check stuff that corresponds with preconditions of montred'_correct *) - - Notation BoundsPipeline_correct in_bounds out_bounds op - := (fun rv (rop : Expr (type.reify_type_of op)) Hrop - => @Pipeline.BoundsPipeline_correct_trans - false (* subst01 *) - relax_zrange - (relax_zrange_gen_good _) - _ - rop - in_bounds - out_bounds - op - Hrop rv) - (only parsing). - - Definition rmontred_correct - := BoundsPipeline_correct - (bound, bound) - bound - (montred' N R N' (Z.log2 R) 2 2). - - Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). - Definition rmontred_correctT rv : Prop - := type_of_strip_3arrow (@rmontred_correct rv). - End rmontred. -End MontgomeryReduction. - -Ltac solve_rmontred := solve_rop MontgomeryReduction.rmontred_correct. -Ltac solve_rmontred_nocache := solve_rop_nocache MontgomeryReduction.rmontred_correct. - -Module Montgomery256. - - Definition N := Eval lazy in (2^256-2^224+2^192+2^96-1). - Definition N':= (115792089210356248768974548684794254293921932838497980611635986753331132366849). - Definition R := Eval lazy in (2^256). - Definition R' := 115792089183396302114378112356516095823261736990586219612555396166510339686400. - Definition machine_wordsize := 256. - - Derive montred256 - SuchThat (MontgomeryReduction.rmontred_correctT N R N' machine_wordsize montred256) - As montred256_correct. - Proof. Time solve_rmontred machine_wordsize. Time Qed. - - Definition montred256_prefancy' := PreFancy.of_Expr machine_wordsize [N;N'] montred256. - - Derive montred256_prefancy - SuchThat (montred256_prefancy = montred256_prefancy' type.interp) - As montred256_prefancy_eq. - Proof. lazy - [type.interp]; reflexivity. Qed. - - - Lemma montred'_correct_specialized R' (R'_correct : Z.equiv_modulo N (R * R') 1) : - forall (lo hi : Z), - 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> - MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 (lo, hi) = ((lo + R * hi) * R') mod N. - Proof. - intros. - apply MontgomeryReduction.montred'_correct with (T:=lo + R * hi) (R':=R'); - try match goal with - | |- context[R'] => assumption - | |- context [lo] => - try assumption; progress autorewrite with zsimplify cancel_pair; reflexivity - end; lazy; try split; congruence. - Qed. - - (* Note: If this is not factored out, then for some reason Qed takes forever in montred256_correct_full. *) - Lemma montred256_correct_proj2 : - forall xy : type.interp (type.prod type.Z type.Z), - ZRange.type.option.is_bounded_by - (t:=type.prod type.Z type.Z) - (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) - xy = true -> - expr.Interp (@ident.interp) montred256 xy = app_curried (t:=type.arrow (type.prod type.Z type.Z) type.Z) (MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2) xy. - Proof. intros; destruct (montred256_correct xy); assumption. Qed. - Lemma montred256_correct_proj2' : - forall xy : type.interp (type.prod type.Z type.Z), - ZRange.type.option.is_bounded_by - (t:=type.prod type.Z type.Z) - (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) - xy = true -> - expr.Interp (@ident.interp) montred256 xy = MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 xy. - Proof. intros; rewrite montred256_correct_proj2 by assumption; unfold app_curried; exact eq_refl. Qed. - - Lemma montred256_correct_full R' (R'_correct : Z.equiv_modulo N (R * R') 1) : - forall (lo hi : Z), - 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> - expr.interp (@ident.interp) (montred256 type.interp) (lo, hi) = ((lo + R * hi) * R') mod N. - Proof. - intros. - rewrite <-montred'_correct_specialized by assumption. - rewrite <-montred256_correct_proj2'. - { cbv [expr.Interp type.uncurried_domain type.uncurry type.final_codomain]. - reflexivity. } - { cbn. rewrite !andb_true_iff. cbv [R N] in *. - repeat split; apply Z.leb_le; omega. } - Qed. - - (* TODO : maybe move these ok_expr tactics somewhere else *) - Ltac ok_expr_step' := - match goal with - | _ => assumption - | |- _ <= _ <= _ \/ @eq zrange _ _ => - right; lazy; try split; congruence - | |- _ <= _ <= _ \/ @eq zrange _ _ => - left; lazy; try split; congruence - | |- lower r[0~>_]%zrange = 0 => reflexivity - | |- context [PreFancy.ok_ident] => constructor - | |- context [PreFancy.ok_scalar] => constructor; try omega - | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ] - | |- context [PreFancy.is_halved] => constructor - | |- context [PreFancy.in_word_range] => lazy; reflexivity - | |- context [PreFancy.in_flag_range] => lazy; reflexivity - | |- context [PreFancy.get_range] => - cbn [PreFancy.get_range lower upper fst snd ZRange.map] - | x : type.interp (type.prod _ _) |- _ => destruct x - | |- (_ <=? _)%zrange = true => - match goal with - | |- context [PreFancy.get_range_var] => - cbv [is_tighter_than_bool PreFancy.has_range fst snd upper lower R N] in *; cbn; - apply andb_true_iff; split; apply Z.leb_le - | _ => lazy - end; omega || reflexivity - | |- @eq zrange _ _ => lazy; reflexivity - | |- _ <= _ => cbv [machine_wordsize]; omega - | |- _ <= _ <= _ => cbv [machine_wordsize]; omega - end; intros. - - (* TODO : maybe move these ok_expr tactics somewhere else *) - Ltac ok_expr_step := - match goal with - | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step' - end; intros; cbn [Nat.max]. - - Lemma montred256_prefancy_correct : - forall (lo hi : Z) dummy_arrow, - 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> - @PreFancy.interp machine_wordsize (PreFancy.interp_cast_mod machine_wordsize) type.Z (montred256_prefancy (lo,hi) dummy_arrow) = ((lo + R * hi) * R') mod N. - Proof. - intros. rewrite montred256_prefancy_eq; cbv [montred256_prefancy']. - erewrite PreFancy.of_Expr_correct. - { apply montred256_correct_full; try assumption; reflexivity. } - { reflexivity. } - { lazy; reflexivity. } - { lazy; reflexivity. } - { repeat constructor. } - { cbv [In N N']; intros; intuition; subst; cbv; congruence. } - { assert (340282366920938463463374607431768211455 * 2 ^ 128 <= 2 ^ machine_wordsize - 1) as shiftl_128_ok by (lazy; congruence). - repeat (ok_expr_step; [ ]). - ok_expr_step. - lazy; congruence. - constructor. - constructor. } - { lazy. omega. } - Qed. - - Definition montred256_fancy' (lo hi RegMod RegPInv RegZero error : positive) := - Fancy.of_Expr 3%positive - (fun z => if z =? N then Some RegMod else if z =? N' then Some RegPInv else if z =? 0 then Some RegZero else None) - [N;N'] - montred256 - (lo, hi)%positive - (fun _ _ => tt) - error. - Derive montred256_fancy - SuchThat (forall RegMod RegPInv RegZero, - montred256_fancy RegMod RegPInv RegZero = montred256_fancy' RegMod RegPInv RegZero) - As montred256_fancy_eq. - Proof. - intros. - lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB - Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU - Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM]. - reflexivity. - Qed. - - Import Fancy.Registers. - - Definition montred256_alloc' lo hi RegPInv := - fun errorP errorR => - Fancy.allocate register - positive Pos.eqb - errorR - (montred256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP) - [r2;r3;r4;r5;r6;r7;r8;r9;r10;r11;r12;r13;r14;r15;r16;r17;r18;r19;r20] - (fun n => if n =? 1000 then lo - else if n =? 1001 then hi - else if n =? 1002 then RegMod - else if n =? 1003 then RegPInv - else if n =? 1004 then RegZero - else errorR). - Derive montred256_alloc - SuchThat (montred256_alloc = montred256_alloc') - As montred256_alloc_eq. - Proof. - intros. - cbv [montred256_alloc' montred256_fancy]. - cbn. subst montred256_alloc. - reflexivity. - Qed. - - Import ProdEquiv. - - Local Ltac solve_bounds := - match goal with - | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega - | _ => assumption - end. - - Lemma montred256_alloc_equivalent errorP errorR cc_start_state start_context : - forall lo hi y t1 t2 scratch RegPInv extra_reg, - NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> - 0 <= start_context lo < R -> - 0 <= start_context hi < R -> - 0 <= start_context RegPInv < R -> - ProdEquiv.interp256 (montred256_alloc r0 r1 r30 errorP errorR) cc_start_state - (fun r => if reg_eqb r r0 - then start_context lo - else if reg_eqb r r1 - then start_context hi - else if reg_eqb r r30 - then start_context RegPInv - else start_context r) - = ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context. - Proof. - intros. cbv [R] in *. - cbv [Prod.MontRed256 montred256_alloc]. - - (* Extract proofs that no registers are equal to each other *) - repeat match goal with - | H : NoDup _ |- _ => inversion H; subst; clear H - | H : ~ In _ _ |- _ => cbv [In] in H - | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H - | H : ~ False |- _ => clear H - end. - - rewrite ProdEquiv.interp_Mul256 with (tmp2:=extra_reg) by (congruence || push_value_unused). - - step_both_sides. - step_both_sides. - rewrite mulll_comm. step_both_sides. - step_both_sides. - step_both_sides. - - rewrite ProdEquiv.interp_Mul256x256 with (tmp2:=extra_reg) by (congruence || push_value_unused). - - rewrite mulll_comm. step_both_sides. - step_both_sides. - step_both_sides. - rewrite mulhh_comm. step_both_sides. - step_both_sides. - step_both_sides. - step_both_sides. - step_both_sides. - - - rewrite add_comm by (cbn; solve_bounds). step_both_sides. - rewrite addc_comm by (cbn; solve_bounds). step_both_sides. - step_both_sides. - step_both_sides. - step_both_sides. - - cbn; repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence. - reflexivity. - Qed. - - Import Fancy_PreFancy_Equiv. - - Definition interp_equivZZ_256 {s} := - @interp_equivZZ s 256 ltac:(cbv; congruence) 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). - Definition interp_equivZ_256 {s} := - @interp_equivZ s 256 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). - - Local Ltac simplify_op_equiv start_ctx := - cbn - [Fancy.spec PreFancy.interp_ident Fancy.cc_spec]; - repeat match goal with H : start_ctx _ = _ |- _ => rewrite H end; - cbv - [ - Z.add_with_get_carry_full - Z.add_get_carry_full Z.sub_get_borrow_full - Z.le Z.ltb Z.leb Z.geb Z.eqb Z.land Z.shiftr Z.shiftl - Z.add Z.mul Z.div Z.sub Z.modulo Z.testbit Z.pow Z.ones - fst snd]; cbn [fst snd]; - try (replace (2 ^ (256 / 2) - 1) with (Z.ones 128) by reflexivity; rewrite !Z.land_ones by omega); - autorewrite with to_div_mod; rewrite ?Z.mod_mod, <-?Z.testbit_spec' by omega; - repeat match goal with - | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by apply H - | |- context [?x <? 0] => rewrite (proj2 (Z.ltb_ge x 0)) by (break_match; Z.zero_bounds) - | _ => rewrite Z.mod_small with (b:=2) by (break_match; omega) - | |- context [ (if Z.testbit ?a ?n then 1 else 0) + ?b + ?c] => - replace ((if Z.testbit a n then 1 else 0) + b + c) with (b + c + (if Z.testbit a n then 1 else 0)) by ring - end. - - Local Ltac solve_nonneg ctx := - match goal with x := (Fancy.spec _ _ _) |- _ => subst x end; - simplify_op_equiv ctx; Z.zero_bounds. - - Local Ltac generalize_result := - let v := fresh "v" in intro v; generalize v; clear v; intro v. - - Local Ltac generalize_result_nonneg ctx := - let v := fresh "v" in - let v_nonneg := fresh "v_nonneg" in - intro v; assert (0 <= v) as v_nonneg; [solve_nonneg ctx |generalize v v_nonneg; clear v v_nonneg; intros v v_nonneg]. - - Local Ltac step ctx := - match goal with - | |- Fancy.interp _ _ _ (Fancy.Instr (Fancy.ADD _) _ _ (Fancy.Instr (Fancy.ADDC _) _ _ _)) _ _ = _ => - apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result_nonneg ctx] - | _ => apply interp_equivZ_256; [simplify_op_equiv ctx | generalize_result] - | _ => apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result] - end. - - Lemma prod_montred256_correct : - forall (cc_start_state : Fancy.CC.state) (* starting carry flags can be anything *) - (start_context : register -> Z) (* starting register values *) - (lo hi y t1 t2 scratch RegPInv extra_reg : register), (* registers to use in computation *) - NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> (* registers must be distinct *) - start_context RegPInv = N' -> (* RegPInv needs to hold the inverse of the modulus *) - start_context RegMod = N -> (* RegMod needs to hold the modulus *) - start_context RegZero = 0 -> (* RegZero needs to hold zero *) - (0 <= start_context lo < R) -> (* low half of the input is in bounds (R=2^256) *) - (0 <= start_context hi < R) -> (* high half of the input is in bounds (R=2^256) *) - let x := (start_context lo) + R * (start_context hi) in (* x is the input (split into two registers) *) - (0 <= x < R * N) -> (* input precondition *) - (ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context = (x * R') mod N). - Proof. - intros. subst x. cbv [N R N'] in *. - rewrite <-montred256_prefancy_correct with (dummy_arrow := fun s d _ => DefaultValue.type.default) by auto. - rewrite <-montred256_alloc_equivalent with (errorR := RegZero) (errorP := 1%positive) (extra_reg:=extra_reg) - by (cbv [R]; auto with omega). - cbv [ProdEquiv.interp256]. - cbv [montred256_alloc montred256_prefancy]. - - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ reflexivity | reflexivity | ]. - step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. - step start_context; [ reflexivity | | ]. - { - let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity. - rewrite !Z.shiftl_0_r, !Z.mod_mod by omega. - repeat match goal with - | |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega)) - end. - apply Z.testbit_neg_eq_if; - let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity; - omega. } - step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. - reflexivity. - Qed. - - Import PrintingNotations. - Set Printing Width 10000. - - Print montred256. -(* -montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, - expr_let x0 := 79228162514264337593543950337 *₂₅₆ (uint128)(x₁ >> 128) in - expr_let x1 := 340282366841710300986003757985643364352 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in - expr_let x2 := 79228162514264337593543950337 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in - expr_let x3 := ADD_256 ((uint256)(((uint128)(x1) & 340282366920938463463374607431768211455) << 128), x2) in - expr_let x4 := ADD_256 ((uint256)(((uint128)(x0) & 340282366920938463463374607431768211455) << 128), x3₁) in - expr_let x5 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in - expr_let x6 := 79228162514264337593543950335 *₂₅₆ (uint128)(x4₁ >> 128) in - expr_let x7 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in - expr_let x8 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x4₁ >> 128) in - expr_let x9 := ADD_256 ((uint256)(((uint128)(x7) & 340282366920938463463374607431768211455) << 128), x5) in - expr_let x10 := ADDC_256 (x9₂, (uint128)(x7 >> 128), x8) in - expr_let x11 := ADD_256 ((uint256)(((uint128)(x6) & 340282366920938463463374607431768211455) << 128), x9₁) in - expr_let x12 := ADDC_256 (x11₂, (uint128)(x6 >> 128), x10₁) in - expr_let x13 := ADD_256 (x11₁, x₁) in - expr_let x14 := ADDC_256 (x13₂, x12₁, x₂) in - expr_let x15 := SELC (x14₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in - expr_let x16 := SUB_256 (x14₁, x15) in - ADDM (x16₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951))%expr - : Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)) -*) - - Import PreFancy. - Import PreFancy.Notations. - Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951). - Local Notation "'RegPInv'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248768974548684794254293921932838497980611635986753331132366849). - Print montred256_prefancy. - (* - mulhl@(y0, RegPInv, $x₁); - mulll@(y1, RegPInv, $x₁); - add@(y2, $y1, $y0, 128); - add@(y3, $y2, $y, 128); - mulll@(y4, RegMod, $y3); - mullh@(y5, RegMod, $y3); - mulhl@(y6, RegMod, $y3); - mulhh@(y7, RegMod, $y3); - add@(y8, $y4, $y6, 128); - addc@(y9, carry{$y8}, $y7, $y6, -128); - add@(y10, $y8, $y5, 128); - addc@(y11, carry{$y10}, $y9, $y5, -128); - add@(y12, $y10, $x₁, 0); - addc@(y13, carry{$y12}, $y11, $x₂, 0); - selc@(y14, carry{$y13}, RegZero, RegMod); - sub@(y15, $y13, $y14, 0); - addm@(y16, $y15, RegZero, RegMod); - ret $y16 - *) - -End Montgomery256. - -Local Notation "i rd x y ; cont" := (Fancy.Instr i rd (x, y) cont) (at level 40, cont at level 200, format "i rd x y ; '//' cont"). -Local Notation "i rd x y z ; cont" := (Fancy.Instr i rd (x, y, z) cont) (at level 40, cont at level 200, format "i rd x y z ; '//' cont"). - -Import Fancy.Registers. -Import Fancy. - -Import Barrett256 Montgomery256. - -(*** Montgomery Reduction ***) - -(* Status: Code in final form is proven correct modulo admits in compiler portions. *) - -(* Montgomery Code : *) -Eval cbv beta iota delta [Prod.MontRed256 Prod.Mul256 Prod.Mul256x256] in Prod.MontRed256. -(* - = fun lo hi y t1 t2 scratch RegPInv : register => - MUL128LL y lo RegPInv; - MUL128UL t1 lo RegPInv; - ADD 128 y y t1; - MUL128LU t1 lo RegPInv; - ADD 128 y y t1; - MUL128LL t1 y RegMod; - MUL128UU t2 y RegMod; - MUL128UL scratch y RegMod; - ADD 128 t1 t1 scratch; - ADDC (-128) t2 t2 scratch; - MUL128LU scratch y RegMod; - ADD 128 t1 t1 scratch; - ADDC (-128) t2 t2 scratch; - ADD 0 lo lo t1; - ADDC 0 hi hi t2; - SELC y RegMod RegZero; - SUB 0 lo hi y; - ADDM lo lo RegZero RegMod; - Ret lo - *) - -(* Uncomment to see proof statement and remaining admitted statements, -or search for "prod_montred256_correct" to see comments on the proof -preconditions. *) -(* -Check Montgomery256.prod_montred256_correct. -Print Assumptions Montgomery256.prod_montred256_correct. -*) - -(*** Barrett Reduction ***) - -(* Status: Code is proven correct modulo admits in compiler -portions. However, unlike for Montgomery, this code is not proven -equivalent to the register-allocated and efficiently-scheduled -reference (Prod.MulMod). This proof is currently admitted and would -require either fiddling with code generation to make instructions come -out in the right order or reasoning about which instructions -commute. *) - -(* Barrett reference code: *) -Eval cbv beta iota delta [Prod.MulMod Prod.Mul256x256] in Prod.MulMod. -(* - = fun x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 : register => - let q1Bottom256 := scratchp1 in - let muSelect := scratchp2 in - let q2 := scratchp3 in - let q2High := scratchp4 in - let q2High2 := scratchp5 in - let q3 := scratchp1 in - let r2 := scratchp2 in - let r2High := scratchp3 in - let maybeM := scratchp1 in - SELM muSelect RegMuLow RegZero; - RSHI 255 q1Bottom256 xHigh x; - MUL128LL q2 q1Bottom256 RegMuLow; - MUL128UU q2High q1Bottom256 RegMuLow; - MUL128UL scratchp5 q1Bottom256 RegMuLow; - ADD 128 q2 q2 scratchp5; - ADDC (-128) q2High q2High scratchp5; - MUL128LU scratchp5 q1Bottom256 RegMuLow; - ADD 128 q2 q2 scratchp5; - ADDC (-128) q2High q2High scratchp5; - RSHI 255 q2High2 RegZero xHigh; - ADD 0 q2High q2High q1Bottom256; - ADDC 0 q2High2 q2High2 RegZero; - ADD 0 q2High q2High muSelect; - ADDC 0 q2High2 q2High2 RegZero; - RSHI 1 q3 q2High2 q2High; - MUL128LL r2 RegMod q3; - MUL128UU r2High RegMod q3; - MUL128UL scratchp4 RegMod q3; - ADD 128 r2 r2 scratchp4; - ADDC (-128) r2High r2High scratchp4; - MUL128LU scratchp4 RegMod q3; - ADD 128 r2 r2 scratchp4; - ADDC (-128) r2High r2High scratchp4; - SUB 0 muSelect x r2; - SUBC 0 xHigh xHigh r2High; - SELL maybeM RegMod RegZero; - SUB 0 q3 muSelect maybeM; - ADDM x q3 RegZero RegMod; - Ret x - *) - -(* Barrett generated code (equivalence with reference admitted) *) -Eval cbv beta iota delta [barrett_red256_alloc] in barrett_red256_alloc. -(* - = fun (xLow xHigh RegMuLow : register) (_ : positive) (_ : register) => - SELM r2 RegMuLow RegZero; - RSHI 255 r3 RegZero xHigh; - RSHI 255 r4 xHigh xLow; - MUL128UU r5 RegMuLow r4; - MUL128UL r6 r4 RegMuLow; - MUL128LU r7 r4 RegMuLow; - MUL128LL r8 RegMuLow r4; - ADD 128 r9 r8 r7; - ADDC (-128) r10 r5 r7; - ADD 128 r5 r9 r6; - ADDC (-128) r11 r10 r6; - ADD 0 r6 r4 r11; - ADDC 0 r12 RegZero r3; - ADD 0 r13 r2 r6; - ADDC 0 r14 RegZero r12; - RSHI 1 r15 r14 r13; - MUL128UU r16 RegMod r15; - MUL128LU r17 r15 RegMod; - MUL128UL r18 r15 RegMod; - MUL128LL r19 RegMod r15; - ADD 128 r20 r19 r18; - ADDC (-128) r21 r16 r18; - ADD 128 r22 r20 r17; - ADDC (-128) r23 r21 r17; - SUB 0 r24 xLow r22; - SUBC 0 r25 xHigh r23; - SELL r26 RegMod RegZero; - SUB 0 r27 r24 r26; - ADDM r28 r27 RegZero RegMod; - Ret r28 - *) - -(* Uncomment to see proof statement and remaining admitted statements. *) -(* -Check prod_barrett_red256_correct. -Print Assumptions prod_barrett_red256_correct. -(* The equivalence with generated code is admitted as barrett_red256_alloc_equivalent. *) -*) |