diff options
author | 2018-02-16 19:18:06 -0500 | |
---|---|---|
committer | 2018-03-19 14:17:26 -0400 | |
commit | a605e01f6da045dd7f8140a55aa951fd7799821a (patch) | |
tree | 4d29b45a91207e5c6afabb2d8e5c35585efe7639 /src | |
parent | 768dc8d4524b0e48b54fd56876312032e626484a (diff) |
Add a ring goal
Unfortunately, the ring proofs are a bit messy.
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 1761 |
1 files changed, 1556 insertions, 205 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 87571459c..fef25fe59 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -9,6 +9,7 @@ Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn. Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil. Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. Require Import Crypto.Util.ZRange. Require Import Crypto.Util.ZRange.Operations. Require Import Crypto.Util.Tactics.RunTacticAsConstr. @@ -16,13 +17,14 @@ Require Import Crypto.Util.Tactics.Head. Require Import Crypto.Util.Option. Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.Tactics.SpecializeBy. Require Import Crypto.Util.Notations. Require Import Crypto.Util.ZUtil.Definitions. Import ListNotations. Local Open Scope Z_scope. Module Associational. Definition eval (p:list (Z*Z)) : Z := - fold_right Z.add 0%Z (map (fun t => fst t * snd t) p). + fold_right (fun x y => x + y) 0%Z (map (fun t => fst t * snd t) p). Lemma eval_nil : eval nil = 0. Proof. trivial. Qed. @@ -409,40 +411,116 @@ Module Positional. Section Positional. Hint Rewrite @eval_opp @eval_sub : push_eval. Hint Rewrite @length_sub @length_opp : distr_length. - Section carry_mulmod. - Context (s:Z) - (c:list (Z*Z)) + Section mod_ops. + Context (s : Z) + (c : list (Z*Z)) (n : nat) (len_c : nat) (idxs : list nat) (len_idxs : nat) - (fg : list Z * list Z). + (m_nz:s - Associational.eval c <> 0) (s_nz:s <> 0) + (Hn_nz : n <> 0%nat) + (Hc : length c = len_c) + (Hidxs : length idxs = len_idxs) + (Hw_div_nz : forall i : nat, weight (S i) / weight i <> 0). Derive carry_mulmod - SuchThat (forall (f := fst fg) (g := snd fg) - (m_nz:s - Associational.eval c <> 0) (s_nz:s <> 0) + SuchThat (forall (fg : list Z * list Z) + (f := fst fg) (g := snd fg) (Hf : length f = n) - (Hg : length g = n) - (Hn_nz : n <> 0%nat) - (Hc : length c = len_c) - (Hidxs : length idxs = len_idxs) - (Hw_div_nz : forall i : nat, weight (S i) / weight i <> 0), - (eval n carry_mulmod) mod (s - Associational.eval c) + (Hg : length g = n), + (eval n (carry_mulmod fg)) mod (s - Associational.eval c) = (eval n f * eval n g) mod (s - Associational.eval c)) As eval_carry_mulmod. Proof. intros. - erewrite <-eval_mulmod with (s:=s) (c:=c) - by (subst; try assumption; try reflexivity). + rewrite <-eval_mulmod with (s:=s) (c:=c) by auto. etransitivity; [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs) - by (subst; auto); reflexivity ]. + by auto; reflexivity ]. + eapply f_equal2; [|trivial]. eapply f_equal. + expand_lists (). + subst f g carry_mulmod; reflexivity. + Qed. + + Derive carrymod + SuchThat (forall (f : list Z) + (Hf : length f = n), + (eval n (carrymod f)) mod (s - Associational.eval c) + = (eval n f) mod (s - Associational.eval c)) + As eval_carrymod. + Proof. + intros. + etransitivity; + [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs) + by auto; reflexivity ]. + eapply f_equal2; [|trivial]. eapply f_equal. + expand_lists (). + subst carrymod; reflexivity. + Qed. + + Derive addmod + SuchThat (forall (fg: list Z * list Z) + (f := fst fg) (g := snd fg) + (Hf : length f = n) + (Hg : length g = n), + (eval n (addmod fg)) mod (s - Associational.eval c) + = (eval n f + eval n g) mod (s - Associational.eval c)) + As eval_addmod. + Proof. + intros. + rewrite <-eval_add by assumption. + eapply f_equal2; [|trivial]. eapply f_equal. + expand_lists (). + subst f g addmod; reflexivity. + Qed. + + Derive submod + SuchThat (forall (coef:Z) + (fg: list Z * list Z) + (f := fst fg) (g := snd fg) + (Hf : length f = n) + (Hg : length g = n), + (eval n (submod coef fg)) mod (s - Associational.eval c) + = (eval n f - eval n g) mod (s - Associational.eval c)) + As eval_submod. + Proof. + intros. + rewrite <-eval_sub with (coef:=coef) by auto. + eapply f_equal2; [|trivial]. eapply f_equal. + expand_lists (). + subst f g submod; reflexivity. + Qed. + + Derive oppmod + SuchThat (forall (coef:Z) + (f: list Z) + (Hf : length f = n), + (eval n (oppmod coef f)) mod (s - Associational.eval c) + = (- eval n f) mod (s - Associational.eval c)) + As eval_oppmod. + Proof. + intros. + rewrite <-eval_opp with (coef:=coef) by auto. + eapply f_equal2; [|trivial]. eapply f_equal. + expand_lists (). + subst oppmod; reflexivity. + Qed. + + Derive encodemod + SuchThat (forall (f:Z), + (eval n (encodemod f)) mod (s - Associational.eval c) + = f mod (s - Associational.eval c)) + As eval_encodemod. + Proof. + intros. + etransitivity. + 2:rewrite <-@eval_encode with (n:=n) by auto; reflexivity. eapply f_equal2; [|trivial]. eapply f_equal. expand_lists (). - subst carry_mulmod. - reflexivity. + subst encodemod; reflexivity. Qed. - End carry_mulmod. + End mod_ops. End Positional. End Positional. Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. @@ -1267,16 +1345,18 @@ Module Compilers. do_reify_ident term' ltac:(fun _ - => let term - := match constr:(Set) with - | _ => (eval cbv delta [term] in term) (* might fail, so we wrap it in a match to give better error messages *) - | _ - => let dummy := match goal with - | _ => fail 2 "Unrecognized term:" term' - end in - constr:(I : I) - end in - reify_rec term) + => + (*let __ := match goal with _ => idtac "Attempting to unfold" term end in*) + let term + := match constr:(Set) with + | _ => (eval cbv delta [term] in term) (* might fail, so we wrap it in a match to give better error messages *) + | _ + => let dummy := match goal with + | _ => fail 2 "Unrecognized term:" term' + end in + constr:(I : I) + end in + reify_rec term) end) end end. @@ -1325,6 +1405,7 @@ Module Compilers. | Z_add : ident (Z * Z) Z | Z_mul : ident (Z * Z) Z | Z_pow : ident (Z * Z) Z + | Z_sub : ident (Z * Z) Z | Z_opp : ident Z Z | Z_div : ident (Z * Z) Z | Z_modulo : ident (Z * Z) Z @@ -1390,6 +1471,7 @@ Module Compilers. | Z_pow => curry2 Z.pow | Z_modulo => curry2 Z.modulo | Z_opp => Z.opp + | Z_sub => curry2 Z.sub | Z_div => curry2 Z.div | Z_eqb => curry2 Z.eqb | Z_leb => curry2 Z.leb @@ -1504,6 +1586,7 @@ Module Compilers. | Z.add ?x ?y => mkAppIdent ident.Z_add (x, y) | Z.mul ?x ?y => mkAppIdent ident.Z_mul (x, y) | Z.pow ?x ?y => mkAppIdent ident.Z_pow (x, y) + | Z.sub ?x ?y => mkAppIdent ident.Z_sub (x, y) | Z.opp ?x => mkAppIdent ident.Z_opp x | Z.div ?x ?y => mkAppIdent ident.Z_div (x, y) | Z.modulo ?x ?y => mkAppIdent ident.Z_modulo (x, y) @@ -1549,6 +1632,7 @@ Module Compilers. Notation add := Z_add. Notation mul := Z_mul. Notation pow := Z_pow. + Notation sub := Z_sub. Notation opp := Z_opp. Notation div := Z_div. Notation modulo := Z_modulo. @@ -1626,6 +1710,7 @@ Module Compilers. | Z_add : ident (Z * Z) Z | Z_mul : ident (Z * Z) Z | Z_pow : ident (Z * Z) Z + | Z_sub : ident (Z * Z) Z | Z_opp : ident Z Z | Z_div : ident (Z * Z) Z | Z_modulo : ident (Z * Z) Z @@ -1685,6 +1770,7 @@ Module Compilers. | Z_mul => curry2 Z.mul | Z_pow => curry2 Z.pow | Z_modulo => curry2 Z.modulo + | Z_sub => curry2 Z.sub | Z_opp => Z.opp | Z_div => curry2 Z.div | Z_eqb => curry2 Z.eqb @@ -1769,6 +1855,7 @@ Module Compilers. | Z.add ?x ?y => mkAppIdent ident.Z_add (x, y) | Z.mul ?x ?y => mkAppIdent ident.Z_mul (x, y) | Z.pow ?x ?y => mkAppIdent ident.Z_pow (x, y) + | Z.sub ?x ?y => mkAppIdent ident.Z_sub (x, y) | Z.opp ?x => mkAppIdent ident.Z_opp x | Z.div ?x ?y => mkAppIdent ident.Z_div (x, y) | Z.modulo ?x ?y => mkAppIdent ident.Z_modulo (x, y) @@ -1807,6 +1894,7 @@ Module Compilers. Notation add := Z_add. Notation mul := Z_mul. Notation pow := Z_pow. + Notation sub := Z_sub. Notation opp := Z_opp. Notation div := Z_div. Notation modulo := Z_modulo. @@ -1916,6 +2004,8 @@ Module Compilers. => AppIdent ident.Z.mul | for_reification.ident.Z_pow => AppIdent ident.Z.pow + | for_reification.ident.Z_sub + => AppIdent ident.Z.sub | for_reification.ident.Z_opp => AppIdent ident.Z.opp | for_reification.ident.Z_div @@ -2207,6 +2297,15 @@ Module Compilers. end | None => None end. + Definition invert_Z_opp (e : @expr var type.Z) : option (@expr var type.Z) + := match invert_AppIdent e with + | Some (existT s (idc, args)) + => match idc in ident s t return expr s -> option (expr type.Z) with + | ident.Z_opp => fun v => Some v + | _ => fun _ => None + end args + | None => None + end. Local Notation list_expr := (fun t => match t return Type with @@ -2710,6 +2809,7 @@ Module Compilers. | ident.Z_add as idc | ident.Z_mul as idc | ident.Z_pow as idc + | ident.Z_sub as idc | ident.Z_opp as idc | ident.Z_div as idc | ident.Z_modulo as idc @@ -2905,6 +3005,7 @@ Module Compilers. @@ (ident.fst @@ (Var xyk))) | ident.Z_add as idc | ident.Z_mul as idc + | ident.Z_sub as idc | ident.Z_pow as idc | ident.Z_div as idc | ident.Z_modulo as idc @@ -3321,6 +3422,36 @@ Module Compilers. Include default. End CPS. + Module sign. + Local Set Boolean Equality Schemes. + Local Set Decidable Equality Schemes. + Inductive sign := pos | neg. + Module Z. + Definition interp (s : sign) : Z -> Z + := match s with + | pos => id + | neg => Z.opp + end. + End Z. + Definition of_Z (v : Z) : sign + := match v with + | Zneg _ => neg + | Zpos _ => pos + | Z0 => (* default *) pos + end. + Definition opp (s : sign) : sign + := match s with + | pos => neg + | neg => pos + end. + Definition mul (s1 s2 : sign) : sign + := match s1, s2 with + | pos, pos => pos + | neg, neg => pos + | pos, neg | neg, pos => neg + end. + End sign. + Notation sign := sign.sign. Module partial. Section value. Context (var : type -> Type). @@ -3336,6 +3467,8 @@ Module Compilers. := match t return Type with | type.arrow _ _ as t => value_prestep value t + | type.type_primitive type.Z as t + => sign * @expr var t + value_prestep value t | type.prod _ _ as t | type.list _ as t | type.type_primitive _ as t @@ -3367,6 +3500,13 @@ Module Compilers. | inl v => v | inr v => reify_list (List.map (@reify A) v) end + | type.type_primitive type.Z as t + => fun x : sign * expr t + type.interp t + => match x with + | inl (sign.pos, v) => v + | inl (sign.neg, v) => ident.Z.opp @@ v + | inr v => ident.primitive v @@ TT + end%core%expr | type.type_primitive _ as t => fun x : expr t + type.interp t => match x with @@ -3400,6 +3540,15 @@ Module Compilers. | None => inl v end + | type.type_primitive type.Z as t + => fun v : expr t + => let inr := @inr (sign * expr t) (value_prestep (value var) t) in + let inl := @inl (sign * expr t) (value_prestep (value var) t) in + match reflect_primitive v, invert_Z_opp v with + | Some v, _ => inr v + | None, Some v => inl (sign.neg, v) + | None, None => inl (sign.pos, v) + end | type.type_primitive _ as t => fun v : expr t => let inr := @inr (expr t) (value_prestep (value var) t) in @@ -3446,6 +3595,17 @@ Module Compilers. (fun b => f (inr (a, b)))) | inl e => partial.expr.reflect (expr_let y := e in partial.expr.reify (f (inl (Var y))))%expr end + | type.type_primitive type.Z as t + => fun (x : sign * expr t + type.interp t) + (f : sign * expr t + type.interp t -> value var tC) + => match x with + | inl (sgn, e) + => match invert_Var e with + | Some v => f (inl (sgn, Var v)) + | None => partial.expr.reflect (expr_let y := e in partial.expr.reify (f (inl (sgn, Var y)%core)))%expr + end + | inr v => f (inr v) + end | type.type_primitive _ as t => fun (x : expr t + type.interp t) (f : expr t + type.interp t -> value var tC) => match x with @@ -3468,6 +3628,8 @@ Module Compilers. end | ident.nil t => fun _ => inr (@nil (value var t)) + | ident.primitive type.Z v + => fun _ => inr v | ident.primitive t v => fun _ => inr v | ident.cons t as idc @@ -3517,6 +3679,15 @@ Module Compilers. ls | _ => expr.reflect (AppIdent idc (expr.reify (t:=P * (A * type.list A * P -> P) * type.list A) nil_case_cons_case_ls)) end + | ident.List.nth_default type.Z as idc + => fun (default_ls_idx : expr (type.Z * type.list type.Z * type.nat) + (expr (type.Z * type.list type.Z) + (sign * expr type.Z + type.interp type.Z) * (expr (type.list type.Z) + list (value var type.Z))) * (expr type.nat + nat)) + => match default_ls_idx with + | inr (inr (default, inr ls), inr idx) + => List.nth_default default ls idx + | inr (inr (inr default, ls), inr idx) + => expr.reflect (AppIdent (ident.List.nth_default_concrete default idx) (expr.reify (t:=type.list type.Z) ls)) + | _ => expr.reflect (AppIdent idc (expr.reify (t:=type.Z * type.list type.Z * type.nat) default_ls_idx)) + end | ident.List.nth_default (type.type_primitive A) as idc => fun (default_ls_idx : expr (A * type.list A * type.nat) + (expr (A * type.list A) + (expr A + type.interp A) * (expr (type.list A) + list (value var A))) * (expr type.nat + nat)) => match default_ls_idx with @@ -3542,8 +3713,8 @@ Module Compilers. end | ident.Z_mul_split as idc => fun (x_y_z : (expr (type.Z * type.Z * type.Z) + - (expr (type.Z * type.Z) + (expr type.Z + Z) * (expr type.Z + Z)) * (expr type.Z + Z))%type) - => match x_y_z return (expr _ + (expr _ + type.interp _) * (expr _ + type.interp _)) with + (expr (type.Z * type.Z) + (sign * expr type.Z + Z) * (sign * expr type.Z + Z)) * (sign * expr type.Z + Z))%type) + => match x_y_z return (expr _ + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) with | inr (inr (inr x, inr y), inr z) => let result := ident.interp idc (x, y, z) in inr (inr (fst result), inr (snd result)) @@ -3553,8 +3724,8 @@ Module Compilers. end | ident.Z_add_get_carry as idc => fun (x_y_z : (expr (type.Z * type.Z * type.Z) + - (expr (type.Z * type.Z) + (expr type.Z + Z) * (expr type.Z + Z)) * (expr type.Z + Z))%type) - => match x_y_z return (expr _ + (expr _ + type.interp _) * (expr _ + type.interp _)) with + (expr (type.Z * type.Z) + (sign * expr type.Z + Z) * (sign * expr type.Z + Z)) * (sign * expr type.Z + Z))%type) + => match x_y_z return (expr _ + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) with | inr (inr (inr x, inr y), inr z) => let result := ident.interp idc (x, y, z) in inr (inr (fst result), inr (snd result)) @@ -3573,9 +3744,9 @@ Module Compilers. | ident.Z_add_with_get_carry as idc => fun (x_y_z_a : (expr (_ * _ * _ * _) + (expr (_ * _ * _) + - (expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) * - (expr _ + type.interp _)) * (expr _ + type.interp _))%type) - => match x_y_z_a return (expr _ + (expr _ + type.interp _) * (expr _ + type.interp _)) with + (expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) * + (sign * expr _ + type.interp _)) * (sign * expr _ + type.interp _))%type) + => match x_y_z_a return (expr _ + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) with | inr (inr (inr (inr x, inr y), inr z), inr a) => let result := ident.interp idc (x, y, z, a) in inr (inr (fst result), inr (snd result)) @@ -3585,8 +3756,8 @@ Module Compilers. end | ident.Z_sub_get_borrow as idc => fun (x_y_z : (expr (type.Z * type.Z * type.Z) + - (expr (type.Z * type.Z) + (expr type.Z + Z) * (expr type.Z + Z)) * (expr type.Z + Z))%type) - => match x_y_z return (expr _ + (expr _ + type.interp _) * (expr _ + type.interp _)) with + (expr (type.Z * type.Z) + (sign * expr type.Z + Z) * (sign * expr type.Z + Z)) * (sign * expr type.Z + Z))%type) + => match x_y_z return (expr _ + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) with | inr (inr (inr x, inr y), inr z) => let result := ident.interp idc (x, y, z) in inr (inr (fst result), inr (snd result)) @@ -3596,17 +3767,17 @@ Module Compilers. end | ident.Z_mul_split_concrete _ as idc | ident.Z.sub_get_borrow_concrete _ as idc - => fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) - => match x_y return (expr _ + (expr _ + type.interp _) * (expr _ + type.interp _)) with + => fun (x_y : expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) + => match x_y return (expr _ + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) with | inr (inr x, inr y) => let result := ident.interp idc (x, y) in inr (inr (fst result), inr (snd result)) | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) end | ident.Z.add_get_carry_concrete _ as idc - => fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) + => fun (x_y : expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in - match x_y return (expr _ + (expr _ + type.interp _) * (expr _ + type.interp _)) with + match x_y return (expr _ + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) with | inr (inr x, inr y) => let result := ident.interp idc (x, y) in inr (inr (fst result), inr (snd result)) @@ -3619,8 +3790,8 @@ Module Compilers. end | ident.Z.add_with_get_carry_concrete _ as idc => fun (x_y_z : (expr (type.Z * type.Z * type.Z) + - (expr (type.Z * type.Z) + (expr type.Z + Z) * (expr type.Z + Z)) * (expr type.Z + Z))%type) - => match x_y_z return (expr _ + (expr _ + type.interp _) * (expr _ + type.interp _)) with + (expr (type.Z * type.Z) + (sign * expr type.Z + Z) * (sign * expr type.Z + Z)) * (sign * expr type.Z + Z))%type) + => match x_y_z return (expr _ + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) with | inr (inr (inr x, inr y), inr z) => let result := ident.interp idc (x, y, z) in inr (inr (fst result), inr (snd result)) @@ -3628,30 +3799,50 @@ Module Compilers. end | ident.pred as idc | ident.Nat_succ as idc + => fun x : expr _ + type.interp _ + => match x return expr _ + type.interp _ with + | inr x => inr (ident.interp idc x) + | inl x => expr.reflect (AppIdent idc x) + end | ident.Z_of_nat as idc + => fun x : expr _ + type.interp _ + => match x return sign * expr _ + type.interp _ with + | inr x => inr (ident.interp idc x) + | inl x => expr.reflect (AppIdent idc x) + end | ident.Z_opp as idc + => fun x : sign * expr _ + type.interp _ + => match x return sign * expr _ + type.interp _ with + | inr x => inr (ident.interp idc x) + | inl (sgn, x) => inl (sign.opp sgn, x) + end | ident.Z_shiftr _ as idc | ident.Z_shiftl _ as idc | ident.Z_land _ as idc - => fun x : expr _ + type.interp _ - => match x return expr _ + type.interp _ with + => fun x : sign * expr _ + type.interp _ + => match x return sign * expr _ + type.interp _ with | inr x => inr (ident.interp idc x) - | inl x => expr.reflect (AppIdent idc x) + | inl _ => expr.reflect (AppIdent idc (expr.reify (t:=type.Z) x)) end | ident.Nat_add as idc | ident.Nat_mul as idc | ident.Z_pow as idc + => fun (x_y : expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) + => match x_y return sign * expr _ + type.interp _ with + | inr (inr x, inr y) => inr (ident.interp idc (x, y)) + | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) + end | ident.Z_eqb as idc | ident.Z_leb as idc - => fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) + => fun (x_y : expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) => match x_y return expr _ + type.interp _ with | inr (inr x, inr y) => inr (ident.interp idc (x, y)) | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) end | ident.Z_div as idc - => fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) + => fun (x_y : expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in - match x_y return expr _ + type.interp _ with + match x_y return sign * expr _ + type.interp _ with | inr (inr x, inr y) => inr (ident.interp idc (x, y)) | inr (x, inr y) => if Z.eqb y (2^Z.log2 y) @@ -3660,9 +3851,9 @@ Module Compilers. | _ => default end | ident.Z_modulo as idc - => fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) + => fun (x_y : expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in - match x_y return expr _ + type.interp _ with + match x_y return sign * expr _ + type.interp _ with | inr (inr x, inr y) => inr (ident.interp idc (x, y)) | inr (x, inr y) => if Z.eqb y (2^Z.log2 y) @@ -3671,38 +3862,82 @@ Module Compilers. | _ => default end | ident.Z_mul as idc - => fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) + => fun (x_y : expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in - match x_y return expr _ + type.interp _ with + match x_y return sign * expr _ + type.interp _ with | inr (inr x, inr y) => inr (ident.interp idc (x, y)) - | inr (inr x, inl e) - | inr (inl e, inr x) + | inr (inr x, inl (sgn, e)) + | inr (inl (sgn, e), inr x) => if Z.eqb x 0 then inr 0%Z else if Z.eqb x 1 - then inl e - else if Z.eqb x (2^Z.log2 x) - then expr.reflect (AppIdent (ident.Z.shiftl (Z.log2 x)) e) - else default - | inr (inl _, inl _) | inl _ => default + then inl (sgn, e) + else if Z.eqb x (-1) + then inl (sign.opp sgn, e) + else let sgn' := sign.mul sgn (sign.of_Z x) in + let x' := Z.abs x in + if Z.eqb x' (2^Z.log2 x') + then inl (sgn', + AppIdent (ident.Z.shiftl (Z.log2 x')) e) + else inl (sgn', + AppIdent idc (ident.primitive (t:=type.Z) x @@ TT, e)) + | inr (inl (sgna, a), inl (sgnb, b)) + => inl (sign.mul sgna sgnb, AppIdent idc (a, b)) + | inl _ => default end | ident.Z_add as idc - => fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) + => fun (x_y : expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in - match x_y return expr _ + type.interp _ with + match x_y return sign * expr _ + type.interp _ with | inr (inr x, inr y) => inr (ident.interp idc (x, y)) - | inr (inr x, inl e) - | inr (inl e, inr x) + | inr (inr x, inl (sgn, e)) + | inr (inl (sgn, e), inr x) + => if Z.eqb x 0 + then inl (sgn, e) + else match sgn with + | sign.pos => default + | sign.neg + => inl (sign.pos, + AppIdent + ident.Z.sub + (ident.primitive (t:=type.Z) x @@ TT, + e)) + end + | inr (inl (sign.pos, p), inl (sign.neg, n)) + | inr (inl (sign.neg, n), inl (sign.pos, p)) + => inl (sign.pos, AppIdent ident.Z.sub (p, n)%expr) + | inr (inl (sign.neg, a), inl (sign.neg, b)) + => inl (sign.neg, AppIdent idc (a, b)%expr) + | inr (inl (sign.pos, _), inl (sign.pos, _)) + | inl _ => default + end + | ident.Z_sub as idc + => fun (x_y : expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) + => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in + match x_y return sign * expr _ + type.interp _ with + | inr (inr x, inr y) => inr (ident.interp idc (x, y)) + | inr (inr x, inl (sgn, e)) + => if Z.eqb x 0 + then inl (sign.opp sgn, e) + else default + | inr (inl (sgn, e), inr x) => if Z.eqb x 0 - then inl e + then inl (sgn, e) else default - | inr (inl _, inl _) | inl _ => default + | inr (inl (sign.pos, p), inl (sign.neg, n)) + => inl (sign.pos, AppIdent ident.Z.add (p, n)) + | inr (inl (sign.neg, n), inl (sign.pos, p)) + => inl (sign.neg, AppIdent ident.Z.add (p, n)) + | inr (inl (sign.neg, a), inl (sign.neg, b)) + => inl (sign.pos, AppIdent ident.Z.sub (b, a)) + | inr (inl (sign.pos, _), inl (sign.pos, _)) + | inl _ => default end | ident.Z_zselect as idc | ident.Z_add_modulo as idc => fun (x_y_z : (expr (_ * _ * _) + - (expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) * (expr _ + type.interp _))%type) - => match x_y_z return expr _ + type.interp _ with + (expr (_ * _) + (sign * expr _ + type.interp _) * (sign * expr _ + type.interp _)) * (sign * expr _ + type.interp _))%type) + => match x_y_z return sign * expr _ + type.interp _ with | inr (inr (inr x, inr y), inr z) => inr (ident.interp idc (x, y, z)) | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z)) end @@ -3941,6 +4176,7 @@ Module Compilers. | List_nth {T : type.primitive} (n : nat) : ident (list T) T | mul (T1 T2 Tout : type.primitive) : ident (T1 * T2) Tout | add (T1 T2 Tout : type.primitive) : ident (T1 * T2) Tout + | sub (T1 T2 Tout : type.primitive) : ident (T1 * T2) Tout | shiftr (T1 Tout : type.primitive) (offset : BinInt.Z) : ident T1 Tout | shiftl (T1 Tout : type.primitive) (offset : BinInt.Z) : ident T1 Tout | land (T1 Tout : type.primitive) (mask : BinInt.Z) : ident T1 Tout @@ -3960,7 +4196,8 @@ Module Compilers. Notation curry3 f := (fun '(a, b, c) => f a b c). - Axiom admit : forall {T}, T. + Axiom admit_pf : False. + Notation admit := (match admit_pf with end). Definition resize {T1 Tout : type.primitive} : type.interp T1 -> type.interp Tout := match T1, Tout return type.interp T1 -> type.interp Tout with @@ -4021,6 +4258,7 @@ Module Compilers. | snd A B => @Datatypes.snd (type.interp A) (type.interp B) | List_nth T n => fun ls => @List.nth_default (type.interp T) default ls n | add T1 T2 Tout => @resize2uc type.Z type.Z type.Z _ _ _ (curry2 Z.add) + | sub T1 T2 Tout => @resize2uc type.Z type.Z type.Z _ _ _ (curry2 Z.sub) | mul T1 T2 Tout => @resize2uc type.Z type.Z type.Z _ _ _ (curry2 Z.mul) | shiftr _ _ n => @resize1 type.Z type.Z _ _ (fun v => Z.shiftr v n) | shiftl _ _ n => @resize1 type.Z type.Z _ _ (fun v => Z.shiftl v n) @@ -4041,6 +4279,7 @@ Module Compilers. Module Z. Notation mul := (@mul type.Z type.Z type.Z). + Notation sub := (@sub type.Z type.Z type.Z). Notation add := (@add type.Z type.Z type.Z). Notation shiftr := (@shiftr type.Z type.Z). Notation shiftl := (@shiftl type.Z type.Z). @@ -4101,6 +4340,44 @@ Module Compilers. else false end. + Definition primitive_bounded_by {t} : type.primitive.interp t -> zrange -> bool + := match t return primitive.interp t -> zrange -> bool with + | unit => fun _ _ => true + | Z => fun v r => (lower r <=? v) && (v <=? upper r) + | ZBounded _ _ + => fun v r => (lower r <=? value v) && (value v <=? upper r) + end. + + Fixpoint list_bounded_by {t} (ls : Datatypes.list (primitive.interp t)) + (r : Datatypes.list zrange) + : bool + := match ls, r with + | nil, nil => true + | nil, _ => false + | cons x xs, cons r rs + => primitive_bounded_by x r && @list_bounded_by _ xs rs + | cons _ _, _ => false + end. + + Lemma length_list_bounded_by {t} ls r + (H : @list_bounded_by t ls r = true) + : List.length ls = List.length r. + Proof. + revert dependent r; induction ls as [|x xs IHxs], r; + cbn; try reflexivity; try discriminate. + rewrite Bool.andb_true_iff. + intros [? ?]; erewrite IHxs; [ reflexivity | assumption ]. + Qed. + + Fixpoint bounded_by {t} : type.interp t -> range t -> bool + := match t return type.interp t -> range t -> bool with + | type_primitive x => primitive_bounded_by + | prod A B => fun '((a, a') : type.interp A * type.interp B) + '((r, r') : range A * range B) + => @bounded_by A a r && @bounded_by B a' r' + | list A => list_bounded_by + end. + Fixpoint type_for_range {t} : range t -> type := match t return range t -> type with | type_primitive _ => primitive_for_zrange @@ -4112,6 +4389,18 @@ Module Compilers. => type.list (primitive_for_zrange (List.fold_right ZRange.union r[0 ~> 0]%zrange ls)) end. + + Fixpoint type_for_range_bounded_by {t} + : forall (r : range t) + (v : type.interp (type_for_range r)), + bool + := match t return forall (r : range t), type.interp (type_for_range r) -> bool with + | type_primitive x => fun r v => primitive_bounded_by v r + | prod A B => fun '((r, r') : range A * range B) + '((a, a') : type.interp (type_for_range r) * type.interp (type_for_range r')) + => @type_for_range_bounded_by A r a && @type_for_range_bounded_by B r' a' + | list A => fun r v => list_bounded_by v r + end. End with_relax. End Range. @@ -4245,6 +4534,8 @@ Module Compilers. => Some Z.mul | default.ident.Z_add => Some Z.add + | default.ident.Z_sub + => Some Z.sub | ident.Z_shiftr n => Some (Z.shiftr n) | ident.Z_shiftl n @@ -4373,18 +4664,7 @@ Module Compilers. Context (relax_zrange : zrange -> option zrange). Local Notation primitive_for_zrange := (primitive_for_zrange relax_zrange). - - Fixpoint type_for_range {t} : range t -> type - := match t return range t -> type with - | type_primitive _ => primitive_for_zrange - | prod A B => fun (ab : range A * range B) - => prod (@type_for_range A (Datatypes.fst ab)) - (@type_for_range B (Datatypes.snd ab)) - | list A - => fun ls : Datatypes.list zrange - => type.list - (primitive_for_zrange (List.fold_right ZRange.union r[0 ~> 0]%zrange ls)) - end. + Local Notation type_for_range := (type_for_range relax_zrange). Definition upper_lor_and_bounds (x y : BinInt.Z) : BinInt.Z := 2^(1 + Z.log2_up (Z.max x y)). @@ -4582,6 +4862,11 @@ Module Compilers. (fun '(existT r args) => existT _ (ZRange.four_corners BinInt.Z.add (Datatypes.fst r) (Datatypes.snd r)) (AppIdent (add _ _ _) args)) + | sub _ _ _ + => option_map + (fun '(existT r args) + => existT _ (ZRange.four_corners BinInt.Z.sub (Datatypes.fst r) (Datatypes.snd r)) + (AppIdent (sub _ _ _) args)) | shiftr _ _ offset => option_map (fun '(existT r args) @@ -4696,7 +4981,7 @@ Module Compilers. (s_bounds : range (Indexed.OfPHOAS.type.compile s)) : option { bs : range (Indexed.OfPHOAS.type.compile d) & - @expr ident (AdjustBounds.ident.type_for_range relax_zrange bs) } + @expr ident (type_for_range relax_zrange bs) } := let e := Indexed.OfPHOAS.expr.Compile e in match e with | Some e @@ -4729,12 +5014,12 @@ Module Compilers. {d : Compilers.type.Notations.type} (relax_zrange : zrange -> option zrange) {bs : range (Indexed.OfPHOAS.type.compile d)} - (v : type.interp (AdjustBounds.ident.type_for_range relax_zrange bs)) + (v : type.interp (type_for_range relax_zrange bs)) {struct d} : Compilers.type.interp d := match d return (forall bs : range (Indexed.OfPHOAS.type.compile d), - type.interp (AdjustBounds.ident.type_for_range relax_zrange bs) + type.interp (type_for_range relax_zrange bs) -> Compilers.type.interp d) with | Compilers.type.type_primitive x => @cast_back_primitive _ _ @@ -4742,7 +5027,7 @@ Module Compilers. => fun (bs : (* ignore this line, for type inference badness *) range (Indexed.OfPHOAS.type.compile A) * range (Indexed.OfPHOAS.type.compile B)) (v : - (* ignore this line, for type inference badness *) type.interp (AdjustBounds.ident.type_for_range relax_zrange (fst bs)) * type.interp (AdjustBounds.ident.type_for_range relax_zrange (snd bs))) + (* ignore this line, for type inference badness *) type.interp (type_for_range relax_zrange (fst bs)) * type.interp (type_for_range relax_zrange (snd bs))) => (@cast_back A relax_zrange (fst bs) (fst v), @cast_back B relax_zrange (snd bs) (snd v)) | type.arrow s d => fun bs v _ => @cast_back d relax_zrange bs v @@ -4752,17 +5037,93 @@ Module Compilers. | Compilers.type.list A => fun _ _ => nil end bs v. + Definition option_cast_back + {d : Compilers.type.Notations.type} + (relax_zrange : zrange -> option zrange) + {bs : range (Indexed.OfPHOAS.type.compile d)} + (v : option (type.interp (type_for_range relax_zrange bs))) + := option_map (@cast_back d relax_zrange bs) v. + + Definition relax_is_good + (relax_zrange : zrange -> option zrange) + : Prop + := forall z z', relax_zrange z = Some z' + -> is_tighter_than_bool z z' = true. + + Definition cast_primitiveZ + (d : Compilers.type.primitive := Compilers.type.Z) + (relax_zrange : zrange -> option zrange) + {bs : zrange} + (v : Compilers.type.interp d) + (Hrelax : relax_is_good relax_zrange) + (Hv : is_bounded_by' None bs v) + : type.primitive.interp (primitive_for_zrange relax_zrange bs). + Proof. + cbn in *; hnf in Hv; cbv [primitive_for_zrange primitive_for_option_zrange]. + specialize (Hrelax bs). + destruct (relax_zrange bs) as [bs'|]; [ exists v | exact v ]. + clear d relax_zrange. + abstract ( + specialize (Hrelax bs' eq_refl); + unfold is_tighter_than_bool in *; + rewrite Bool.andb_true_iff in *; split; + destruct_head'_and; + Z.ltb_to_lt; + (apply Z.min_case_strong || apply Z.max_case_strong); + omega + ). + Defined. + + Lemma cast_back_primitive_cast_primitive + relax_zrange bs v Hrelax Hv + : cast_back_primitive relax_zrange + (@cast_primitiveZ relax_zrange bs v Hrelax Hv) + = v. + Proof. + cbv [cast_primitiveZ]; cbn. + generalize (Hrelax bs). + break_innermost_match; reflexivity. + Qed. + Definition Interp {s d : Compilers.type.Notations.type} (relax_zrange : zrange -> option zrange) (s_bounds : range (Indexed.OfPHOAS.type.compile s)) {bs : range (Indexed.OfPHOAS.type.compile d)} - (args : type.interp (AdjustBounds.ident.type_for_range relax_zrange s_bounds)) - (e : @expr ident (AdjustBounds.ident.type_for_range relax_zrange bs)) + (args : type.interp (type_for_range relax_zrange s_bounds)) + (e : @expr ident (type_for_range relax_zrange bs)) : option (Compilers.type.interp d) := let ctx := extendb (PositiveMap.empty _) 1 args in let v := Indexed.expr.interp (@ident.interp) e ctx in - option_map (cast_back relax_zrange) v. + option_cast_back relax_zrange v. + + Definition AnalyzeBoundsConst + {d : Compilers.type.Notations.type} + (relax_zrange : zrange -> option zrange) + (e : Expr d) + : option + { bs : range (Indexed.OfPHOAS.type.compile d) & + @expr ident (type_for_range relax_zrange bs) } + := let e := Indexed.OfPHOAS.expr.Compile e in + match e with + | Some e + => let e := AdjustBounds.expr.adjust_bounds + relax_zrange + (PositiveMap.empty _) + e in + e + | None => None + end. + + Definition InterpConst + {d : Compilers.type.Notations.type} + (relax_zrange : zrange -> option zrange) + {bs : range (Indexed.OfPHOAS.type.compile d)} + (e : @expr ident (type_for_range relax_zrange bs)) + : option (Compilers.type.interp d) + := let ctx := PositiveMap.empty _ in + let v := Indexed.expr.interp (@ident.interp) e ctx in + option_cast_back relax_zrange v. End OfPHOAS. End BoundsAnalysis. End Compilers. @@ -4962,7 +5323,33 @@ Module test6. Qed. End test6. -Axiom admit : forall {T}, T. +Axiom admit_pf : False. +Notation admit := (match admit_pf with end). + +Ltac cache_reify _ := + intros; + etransitivity; + [ + | repeat apply (f_equal (fun f => f _)); + Reify_rhs (); + reflexivity ]; + cbv beta; + let RHS := match goal with |- _ = ?RHS => RHS end in + let e := match RHS with context[expr.Interp _ ?e] => e end in + let E := fresh "E" in + set (E := e); + let E' := constr:(PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) in + let LHS := match goal with |- ?LHS = _ => LHS end in + lazymatch LHS with + | context LHS[@expr.Interp ?ident ?interp_ident ?t ?e] + => let LHS := context LHS[@expr.Interp ident interp_ident t E'] in + transitivity LHS; [ | clear e ] + end; + [ repeat match goal with |- context[expr.Interp _ _ _] => apply (f_equal (fun f => f _)) end; + apply f_equal; + time lazy; + reflexivity + | clearbody E ]. Derive carry_mul_gen SuchThat (forall (w : nat -> Z) @@ -4977,32 +5364,118 @@ Derive carry_mul_gen carry_mul_gen w s c n len_c idxs len_idxs fg = carry_mulmod w s c n len_c idxs len_idxs fg) As carry_mul_gen_correct. +Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. + +Derive carry_gen + SuchThat (forall (w : nat -> Z) + (f : list Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (len_c : nat) + (idxs : list nat) + (len_idxs : nat), + Interp (t:=type.reify_type_of carrymod) + carry_gen w s c n len_c idxs len_idxs f + = carrymod w s c n len_c idxs len_idxs f) + As carry_gen_correct. +Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. + +Derive encode_gen + SuchThat (forall (w : nat -> Z) + (v : Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (len_c : nat), + Interp (t:=type.reify_type_of encodemod) + encode_gen w s c n len_c v + = encodemod w s c n len_c v) + As encode_gen_correct. +Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. + +Derive add_gen + SuchThat (forall (w : nat -> Z) + (fg : list Z * list Z) + (n : nat), + Interp (t:=type.reify_type_of addmod) + add_gen w n fg + = addmod w n fg) + As add_gen_correct. +Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. + +Derive sub_gen + SuchThat (forall (w : nat -> Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (len_c : nat) + (coef : Z) + (fg : list Z * list Z), + Interp (t:=type.reify_type_of submod) + sub_gen w s c n len_c coef fg + = submod w s c n len_c coef fg) + As sub_gen_correct. +Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. + +Derive opp_gen + SuchThat (forall (w : nat -> Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (len_c : nat) + (coef : Z) + (f : list Z), + Interp (t:=type.reify_type_of oppmod) + opp_gen w s c n len_c coef f + = oppmod w s c n len_c coef f) + As opp_gen_correct. +Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. + +Definition zeromod w n s c len_c := encodemod w n s c len_c 0. +Definition onemod w n s c len_c := encodemod w n s c len_c 1. + +Derive zero_gen + SuchThat (forall (w : nat -> Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (len_c : nat), + Interp (t:=type.reify_type_of zeromod) + zero_gen w s c n len_c + = zeromod w s c n len_c) + As zero_gen_correct. +Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. + +Derive one_gen + SuchThat (forall (w : nat -> Z) + (n : nat) + (s : Z) + (c : list (Z * Z)) + (len_c : nat), + Interp (t:=type.reify_type_of onemod) + one_gen w s c n len_c + = onemod w s c n len_c) + As one_gen_correct. +Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. + +Definition expanding_id (n : nat) (ls : list Z) := expand_list (-1)%Z ls n. + +Lemma expanding_id_id n ls (H : List.length ls = n) + : expanding_id n ls = ls. Proof. - intros. - etransitivity. - Focus 2. - { repeat apply (f_equal (fun f => f _)). - Reify_rhs (). - reflexivity. - } Unfocus. - cbv beta. - let RHS := match goal with |- _ = ?RHS => RHS end in - let e := match RHS with context[expr.Interp _ ?e] => e end in - set (E := e). - Time let E' := constr:(PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) in - let E' := (eval vm_compute in E') in (* 0.131 for vm, about 0.6 for lazy, slower for native and cbv *) - pose E' as E''. - let LHS := match goal with |- ?LHS = _ => LHS end in - lazymatch LHS with - | context LHS[expr.Interp _ _] - => let LHS := context LHS[Interp E''] in - transitivity LHS - end; - [ clear E | exact admit ]. - subst carry_mul_gen. - reflexivity. + unfold expanding_id. rewrite expand_list_correct by assumption; reflexivity. Qed. +Derive id_gen + SuchThat (forall (n : nat) + (ls : list Z), + Interp (t:=type.reify_type_of expanding_id) + id_gen n ls + = expanding_id n ls) + As id_gen_correct. +Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed. + Module Pipeline. Inductive ErrorMessage := | Computed_bounds_are_not_tight_enough @@ -5010,6 +5483,7 @@ Module Pipeline. | Bounds_analysis_failed | Return_type_mismatch {T'} (found expected : T') | Value_not_le (descr : string) {T'} (lhs rhs : T') + | Value_not_lt (descr : string) {T'} (lhs rhs : T') | Values_not_provably_distinct (descr : string) {T'} (lhs rhs : T') | Values_not_provably_equal (descr : string) {T'} (lhs rhs : T'). @@ -5036,7 +5510,7 @@ Module Pipeline. arg_bounds out_bounds (E : Expr (s -> d)) - : ErrorT (BoundsAnalysis.Indexed.expr.Notations.expr (BoundsAnalysis.AdjustBounds.ident.type_for_range relax_zrange (t:=BoundsAnalysis.Indexed.OfPHOAS.type.compile d) out_bounds)) + : ErrorT (BoundsAnalysis.Indexed.expr.Notations.expr (BoundsAnalysis.Indexed.Range.type_for_range relax_zrange (t:=BoundsAnalysis.Indexed.OfPHOAS.type.compile d) out_bounds)) := let E := PartialReduce E in let E := ReassociateSmallConstants.Reassociate (2^8) E in let E := BoundsAnalysis.OfPHOAS.AnalyzeBounds relax_zrange E arg_bounds in @@ -5058,20 +5532,60 @@ Module Pipeline. rv (Hrv : BoundsPipeline relax_zrange arg_bounds out_bounds E = Success rv) : forall arg + (Harg : BoundsAnalysis.Indexed.Range.type_for_range_bounded_by + relax_zrange + arg_bounds arg = true) (arg' := @BoundsAnalysis.OfPHOAS.cast_back s relax_zrange arg_bounds arg), - BoundsAnalysis.OfPHOAS.Interp - (s:=s) - (d:=d) - relax_zrange - arg_bounds - (bs:=out_bounds) - arg - rv - = Some (Interp E arg'). + exists res, + let ctx := + BoundsAnalysis.Indexed.Context.extendb + (PositiveMap.empty _) 1 arg in + BoundsAnalysis.Indexed.expr.interp (@BoundsAnalysis.ident.interp) rv ctx + = Some res + /\ BoundsAnalysis.Indexed.Range.type_for_range_bounded_by + relax_zrange out_bounds res = true + /\ BoundsAnalysis.OfPHOAS.cast_back relax_zrange res + = Interp E arg'. + Proof. + Admitted. + + Definition BoundsPipelineConst + relax_zrange + {t} + bounds + (E : Expr t) + : ErrorT (BoundsAnalysis.Indexed.expr.Notations.expr (BoundsAnalysis.Indexed.Range.type_for_range relax_zrange (t:=BoundsAnalysis.Indexed.OfPHOAS.type.compile t) bounds)) + := let E := PartialReduce E in + let E := ReassociateSmallConstants.Reassociate (2^8) E in + let E := BoundsAnalysis.OfPHOAS.AnalyzeBoundsConst relax_zrange E in + let E := match E with + | Some (existT b v) + => if BoundsAnalysis.Indexed.Range.range_is_tighter_than b bounds + then transport_or_error BoundsAnalysis.Indexed.expr.Notations.expr v + else Error (Computed_bounds_are_not_tight_enough b bounds) + | None => Error Bounds_analysis_failed + end in + E. + + Lemma BoundsPipelineConst_correct + relax_zrange + {d} + bounds + (E : Expr d) + rv + (Hrv : BoundsPipelineConst relax_zrange bounds E = Success rv) + : exists res, + let ctx := PositiveMap.empty _ in + BoundsAnalysis.Indexed.expr.interp (@BoundsAnalysis.ident.interp) rv ctx + = Some res + /\ BoundsAnalysis.Indexed.Range.type_for_range_bounded_by + relax_zrange bounds res = true + /\ BoundsAnalysis.OfPHOAS.cast_back relax_zrange res + = Interp E. Proof. Admitted. End Pipeline. @@ -5147,6 +5661,9 @@ Proof. reflexivity. Qed. +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. + (** XXX TODO: Translate Jade's python script *) Section rcarry_mul. Context (n : nat) @@ -5156,7 +5673,12 @@ Section rcarry_mul. Let limbwidth := (Z.log2_up (s - Associational.eval c) / Z.of_nat n)%Q. Let idxs := (seq 0 n ++ [0; 1])%list%nat. - Let f_bounds := List.repeat r[0~>(2^Qceiling limbwidth + 2^(Qceiling limbwidth - 3))%Z]%zrange n. + Let upperbound_tight := (2^Qceiling limbwidth + 2^(Qceiling limbwidth - 3))%Z. + Let upperbound_loose := (3 * upperbound_tight)%Z. + Let f_bounds_tight := List.repeat r[0~>upperbound_tight]%zrange n. + Let f_bounds_loose := List.repeat r[0~>upperbound_loose]%zrange n. + Let prime_bound : BoundsAnalysis.Indexed.Range.range (BoundsAnalysis.Indexed.OfPHOAS.type.compile (type.Z)) + := r[0~>(s - Associational.eval c - 1)]%zrange. Definition relax_zrange_of_machine_wordsize := relax_zrange_gen [machine_wordsize; 2 * machine_wordsize]%Z. @@ -5169,23 +5691,30 @@ Section rcarry_mul. Let ridxs := GallinaReify.Reify idxs. Let rlen_c := GallinaReify.Reify (List.length c). Let rlen_idxs := GallinaReify.Reify (List.length idxs). + Let rcoef := GallinaReify.Reify 2. Let relax_zrange := relax_zrange_of_machine_wordsize. - Let arg_bounds : BoundsAnalysis.Indexed.Range.range (BoundsAnalysis.Indexed.OfPHOAS.type.compile (type.list type.Z * type.list type.Z)) - := (f_bounds, f_bounds). - Let out_bounds : BoundsAnalysis.Indexed.Range.range (BoundsAnalysis.Indexed.OfPHOAS.type.compile (type.list type.Z)) - := f_bounds. + Let tight_bounds : BoundsAnalysis.Indexed.Range.range (BoundsAnalysis.Indexed.OfPHOAS.type.compile (type.list type.Z)) + := f_bounds_tight. + Let tight2_bounds : BoundsAnalysis.Indexed.Range.range (BoundsAnalysis.Indexed.OfPHOAS.type.compile (type.list type.Z * type.list type.Z)) + := (tight_bounds, tight_bounds). + Let loose_bounds : BoundsAnalysis.Indexed.Range.range (BoundsAnalysis.Indexed.OfPHOAS.type.compile (type.list type.Z)) + := f_bounds_loose. + Let loose2_bounds : BoundsAnalysis.Indexed.Range.range (BoundsAnalysis.Indexed.OfPHOAS.type.compile (type.list type.Z * type.list type.Z)) + := (loose_bounds, loose_bounds). Definition check_args {T} (res : Pipeline.ErrorT T) : Pipeline.ErrorT T := if negb (Qle_bool 1 limbwidth)%Q then Pipeline.Error (Pipeline.Value_not_le "1 ≤ limbwidth" 1%Q limbwidth) - else if (s - Associational.eval c =? 0)%Z - then Pipeline.Error (Pipeline.Values_not_provably_distinct "s - Associational.eval c ≠ 0" (s - Associational.eval c) 0) + else if (negb (0 <? s - Associational.eval c))%Z + then Pipeline.Error (Pipeline.Value_not_lt "s - Associational.eval c ≤ 0" 0 (s - Associational.eval c)) else if (s =? 0)%Z then Pipeline.Error (Pipeline.Values_not_provably_distinct "s ≠ 0" s 0) else if (n =? 0)%nat then Pipeline.Error (Pipeline.Values_not_provably_distinct "n ≠ 0" n 0%nat) - else res. + else if (negb (0 <? machine_wordsize)) + then Pipeline.Error (Pipeline.Value_not_lt "0 < machine_wordsize" 0 machine_wordsize) + else res. Lemma check_args_success_id {T} {rv : T} {res} : check_args res = Pipeline.Success rv @@ -5194,13 +5723,114 @@ Section rcarry_mul. cbv [check_args]; break_innermost_match; congruence. Qed. - Definition rcarry_mul + Local Ltac solve_correct_gen pipeline_lem gen_correct := + let Hrv := lazymatch goal with H : ?rop = Pipeline.Success _ |- _ => H end in + let rop := lazymatch type of Hrv with ?rop = Pipeline.Success _ => rop end in + hnf; intros; cbv [rop] in Hrv; + eapply pipeline_lem in Hrv; [ | eassumption.. ]; + let res := fresh "res" in + destruct Hrv as [res Hrv]; + exists res; do 2 try apply conj; + [ | | etransitivity ]; + [ solve [ apply Hrv ].. | ]; + repeat match goal with H := _ |- _ => subst H end; + erewrite <- gen_correct; + cbv [expr.Interp]; + cbn [expr.interp]; + f_equal; + cbn -[reify_list]; + try (rewrite interp_reify_list, map_map; cbn; + erewrite map_ext with (g:=id), map_id; try reflexivity); + try (intros []; reflexivity). + Local Ltac solve_correct gen_correct := + solve_correct_gen Pipeline.BoundsPipeline_correct gen_correct. + Local Ltac solve_correct_const gen_correct := + solve_correct_gen Pipeline.BoundsPipelineConst_correct gen_correct. + + Let BoundsPipeline21 in_bounds out_bounds res := let res := Pipeline.BoundsPipeline - (relax_zrange) + relax_zrange (s:=(type.list type.Z * type.list type.Z)%ctype) (d:=(type.list type.Z)%ctype) - (arg_bounds) - (out_bounds) + (in_bounds, in_bounds) + out_bounds + res in + res. + + Let BoundsPipeline11 in_bounds out_bounds res + := let res := Pipeline.BoundsPipeline + relax_zrange + (s:=(type.list type.Z)%ctype) + (d:=(type.list type.Z)%ctype) + (in_bounds) + out_bounds + res in + res. + + Definition rexpr_1_correctT_ctx + ctx + out_bounds + (f : type.interp (type.list type.Z)) + rv + := (exists res, + BoundsAnalysis.Indexed.expr.interp (@BoundsAnalysis.ident.interp) rv ctx + = Some res + /\ BoundsAnalysis.Indexed.Range.type_for_range_bounded_by + relax_zrange out_bounds res = true + /\ BoundsAnalysis.OfPHOAS.cast_back relax_zrange res + = f). + + Definition rexpr_n1_correctT + t_in + in_bounds out_bounds + (f : _ -> type.interp (type.list type.Z)) + rv + := forall arg + (arg' := @BoundsAnalysis.OfPHOAS.cast_back + t_in + relax_zrange + in_bounds + arg) + (Harg : BoundsAnalysis.Indexed.Range.type_for_range_bounded_by + relax_zrange + (t:=BoundsAnalysis.Indexed.OfPHOAS.type.compile t_in) + in_bounds arg = true), + let ctx := + BoundsAnalysis.Indexed.Context.extendb + (PositiveMap.empty _) 1 arg in + @rexpr_1_correctT_ctx ctx out_bounds (f arg') rv. + + Definition rexpr_21_correctT + in_bounds out_bounds + (f : _ -> type.interp (type.list type.Z)) + rv + := @rexpr_n1_correctT (type.list type.Z * type.list type.Z) + (in_bounds, in_bounds) out_bounds f rv. + + Definition rexpr_11_correctT + in_bounds out_bounds + (f : _ -> type.interp (type.list type.Z)) + rv + := @rexpr_n1_correctT (type.list type.Z) + in_bounds out_bounds f rv. + + Definition rexpr_Z1_correctT + in_bounds out_bounds + (f : _ -> type.interp (type.list type.Z)) + rv + := @rexpr_n1_correctT type.Z + in_bounds out_bounds f rv. + + Definition rexpr_01_correctT + out_bounds + (f : type.interp (type.list type.Z)) + rv + := @rexpr_1_correctT_ctx (PositiveMap.empty _) out_bounds f rv. + + Definition rcarry_mul + := let res := BoundsPipeline21 + loose_bounds + tight_bounds (fun var => (carry_mul_gen _) @ (rw _) @@ -5211,84 +5841,733 @@ Section rcarry_mul. @ (ridxs _) @ (rlen_idxs _) )%expr in - check_args res. + res. Definition rcarry_mul_correctT rv - := forall arg - (arg' := @BoundsAnalysis.OfPHOAS.cast_back - _ - (relax_zrange) - (arg_bounds) - arg) - (Hf : List.length (fst arg) = n) - (Hg : List.length (snd arg) = n), - BoundsAnalysis.OfPHOAS.Interp - (relax_zrange) - (arg_bounds) - (bs:=out_bounds) - arg - rv - = Some (carry_mulmod (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs) arg'). + := Eval hnf in + rexpr_21_correctT + loose_bounds tight_bounds + (carry_mulmod (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs)) + rv. Lemma rcarry_mul_correct - rv - (Hrv : rcarry_mul = Pipeline.Success rv) + rv (Hrv : rcarry_mul = Pipeline.Success rv) : rcarry_mul_correctT rv. + Proof. solve_correct carry_mul_gen_correct. Qed. + + Definition rcarry + := let res := Pipeline.BoundsPipeline + relax_zrange + (s:=(type.list type.Z)%ctype) + (d:=(type.list type.Z)%ctype) + loose_bounds + tight_bounds + (fun var + => (carry_gen _) + @ (rw _) + @ (rs _) + @ (rc _) + @ (rn _) + @ (rlen_c _) + @ (ridxs _) + @ (rlen_idxs _) + )%expr in + res. + + Definition rcarry_correctT + rv + := Eval hnf in + rexpr_11_correctT + loose_bounds tight_bounds + (carrymod (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs)) + rv. + + Lemma rcarry_correct + rv (Hrv : rcarry = Pipeline.Success rv) + : rcarry_correctT rv. + Proof. solve_correct carry_gen_correct. Qed. + + Definition rrelax + := let res := Pipeline.BoundsPipeline + relax_zrange + (s:=(type.list type.Z)%ctype) + (d:=(type.list type.Z)%ctype) + tight_bounds + loose_bounds + (fun var + => (id_gen _) + @ (rn _) + )%expr in + res. + + Definition rrelax_correctT + rv + := Eval hnf in + rexpr_11_correctT + tight_bounds loose_bounds + (expanding_id n) + rv. + + Lemma rrelax_correct + rv (Hrv : rrelax = Pipeline.Success rv) + : rrelax_correctT rv. + Proof. solve_correct id_gen_correct. Qed. + + Definition radd + := let res := BoundsPipeline21 + tight_bounds + loose_bounds + (fun var + => (add_gen _) + @ (rw _) + @ (rn _) + )%expr in + res. + + Definition radd_correctT + rv + := Eval hnf in + rexpr_21_correctT + tight_bounds loose_bounds + (addmod (Interp rw) n) + rv. + + Lemma radd_correct + rv (Hrv : radd = Pipeline.Success rv) + : radd_correctT rv. + Proof. solve_correct add_gen_correct. Qed. + + Definition rsub + := let res := BoundsPipeline21 + tight_bounds + loose_bounds + (fun var + => (sub_gen _) + @ (rw _) + @ (rs _) + @ (rc _) + @ (rn _) + @ (rlen_c _) + @ (rcoef _) + )%expr in + res. + + Definition rsub_correctT + rv + := Eval hnf in + rexpr_21_correctT + tight_bounds loose_bounds + (submod (Interp rw) s c n (Interp rlen_c) (Interp rcoef)) + rv. + + Lemma rsub_correct + rv (Hrv : rsub = Pipeline.Success rv) + : rsub_correctT rv. + Proof. solve_correct sub_gen_correct. Qed. + + Definition ropp + := let res := BoundsPipeline11 + tight_bounds + loose_bounds + (fun var + => (opp_gen _) + @ (rw _) + @ (rs _) + @ (rc _) + @ (rn _) + @ (rlen_c _) + @ (rcoef _) + )%expr in + res. + + Definition ropp_correctT + rv + := Eval hnf in + rexpr_11_correctT + tight_bounds loose_bounds + (oppmod (Interp rw) s c n (Interp rlen_c) (Interp rcoef)) + rv. + + Lemma ropp_correct + rv (Hrv : ropp = Pipeline.Success rv) + : ropp_correctT rv. + Proof. solve_correct opp_gen_correct. Qed. + + Definition rencode + := let res := Pipeline.BoundsPipeline + relax_zrange + (s:=type.Z) + (d:=(type.list type.Z)%ctype) + (prime_bound) + tight_bounds + (fun var + => (encode_gen _) + @ (rw _) + @ (rs _) + @ (rc _) + @ (rn _) + @ (rlen_c _) + )%expr in + res. + + Definition rencode_correctT + rv + := Eval hnf in + rexpr_Z1_correctT + prime_bound tight_bounds + (encodemod (Interp rw) s c n (Interp rlen_c)) + rv. + + Lemma rencode_correct + rv (Hrv : rencode = Pipeline.Success rv) + : rencode_correctT rv. + Proof. solve_correct encode_gen_correct. Qed. + + Definition rzero + := let res := Pipeline.BoundsPipelineConst + relax_zrange + (t:=(type.list type.Z)%ctype) + tight_bounds + (fun var + => (zero_gen _) + @ (rw _) + @ (rs _) + @ (rc _) + @ (rn _) + @ (rlen_c _) + )%expr in + res. + + Definition rzero_correctT + rv + := Eval hnf in + rexpr_01_correctT + tight_bounds + (zeromod (Interp rw) s c n (Interp rlen_c)) + rv. + + Lemma rzero_correct + rv (Hrv : rzero = Pipeline.Success rv) + : rzero_correctT rv. + Proof. solve_correct_const zero_gen_correct. Qed. + + Definition rone + := let res := Pipeline.BoundsPipelineConst + relax_zrange + (t:=(type.list type.Z)%ctype) + tight_bounds + (fun var + => (one_gen _) + @ (rw _) + @ (rs _) + @ (rc _) + @ (rn _) + @ (rlen_c _) + )%expr in + res. + + Definition rone_correctT + rv + := Eval hnf in + rexpr_01_correctT + tight_bounds + (onemod (Interp rw) s c n (Interp rlen_c)) + rv. + + Lemma rone_correct + rv (Hrv : rone = Pipeline.Success rv) + : rone_correctT rv. + Proof. solve_correct_const one_gen_correct. Qed. + + Definition encodedT := { ls : option + (list (BoundsAnalysis.type.primitive.interp + (BoundsAnalysis.Indexed.Range.primitive_for_option_zrange + (relax_zrange (fold_right ZRange.union r[0 ~> 0]%zrange f_bounds_tight))))) + | exists res, + ls = Some res + /\ BoundsAnalysis.Indexed.Range.type_for_range_bounded_by + relax_zrange tight_bounds res = true }. + + Program Definition list_of_encodedT (v : encodedT) : list _ + := @Option.always_invert_Some + _ + (proj1_sig v) + _. + Next Obligation. + repeat match goal with H : _ |- _ => clear H end. + destruct v as [? [? [H0 H1] ] ]; cbn; subst; congruence. + Qed. + + Lemma length_list_of_encodedT v : List.length (list_of_encodedT v) = n. Proof. - hnf; intros. - cbv [rcarry_mul] in Hrv. - edestruct (Pipeline.BoundsPipeline _ _ _ _) as [rv'|] eqn:Hrv'; - [ | clear -Hrv; cbv [check_args] in Hrv; break_innermost_match_hyps; discriminate ]. - erewrite <- carry_mul_gen_correct. - eapply Pipeline.BoundsPipeline_correct in Hrv'. - apply check_args_success_id in Hrv; inversion Hrv; subst rv. - rewrite Hrv'. - cbv [expr.Interp]. - cbn [expr.interp]. - apply f_equal; f_equal; - cbn -[reify_list]; - rewrite interp_reify_list, map_map; cbn; - erewrite map_ext with (g:=id), map_id; try reflexivity. - intros []; reflexivity. + destruct v as [v [res [Hpf H] ] ]; destruct v; cbn in *; [ | discriminate ]. + apply BoundsAnalysis.Indexed.Range.length_list_bounded_by in H. + inversion_option; subst. + etransitivity; [ exact H | ]. + subst tight_bounds f_bounds_tight. + rewrite repeat_length; reflexivity. Qed. + Let m : positive := Z.to_pos (s - Associational.eval c). + Definition Zdecode (v : encodedT) + := BoundsAnalysis.OfPHOAS.cast_back + (d:=type.list type.Z) + _ + (list_of_encodedT v). + Definition Fdecode (v : encodedT) : F m + := F.of_Z m (Positional.eval (Interp rw) n (Zdecode v)). + Definition encodedT_eq (x y : encodedT) + := Fdecode x = Fdecode y. + + Lemma length_Zdecode v : List.length (Zdecode v) = n. + Proof. + cbn [Zdecode BoundsAnalysis.OfPHOAS.cast_back]. + rewrite map_length, length_list_of_encodedT; reflexivity. + Qed. - (** This code may eventually be useful; it proves that [check_args] - is sufficient to satisfy the preconditions of - [eval_carry_mulmod] *) - (* -<< - break_innermost_match_hyps; try solve [ exfalso; clear -Hrv; discriminate ]; []. - Z.ltb_to_lt. - rewrite negb_false_iff in *. - rewrite Qle_bool_iff in *. - rewrite NPeano.Nat.eqb_neq in *. - erewrite <- carry_mul_gen_correct - by (clear Hrv rv; try clear arg arg'; - generalize (@pow_ceil_mul_nat_divide_nonzero 2 limbwidth); - generalize (@pow_ceil_mul_nat_nonzero 2 limbwidth); - cbv [Qceiling Qfloor Qopp Qnum Qdiv Qplus inject_Z Qmult Qinv Qden]; - cbv [Qle] in *; - cbn; rewrite ?map_length, ?Z.mul_0_r, ?Pos.mul_1_r, ?Z.mul_1_r; - repeat match goal with H := _ |- _ => subst H || clear H end; - try destruct limbwidth; cbn in *; - do 2 try match goal with - | [ |- forall _, _ ] - => (let H := fresh in intro H; apply H) || intro - end; - intros; - repeat match goal with - | [ H : _ |- _ <> _ ] => eapply H - end; - try reflexivity; - try lia). ->> *) + Section make_ring. + Context (curve_good : check_args (Pipeline.Success tt) = Pipeline.Success tt) + {rcarry_mulv} (Hrmulv : rcarry_mul_correctT rcarry_mulv) + {rcarryv} (Hrcarryv : rcarry_correctT rcarryv) + {rrelaxv} (Hrrelaxv : rrelax_correctT rrelaxv) + {raddv} (Hraddv : radd_correctT raddv) + {rsubv} (Hrsubv : rsub_correctT rsubv) + {roppv} (Hroppv : ropp_correctT roppv) + {rzerov} (Hrzerov : rzero_correctT rzerov) + {ronev} (Hronev : rone_correctT ronev) + {rencodev} (Hrencodev : rencode_correctT rencodev). + + Local Ltac use_curve_good_t := + repeat first [ progress rewrite ?map_length, ?Z.mul_0_r, ?Pos.mul_1_r, ?Z.mul_1_r in * + | reflexivity + | lia + | rewrite interp_reify_list, ?map_map + | rewrite map_ext with (g:=id), map_id + | rewrite repeat_length + | progress cbv [Qceiling Qfloor Qopp Qdiv Qplus inject_Z Qmult Qinv] in * + | progress cbv [Qle] in * + | progress cbn -[reify_list] in * + | progress intros + | solve [ auto ] ]. + + Lemma use_curve_good + : Z.pos m = s - Associational.eval c + /\ (Interp rw 0%nat = 1) + /\ (forall i, Interp rw i <> 0) + /\ ' m <> 0 + /\ s - Associational.eval c <> 0 + /\ s <> 0 + /\ 0 < machine_wordsize + /\ n <> 0%nat + /\ List.length (Interp ridxs) = Interp rlen_idxs + /\ List.length (Interp rc) = Interp rlen_c + /\ List.length idxs = Interp rlen_idxs + /\ List.length c = Interp rlen_c + /\ List.length tight_bounds = n + /\ List.length loose_bounds = n + /\ forall i, Interp rw (S i) / Interp rw i <> 0. + Proof. + clear -curve_good. + cbv [check_args] in curve_good. + break_innermost_match_hyps; try discriminate. + rewrite negb_false_iff in *. + Z.ltb_to_lt. + rewrite Qle_bool_iff in *. + rewrite NPeano.Nat.eqb_neq in *. + generalize (@pow_ceil_mul_nat_divide_nonzero 2 limbwidth). + generalize (@pow_ceil_mul_nat_nonzero 2 limbwidth). + intros. + cbv [Qle] in *; cbn [Qnum Qden] in *. + cbv [Qceiling Qfloor Qopp Qdiv Qplus inject_Z Qmult Qinv] in *. + cbn [Qnum Qden] in *. + rewrite ?map_length, ?Z.mul_0_r, ?Pos.mul_1_r, ?Z.mul_1_r in *. + specialize_by lia. + repeat match goal with H := _ |- _ => subst H end. + repeat apply conj. + { destruct (s - Associational.eval c); cbn; lia. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + Qed. + + Local Lemma m_eq : Z.pos m = s - Associational.eval c. + Proof. apply use_curve_good. Qed. + + Local Lemma sc_pos : 0 < s - Associational.eval c. + Proof. pose proof use_curve_good; destruct_head'_and; lia. Qed. + + Local Lemma length_tight_bounds : List.length tight_bounds = n. + Proof. apply use_curve_good. Qed. + + Local Lemma length_loose_bounds : List.length loose_bounds = n. + Proof. apply use_curve_good. Qed. + + Local Arguments Z.pow !_ !_ . + Lemma relax_is_good + : BoundsAnalysis.OfPHOAS.relax_is_good relax_zrange. + Proof. + cbn; cbv [BoundsAnalysis.OfPHOAS.relax_is_good relax_zrange_gen]; cbn. + intros [l u] [l' u']; cbn. + pose proof (Z.log2_up_nonneg (u + 1)). + destruct (Z.log2_up_null (u + 1)). + unfold is_tighter_than_bool; rewrite Bool.andb_true_iff; + break_innermost_match; cbn; Z.ltb_to_lt; intros; + inversion_option; inversion_zrange; subst; + repeat apply conj; Z.ltb_to_lt; try omega; + try (rewrite <- ZUtil.Z.log2_up_le_pow2_full in * by lia; cbn in *; lia). + Qed. + + Local Notation option_interp0 f + := (BoundsAnalysis.Indexed.expr.interp + (@BoundsAnalysis.ident.interp) + f + (PositiveMap.empty _)). + + Local Notation option_interp1 f arg + := (arg' <- arg; + BoundsAnalysis.Indexed.expr.interp + (@BoundsAnalysis.ident.interp) + f + (BoundsAnalysis.Indexed.Context.extendb + (T:=BoundsAnalysis.type.list _) + (PositiveMap.empty _) + 1 arg'))%option. + + Local Notation option_interp2 f arg1 arg2 + := (x <- arg1; + y <- arg2; + BoundsAnalysis.Indexed.expr.interp + (@BoundsAnalysis.ident.interp) + f + (BoundsAnalysis.Indexed.Context.extendb + (T:=BoundsAnalysis.type.prod + (BoundsAnalysis.type.list _) + (BoundsAnalysis.type.list _)) + (PositiveMap.empty _) + 1 (x, y)))%option. + + Local Ltac solve_encodedT _ := + repeat match goal with + | _ => progress destruct_head' encodedT + | _ => progress (destruct_head'_ex; destruct_head'_and; subst) + | _ => progress cbn [proj1_sig Option.bind always_invert_Some] in * + | [ H : ?x = Some _ |- context[?x] ] => rewrite H + | [ H : ?x = Some _, H' : context[?x] |- _ ] => rewrite H in H' + | [ |- BoundsAnalysis.Indexed.Range.type_for_range_bounded_by _ _ (_, _) = true ] + => progress cbn [BoundsAnalysis.Indexed.Range.type_for_range_bounded_by BoundsAnalysis.Indexed.OfPHOAS.type.compile] + | _ => progress cbn [BoundsAnalysis.Indexed.Range.type_for_range BoundsAnalysis.Indexed.OfPHOAS.type.compile fst snd] in * (* for getting rewrite to match *) + | [ |- andb _ _ = true ] => rewrite Bool.andb_true_iff + | _ => solve [ eauto ] + | [ |- sig _ ] => progress rewrite expanding_id_id in * + | [ |- List.length (List.map _ _) = _ ] => rewrite map_length + | [ H : BoundsAnalysis.Indexed.Range.type_for_range_bounded_by _ _ ?x = true + |- List.length ?x = _ ] + => erewrite BoundsAnalysis.Indexed.Range.length_list_bounded_by by exact H + | _ => progress rewrite ?length_tight_bounds, ?length_loose_bounds + | [ H : map _ ?x = map _ _ |- _ ] => rewrite H in * + | [ H : map _ ?x = _ |- context[map _ ?x] ] + => cbv [tight_bounds loose_bounds] in H |- *; rewrite H + | _ => rewrite BoundsAnalysis.OfPHOAS.cast_back_primitive_cast_primitive in * + | [ fx := (BoundsAnalysis.Indexed.expr.interp _ ?rop (BoundsAnalysis.Indexed.Context.extendb _ _ ?x)), + H : forall arg, BoundsAnalysis.Indexed.Range.type_for_range_bounded_by _ ?bs arg = true -> rexpr_1_correctT_ctx _ _ _ ?rop + |- _ ] + => destruct (H x); [ clear H | subst fx ] + | [ fx := Option.bind ?x ?f, H : ?x = Some ?v |- _ ] + => let H' := fresh in + let fx' := fresh fx in + pose (Option.bind (Some v) f) as fx'; + assert (H' : fx = fx') + by (subst fx fx'; apply f_equal2; [ exact H | reflexivity ]); + cbn [Option.bind] in fx'; clearbody fx; subst fx + (*| [ H : forall arg, BoundsAnalysis.Indexed.Range.type_for_range_bounded_by _ ?bs arg = true -> rexpr_1_correctT_ctx _ _ _ ?rop + |- context[BoundsAnalysis.Indexed.expr.interp _ ?rop (BoundsAnalysis.Indexed.Context.extendb _ _ ?x)] ] + => specialize (H x); destruct H*) + | _ => progress cbn [BoundsAnalysis.OfPHOAS.cast_back] in * (* for getting eauto to work *) + end. + + Definition ring_mul_sig + : forall (x y : encodedT), + { v : encodedT + | Zdecode v = carry_mulmod (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs) (Zdecode x, Zdecode y) }. + Proof. + simple refine + (fun x y + => let x' := option_interp1 rrelaxv (proj1_sig x) in + let y' := option_interp1 rrelaxv (proj1_sig y) in + let v' := option_interp2 rcarry_mulv x' y' in + let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } + := _ in + exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). + cbv [Zdecode list_of_encodedT]; cbn [proj1_sig]. + hnf in Hrrelaxv, Hrmulv. + abstract solve_encodedT (). + Defined. + + Definition ring_add_sig + : forall (x y : encodedT), + { v : encodedT + | Zdecode v = carrymod + (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs) + (addmod (Interp rw) n (Zdecode x, Zdecode y)) }. + Proof. + simple refine + (fun x y + => let v'' := option_interp2 raddv (proj1_sig x) (proj1_sig y) in + let v' := option_interp1 rcarryv v'' in + let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } + := _ in + exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). + cbv [Zdecode list_of_encodedT]; cbn [proj1_sig]. + hnf in Hrcarryv, Hraddv; clear -Hrcarryv Hraddv. + abstract solve_encodedT (). + Defined. + Definition ring_sub_sig + : forall (x y : encodedT), + { v : encodedT + | Zdecode v = carrymod + (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs) + (submod (Interp rw) s c n (Interp rlen_c) (Interp rcoef) (Zdecode x, Zdecode y)) }. + Proof. + simple refine + (fun x y + => let v'' := option_interp2 rsubv (proj1_sig x) (proj1_sig y) in + let v' := option_interp1 rcarryv v'' in + let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } + := _ in + exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). + cbv [Zdecode list_of_encodedT]; cbn [proj1_sig]. + hnf in Hrcarryv, Hrsubv. + abstract solve_encodedT (). + Defined. + Definition ring_opp_sig + : forall (x : encodedT), + { v : encodedT + | Zdecode v = carrymod + (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs) + (oppmod (Interp rw) s c n (Interp rlen_c) (Interp rcoef) (Zdecode x)) }. + Proof. + simple refine + (fun x + => let v'' := option_interp1 roppv (proj1_sig x) in + let v' := option_interp1 rcarryv v'' in + let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } + := _ in + exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). + cbv [Zdecode list_of_encodedT]; cbn [proj1_sig]. + hnf in Hrcarryv, Hroppv. + abstract solve_encodedT (). + Defined. + Definition ring_zero_sig + : { v : encodedT + | Zdecode v = zeromod + (Interp rw) s c n (Interp rlen_c) }. + Proof. + simple refine + (let v' := option_interp0 rzerov in + let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } + := _ in + exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). + cbv [Zdecode list_of_encodedT]; cbn [proj1_sig]. + abstract ( + destruct Hrzerov; subst v'; + solve_encodedT () + ). + Defined. + Definition ring_one_sig + : { v : encodedT + | Zdecode v = onemod + (Interp rw) s c n (Interp rlen_c) }. + Proof. + simple refine + (let v' := option_interp0 ronev in + let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } + := _ in + exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). + cbv [Zdecode list_of_encodedT]; cbn [proj1_sig]. + abstract ( + destruct Hronev; subst v'; + solve_encodedT () + ). + Defined. + Arguments Z.mul !_ !_ . + Definition ring_encode_sig + : forall (x : F m), + { v : encodedT + | Zdecode v = encodemod + (Interp rw) s c n (Interp rlen_c) (F.to_Z x) }. + Proof. + simple refine + (fun v + => let pf0 := _ in + let pf1 := _ in + let arg := BoundsAnalysis.OfPHOAS.cast_primitiveZ + relax_zrange + (F.to_Z v) relax_is_good pf0 in + let Hrencodev' := Hrencodev arg pf1 in + let v' := BoundsAnalysis.Indexed.expr.interp + (@BoundsAnalysis.ident.interp) + rencodev + (BoundsAnalysis.Indexed.Context.extendb + (T:=BoundsAnalysis.type.type_primitive _) + (PositiveMap.empty _) + 1 arg) in + let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } + := _ in + exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). + { pose proof use_curve_good as keep. + clear -curve_good keep. + abstract ( + split; [ | exact I ]; + destruct v as [v Hv]; hnf; cbn; + rewrite m_eq in Hv; rewrite Hv; + pose proof (Z.mod_pos_bound v _ sc_pos); lia + ). } + { clearbody pf0. + clear -pf0. + cbn. + (** TODO: clean up this part of the proof. It is annoying + because [BoundsAnalysis.OfPHOAS.cast_primitiveZ] and + [BoundsAnalysis.Indexed.Range.type_for_range_bounded_by] + use different notions of the "interpretation" of a + not-necessarily-Z thing as probably-a-Z *) + generalize (BoundsAnalysis.OfPHOAS.cast_back_primitive_cast_primitive + relax_zrange _ _ relax_is_good pf0). + generalize (BoundsAnalysis.OfPHOAS.cast_primitiveZ + relax_zrange (F.to_Z v) relax_is_good pf0). + generalize (relax_is_good prime_bound). + cbv [BoundsAnalysis.Indexed.Range.primitive_bounded_by]. + cbn [BoundsAnalysis.type.primitive.interp]. + cbv [BoundsAnalysis.OfPHOAS.cast_back_primitive]. + cbv [BoundsAnalysis.Indexed.Range.primitive_for_zrange]. + abstract ( + break_innermost_match; cbn; + intros; destruct_head' BoundsAnalysis.type.BoundedZ; cbn [BoundsAnalysis.type.value] in *; + subst; + hnf in pf0; + cbv [prime_bound is_tighter_than_bool] in *; + try lazymatch goal with + | [ H : forall _, Some _ = Some _ -> _ |- _ ] + => specialize (H _ eq_refl) + end; + rewrite ?Bool.andb_true_iff in *; + destruct_head'_and; + apply conj; Z.ltb_to_lt; + cbn [upper lower] in *; + try omega + ). } + { cbv [Zdecode list_of_encodedT]; cbn [proj1_sig]. + subst v' arg; clearbody Hrencodev'; cbv beta zeta in *. + abstract ( + destruct Hrencodev'; + solve_encodedT () + ). } + Defined. + + Lemma length_addmod x y + : List.length (addmod (Interp rw) n (Zdecode x, Zdecode y)) = n. + Proof. + cbv [addmod]; apply length_add; rewrite expand_list_correct; + cbn [fst snd]; apply length_Zdecode. + Qed. + Lemma length_submod x y + : List.length (submod (Interp rw) s c n (Interp rlen_c) (Interp rcoef) (Zdecode x, Zdecode y)) = n. + Proof. + cbv [submod]; apply length_sub; rewrite expand_list_correct; + cbn [fst snd]; apply length_Zdecode. + Qed. + Lemma length_oppmod x + : List.length (oppmod (Interp rw) s c n (Interp rlen_c) (Interp rcoef) (Zdecode x)) = n. + Proof. + cbv [oppmod]; apply length_opp; rewrite expand_list_correct; + cbn [fst snd]; apply length_Zdecode. + Qed. + + Definition ring_mul x y := proj1_sig (ring_mul_sig x y). + Definition ring_add x y := proj1_sig (ring_add_sig x y). + Definition ring_sub x y := proj1_sig (ring_sub_sig x y). + Definition ring_opp x := proj1_sig (ring_opp_sig x). + Definition ring_zero := proj1_sig ring_zero_sig. + Definition ring_one := proj1_sig ring_one_sig. + Definition ring_encode x := proj1_sig (ring_encode_sig x). + + Definition GoodT : Prop + := @Hierarchy.ring + encodedT encodedT_eq ring_zero ring_one ring_opp ring_add ring_sub ring_mul + /\ @Ring.is_homomorphism + (F m) eq 1%F F.add F.mul + encodedT encodedT_eq ring_one ring_add ring_mul ring_encode + /\ @Ring.is_homomorphism + encodedT encodedT_eq ring_one ring_add ring_mul + (F m) eq 1%F F.add F.mul + Fdecode. + + Hint Rewrite ->@F.to_Z_add : push_FtoZ. + Hint Rewrite ->@F.to_Z_mul : push_FtoZ. + Hint Rewrite ->@F.to_Z_opp : push_FtoZ. + Hint Rewrite ->@F.to_Z_of_Z : push_FtoZ. + + Local Ltac rewrite_proj2_sig _ := + lazymatch goal with + | [ |- context[proj1_sig ?x] ] => rewrite (proj2_sig x) + end. + + Lemma Good : GoodT. + Proof. + pose proof use_curve_good. + destruct_head'_and. + eapply ring_by_isomorphism; intros; [ | reflexivity | .. ]; rewrite F.eq_to_Z_iff; + cbv [F.sub Fdecode]; + autorewrite with push_FtoZ; + pull_Zmod; + rewrite ?Z.add_opp_r. + { cbv [ring_encode]; rewrite_proj2_sig (). + erewrite m_eq, eval_encodemod, <- m_eq by assumption. + let A := lazymatch goal with A : F _ |- _ => A end in + destruct A as [v Hv]; cbn; congruence. } + { cbv [ring_zero]; rewrite_proj2_sig (). + cbv [zeromod]; erewrite m_eq, eval_encodemod, Zmod_0_l by eassumption; reflexivity. } + { cbv [ring_one]; rewrite_proj2_sig (). + cbv [onemod]; cbn; erewrite m_eq, eval_encodemod by eassumption; reflexivity. } + { cbv [ring_opp]; rewrite_proj2_sig (). + erewrite m_eq, eval_carrymod, eval_oppmod + by eauto using length_Zdecode, length_oppmod; + reflexivity. } + { cbv [ring_add]; rewrite_proj2_sig (). + erewrite m_eq, eval_carrymod, eval_addmod + by eauto using length_Zdecode, length_addmod; + reflexivity. } + { cbv [ring_sub]; rewrite_proj2_sig (). + erewrite m_eq, eval_carrymod, eval_submod + by eauto using length_Zdecode, length_submod; + reflexivity. } + { cbv [ring_mul]; rewrite_proj2_sig (). + erewrite m_eq, eval_carry_mulmod + by eauto using length_Zdecode; reflexivity. } + Qed. + End make_ring. End rcarry_mul. -Ltac solve_rcarry_mul _ := - eapply rcarry_mul_correct; - lazy; reflexivity. +Ltac solve_rcarry_mul _ := eapply rcarry_mul_correct; lazy; reflexivity. +Ltac solve_rcarry _ := eapply rcarry_correct; lazy; reflexivity. +Ltac solve_radd _ := eapply radd_correct; lazy; reflexivity. +Ltac solve_rsub _ := eapply rsub_correct; lazy; reflexivity. +Ltac solve_ropp _ := eapply ropp_correct; lazy; reflexivity. +Ltac solve_rencode _ := eapply rencode_correct; lazy; reflexivity. +Ltac solve_rrelax _ := eapply rrelax_correct; lazy; reflexivity. +Ltac solve_rzero _ := eapply rzero_correct; lazy; reflexivity. +Ltac solve_rone _ := eapply rone_correct; lazy; reflexivity. Module PrintingNotations. Export ident. @@ -5303,6 +6582,8 @@ Module PrintingNotations. := (r[0 ~> 18446744073709551615]) : btype_scope. Notation "'uint32'" := (r[0 ~> 4294967295]) : btype_scope. + Notation "'ℤ'" + := BoundsAnalysis.type.Z : btype_scope. Notation "ls [[ n ]]" := (List.nth n @@ ls)%nexpr : nexpr_scope. Notation "x *₆₄₋₆₄₋₁₂₈ y" := (mul uint64 uint64 uint128 @@ (x, y))%nexpr (at level 40) : nexpr_scope. @@ -5326,6 +6607,16 @@ Module PrintingNotations. := (add uint64 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. Notation "x +₃₂ y" := (add uint32 uint32 uint32 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x -₁₂₈ y" + := (sub uint128 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x -₆₄₋₁₂₈₋₁₂₈ y" + := (sub uint64 uint128 uint128 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x -₃₂₋₆₄₋₆₄ y" + := (sub uint32 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x -₆₄ y" + := (sub uint64 uint64 uint64 @@ (x, y))%nexpr (at level 50) : nexpr_scope. + Notation "x -₃₂ y" + := (sub uint32 uint32 uint32 @@ (x, y))%nexpr (at level 50) : nexpr_scope. Notation "x" := ({| BoundsAnalysis.type.value := x |}) (only printing) : nexpr_scope. Notation "( out_t )( v >> count )" := ((shiftr _ out_t count @@ v)%nexpr) @@ -5363,6 +6654,58 @@ Module X25519_64. SuchThat (rcarry_mul_correctT n s c machine_wordsize base_51_carry_mul) As base_51_carry_mul_correct. Proof. Time solve_rcarry_mul (). Time Qed. + Derive base_51_carry + SuchThat (rcarry_correctT n s c machine_wordsize base_51_carry) + As base_51_carry_correct. + Proof. Time solve_rcarry (). Time Qed. + Derive base_51_relax + SuchThat (rrelax_correctT n s c machine_wordsize base_51_relax) + As base_51_relax_correct. + Proof. Time solve_rrelax (). Time Qed. + Derive base_51_add + SuchThat (radd_correctT n s c machine_wordsize base_51_add) + As base_51_add_correct. + Proof. Time solve_radd (). Time Qed. + Derive base_51_sub + SuchThat (rsub_correctT n s c machine_wordsize base_51_sub) + As base_51_sub_correct. + Proof. Time solve_rsub (). Time Qed. + Derive base_51_opp + SuchThat (ropp_correctT n s c machine_wordsize base_51_opp) + As base_51_opp_correct. + Proof. Time solve_ropp (). Time Qed. + Derive base_51_encode + SuchThat (rencode_correctT n s c machine_wordsize base_51_encode) + As base_51_encode_correct. + Proof. Time solve_rencode (). Time Qed. + Derive base_51_zero + SuchThat (rzero_correctT n s c machine_wordsize base_51_zero) + As base_51_zero_correct. + Proof. Time solve_rzero (). Time Qed. + Derive base_51_one + SuchThat (rone_correctT n s c machine_wordsize base_51_one) + As base_51_one_correct. + Proof. Time solve_rone (). Time Qed. + Lemma base_51_curve_good + : check_args n s c machine_wordsize (Pipeline.Success tt) = Pipeline.Success tt. + Proof. vm_compute; reflexivity. Qed. + + Definition base_51_goodT + := GoodT n s c machine_wordsize + base_51_curve_good + base_51_carry_mul_correct + base_51_carry_correct + base_51_relax_correct + base_51_add_correct + base_51_sub_correct + base_51_opp_correct + base_51_zero_correct + base_51_one_correct + base_51_encode_correct. + Theorem base_51_good : base_51_goodT. + Proof. apply Good. Qed. + + Print Assumptions base_51_good. Import PrintingNotations. Print base_51_carry_mul. @@ -5415,8 +6758,16 @@ Module X25519_64. x_19 :: x_22 :: x_21 +₆₄ x_10 :: x_13 :: x_16 :: [])%nexpr : expr (BoundsAnalysis.AdjustBounds.ident.type_for_range - (relax_zrange (make_carry_mul_rargs limbwidth s c machine_wordsize)) - (out_bounds (make_carry_mul_rargs limbwidth s c machine_wordsize))) + (relax_zrange_of_machine_wordsize machine_wordsize) + (List.repeat + r[0 ~> 2 + ^ Qceiling + (inject_Z (Z.log2_up (s - Associational.eval c)) / + inject_Z (BinInt.Z.of_nat n)) + + 2 + ^ (Qceiling + (inject_Z (Z.log2_up (s - Associational.eval c)) / + inject_Z (BinInt.Z.of_nat n)) - 3)]%zrange n)) *) End X25519_64. @@ -5593,8 +6944,8 @@ Module X25519_32. x_34 :: x_37 :: x_36 +₃₂ x_10 :: x_13 :: x_16 :: x_19 :: x_22 :: x_25 :: x_28 :: x_31 :: [])%nexpr : expr (BoundsAnalysis.AdjustBounds.ident.type_for_range - (relax_zrange (make_carry_mul_rargs limbwidth s c machine_wordsize)) - (out_bounds (make_carry_mul_rargs limbwidth s c machine_wordsize))) + (relax_zrange (make_carry_mul_rargs n s c machine_wordsize)) + (out_bounds (make_carry_mul_rargs n s c machine_wordsize))) *) End X25519_32. *) |