From e6b25ccf1bb99582c3f83aa8b3b4fe3f0de31870 Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 9 May 2018 13:17:52 +0200 Subject: Proofs for pre-fancy pass (could use cleanup) --- src/Experiments/SimplyTypedArithmetic.v | 1175 ++++++++++++++++++++++++------- 1 file changed, 904 insertions(+), 271 deletions(-) (limited to 'src/Experiments/SimplyTypedArithmetic.v') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index e7c46ee9a..4a5be2505 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -3031,14 +3031,14 @@ Module Compilers. | Z_mul_split => curry3 Z.mul_split | Z_mul_split_concrete s => curry2 (Z.mul_split s) | Z_add_get_carry => curry3 Z.add_get_carry_full - | Z_add_get_carry_concrete s => curry2 (Z.add_get_carry s) + | Z_add_get_carry_concrete s => curry2 (Z.add_get_carry_full s) | Z_add_with_carry => curry3 Z.add_with_carry | Z_add_with_get_carry => curry4 Z.add_with_get_carry_full - | Z_add_with_get_carry_concrete s => curry3 (Z.add_with_get_carry s) + | Z_add_with_get_carry_concrete s => curry3 (Z.add_with_get_carry_full s) | Z_sub_get_borrow => curry3 Z.sub_get_borrow_full - | Z_sub_get_borrow_concrete s => curry2 (Z.sub_get_borrow s) + | Z_sub_get_borrow_concrete s => curry2 (Z.sub_get_borrow_full s) | Z_sub_with_get_borrow => curry4 Z.sub_with_get_borrow_full - | Z_sub_with_get_borrow_concrete s => curry3 (Z.sub_with_get_borrow s) + | Z_sub_with_get_borrow_concrete s => curry3 (Z.sub_with_get_borrow_full s) | Z_zselect => curry3 Z.zselect | Z_add_modulo => curry3 Z.add_modulo | Z_rshi => curry4 Z.rshi @@ -4198,13 +4198,13 @@ Module Compilers. | ident.Z_mul_split_concrete s => cps_of (curry2 (Z.mul_split s)) | ident.Z_add_get_carry_concrete s - => cps_of (curry2 (Z.add_get_carry s)) + => cps_of (curry2 (Z.add_get_carry_full s)) | ident.Z_add_with_get_carry_concrete s - => cps_of (curry3 (Z.add_with_get_carry s)) + => cps_of (curry3 (Z.add_with_get_carry_full s)) | ident.Z_sub_get_borrow_concrete s - => cps_of (curry2 (Z.sub_get_borrow s)) + => cps_of (curry2 (Z.sub_get_borrow_full s)) | ident.Z_sub_with_get_borrow_concrete s - => cps_of (curry3 (Z.sub_with_get_borrow s)) + => cps_of (curry3 (Z.sub_with_get_borrow_full s)) | ident.Z_rshi_concrete s n => cps_of (curry2 (fun x y => Z.rshi s x y n)) | ident.Z_cc_m_concrete s @@ -7919,6 +7919,22 @@ End X25519_32. Module Straightline. Module expr. + (* TODO: move these to a better location *) + Module type. + Definition primitive_eq_dec (a b : type.primitive) : {a = b} + {a <> b}. + Proof. destruct a,b; auto; right; congruence. Defined. + Fixpoint type_eq_dec (a b : type) : {a = b} + {a <> b}. + Proof. + destruct a, b; try solve [right; congruence]; [ | | | ]. + { destruct (primitive_eq_dec p p0); subst; [left | right]; congruence. } + { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence]. + left; congruence. } + { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence]. + left; congruence. } + { destruct (type_eq_dec a b); [left | right]; congruence. } + Defined. + End type. + Section with_var. Context {var : type.type -> Type}. Context {dummy_arrow : forall s d, var (s -> d)}. (* TODO: remove once arrow-containing pairs are removed at type level *) @@ -7927,24 +7943,26 @@ Module Straightline. Section with_ident. Context {ident : type.type -> type.type -> Type}. + Inductive scalar : type.type -> Type := + | Var t : var t -> scalar t + | TT : scalar (type.type_primitive type.unit) + | Nil t : scalar (type.list t) + | Pair {a b} : scalar a -> scalar b -> scalar (a * b) + | Cast : zrange -> scalar type.Z -> scalar type.Z + | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z) + | Fst {a b} : scalar (a * b) -> scalar a + | Snd {a b} : scalar (a * b) -> scalar b + | Shiftr : Z -> scalar type.Z -> scalar type.Z + | Shiftl : Z -> scalar type.Z -> scalar type.Z + | Land : Z -> scalar type.Z -> scalar type.Z + | CC_m : Z -> scalar type.Z -> scalar type.Z + | Primitive {t} : type.interp (type.type_primitive t) -> scalar t + . + Inductive expr : type.type -> Type := | Scalar {t} : scalar t -> expr t | LetInAppIdentZ {s d} : zrange -> ident s (type.Z) -> scalar s -> (var (type.Z) -> expr d) -> expr d | LetInAppIdentZZ {s d} : zrange * zrange -> ident s (type.Z*type.Z) -> scalar s -> (var (type.Z*type.Z) -> expr d) -> expr d - with scalar : type.type -> Type := - | Var t : var t -> scalar t - | TT : scalar (type.type_primitive type.unit) - | Nil t : scalar (type.list t) - | Pair {a b} : scalar a -> scalar b -> scalar (a * b) - | Cast : zrange -> scalar type.Z -> scalar type.Z - | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z) - | Fst {a b} : scalar (a * b) -> scalar a - | Snd {a b} : scalar (a * b) -> scalar b - | Shiftr : Z -> scalar type.Z -> scalar type.Z - | Shiftl : Z -> scalar type.Z -> scalar type.Z - | Land : Z -> scalar type.Z -> scalar type.Z - | CC_m : Z -> scalar type.Z -> scalar type.Z - | Primitive {t} : type.interp (type.type_primitive t) -> scalar t . Fixpoint dummy_scalar t : scalar t := @@ -7960,7 +7978,7 @@ Module Straightline. Definition of_uncurried_scalar_ident {s d} (idc : ident.ident s d) : scalar s -> option (scalar d) := - match idc in ident.ident s d return scalar (ident:=ident.ident) s -> option (scalar d) with + match idc in ident.ident s d return scalar s -> option (scalar d) with | ident.Z.cast r => fun args => Some (Cast r args) | ident.Z.cast2 r => fun args => Some (Cast2 r args) | @ident.fst A B => fun args => Some (Fst args) @@ -8131,20 +8149,20 @@ Module Straightline. End depth. Section interp. - Local Notation scalar := (@scalar type.interp default.ident). - Local Notation expr := (@expr type.interp default.ident). + Context {ident : type -> type -> Type} {interp_ident : forall s d, ident s d -> type.interp s -> type.interp d}. Definition interp_cast (r : zrange) (x : Z) : Z := ident.cast ident.cast_outside_of_range r x. - - Fixpoint interp_scalar {t} (s : scalar t) : type.interp t := + Definition interp_cast2 (r : zrange * zrange) (x : Z * Z) : Z * Z := + (interp_cast (fst r) (fst x), interp_cast (snd r) (snd x)). + + Fixpoint interp_scalar {t} (s : @scalar type.interp t) : type.interp t := match s with | Var t v => v | TT => tt | Nil _ => [] | Pair _ _ x y => (interp_scalar x, interp_scalar y) | Cast r x => interp_cast r (interp_scalar x) - | Cast2 r x => - let '(a,b) := interp_scalar x in (interp_cast (fst r) a, interp_cast (snd r) b) + | Cast2 r x => interp_cast2 r (interp_scalar x) | Fst _ _ p => fst (interp_scalar p) | Snd _ _ p => snd (interp_scalar p) | Shiftr n x => Z.shiftr (interp_scalar x) n @@ -8154,15 +8172,20 @@ Module Straightline. | Primitive _ x => x end. - Fixpoint interp {t} (e : expr t) : type.interp t := + Fixpoint interp {t} (e : @expr type.interp ident t) : type.interp t := match e with | Scalar _ s => interp_scalar s | LetInAppIdentZ _ _ r idc x f => - interp (f (interp_cast r (ident.interp idc (interp_scalar x)))) - | LetInAppIdentZZ _ _ (r1,r2) idc x f => - let '(a,b) := ident.interp idc (interp_scalar x) in - interp (f (interp_cast r1 a, interp_cast r2 b)) + interp (f (interp_cast r (interp_ident _ _ idc (interp_scalar x)))) + | LetInAppIdentZZ _ _ r idc x f => + interp (f (interp_cast2 r (interp_ident _ _ idc (interp_scalar x)))) end. + End interp. + + Section proofs. + Local Notation straightline_interp := (expr.interp (ident:=default.ident) (interp_ident:=@ident.interp)). + Local Notation uinterp := (Uncurried.expr.interp (@ident.interp)). + Local Notation uexpr := (@Uncurried.expr.expr ident type.interp). Inductive ok_scalar_ident : forall {s d}, ident.ident s d -> Prop := | ok_si_cast : forall r, ok_scalar_ident (ident.Z.cast r) @@ -8176,7 +8199,7 @@ Module Straightline. | ok_prim : forall p x, ok_scalar_ident (@ident.primitive p x) . - Inductive ok_scalar: forall {t}, @Uncurried.expr.expr ident type.interp t -> Prop := + Inductive ok_scalar: forall {t}, uexpr t -> Prop := | ok_Var : forall t v, @ok_scalar t (Uncurried.expr.Var v) | ok_TT : ok_scalar Uncurried.expr.TT | ok_AppIdent : @@ -8191,7 +8214,7 @@ Module Straightline. ok_scalar (Uncurried.expr.Pair a b) . - Inductive ok_expr : forall {t}, Uncurried.expr.expr t -> Prop := + Inductive ok_expr : forall {t}, uexpr t -> Prop := | ok_LetInAppIdentZ : forall tC r s (idc : ident s type.Z) x k, ok_scalar x -> (forall y, @ok_expr tC (k y)) -> @@ -8214,34 +8237,18 @@ Module Straightline. forall t x, @ok_scalar t x -> @ok_expr t x . - Lemma interp_cast_correct {t} r (idc : ident t type.Z) x : - interp_cast r (ident.interp idc (expr.interp (@ident.interp) x)) - = expr.interp (@ident.interp) (AppIdent (ident.Z.cast r) (AppIdent idc x)). + Lemma interp_cast_correct r (x : uexpr type.Z) : + interp_cast r (uinterp x) = uinterp (AppIdent (ident.Z.cast r) x). Proof. reflexivity. Qed. - Lemma interp_cast2_correct {t} r1 r2 (idc : ident t (type.prod type.Z type.Z)) x : - (interp_cast r1 (fst (ident.interp idc (expr.interp (@ident.interp) x))), - interp_cast r2 (snd (ident.interp idc (expr.interp (@ident.interp) x)))) - = expr.interp (@ident.interp) (AppIdent (ident.Z.cast2 (r1, r2)) (AppIdent idc x)). + Lemma interp_cast2_correct r (x : uexpr (type.prod type.Z type.Z)) : + interp_cast2 r (uinterp x) = uinterp (AppIdent (ident.Z.cast2 r) x). Proof. cbn; break_match; reflexivity. Qed. - Definition primitive_eq_dec (a b : type.primitive) : {a = b} + {a <> b}. - Proof. destruct a,b; auto; right; congruence. Defined. - Fixpoint type_eq_dec (a b : type) : {a = b} + {a <> b}. - Proof. - destruct a, b; try solve [right; congruence]; [ | | | ]. - { destruct (primitive_eq_dec p p0); subst; [left | right]; congruence. } - { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence]. - left; congruence. } - { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence]. - left; congruence. } - { destruct (type_eq_dec a b); [left | right]; congruence. } - Defined. - Ltac invert H := inversion H; subst; repeat match goal with - | H : existT _ _ _ = existT _ _ _ |- _ => apply (Eqdep_dec.inj_pair2_eq_dec _ type_eq_dec) in H; subst + | H : existT _ _ _ = existT _ _ _ |- _ => apply (Eqdep_dec.inj_pair2_eq_dec _ type.type_eq_dec) in H; subst end. Ltac invert_ok_expr := @@ -8253,7 +8260,7 @@ Module Straightline. Ltac simpl_inversions := cbn [invert_LetInAppIdent invert_LetInCast invert_Pair invert_cast invert_AppIdent invert_Abs]. - Lemma invert_AppIdent_correct {d} (e : @Uncurried.expr.expr ident type.interp d) x p : + Lemma invert_AppIdent_correct {d} (e : uexpr d) x p : invert_AppIdent e = Some (existT (fun s : type => (ident s d * default.expr s)%type) x p) -> e = AppIdent (fst p) (snd p). Proof. @@ -8268,37 +8275,38 @@ Module Straightline. Lemma of_uncurried_scalar_ident_correct {s d} (idc : ident s d) args args': ok_scalar_ident idc -> of_uncurried_scalar args = Some args' -> - interp_scalar args' = expr.interp (@ident.interp) args -> + interp_scalar args' = uinterp args -> exists s, of_uncurried_scalar_ident idc args' = Some s - /\ interp_scalar s = expr.interp (@ident.interp) (AppIdent idc args). + /\ interp_scalar s = uinterp (AppIdent idc args). Proof. destruct 1; intros; repeat match goal with | _ => eexists; split; [ reflexivity | cbn [interp_scalar] ] | H : interp_scalar _ = _ |- _ => rewrite H | _ => reflexivity + | _ => solve [auto using interp_cast2_correct] + | |- context [@Uncurried.expr.interp _ _ (type.type_primitive _)] => + cbn; break_match; reflexivity end. - { cbn; break_match; reflexivity. } - { cbn; break_match; reflexivity. } Qed. - Lemma of_uncurried_scalar_correct {t} (e : Uncurried.expr.expr t) : + Lemma of_uncurried_scalar_correct {t} (e : uexpr t) : ok_scalar e -> exists s, of_uncurried_scalar e = Some s - /\ interp_scalar s = expr.interp (@ident.interp) e. + /\ interp_scalar s = uinterp e. Proof. induction 1; cbn [of_uncurried_scalar]; intros; repeat match goal with + | _ => progress cbn [interp_scalar] | IH : exists _, _ /\ _ |- _ => destruct IH as [? [? ?] ] | H : of_uncurried_scalar _ = _ |- _ => rewrite H + | H : interp_scalar _ = _ |- _ => rewrite H | _ => apply of_uncurried_scalar_ident_correct; solve [auto] - | _ => eexists; tauto - end; [ ]. - eexists; split; [ reflexivity | ]. - cbn [interp_scalar expr.interp]. - congruence. + | _ => eexists; split; [ reflexivity | ] + | _ => reflexivity + end. Qed. Ltac rewrite_ok_scalar := @@ -8312,10 +8320,10 @@ Module Straightline. end. Lemma of_uncurried_correct dummy_arrow fuel dummy_var : - forall {t} (e : @Uncurried.expr.expr ident type.interp t), + forall {t} (e : uexpr t), (depth _ dummy_var e <= fuel)%nat -> ok_expr e -> - expr.interp (@ident.interp) e = interp (@of_uncurried _ dummy_arrow fuel _ e). + uinterp e = straightline_interp (@of_uncurried _ dummy_arrow fuel _ e). Proof. induction fuel; intros; [ pose proof (depth_positive dummy_var e); omega | ]. destruct e; cbn [depth of_uncurried expr.interp interp]; intros; invert_ok_expr; @@ -8327,7 +8335,7 @@ Module Straightline. end; [ | | | | ]. { match goal with H : interp_scalar _ = _ |- _ => rewrite H end. - rewrite <-IHfuel; rewrite interp_cast_correct. + rewrite <-IHfuel. { reflexivity. } { cbn [depth] in *. (* here we have to reason about the depth calculation for arrows; this will probably be unnecessary with new compilers setup *) @@ -8335,12 +8343,7 @@ Module Straightline. { auto. } } { match goal with H : interp_scalar _ = _ |- _ => rewrite H end. - break_match. - match goal with H : ?p = (?x, ?y) |- _ => - replace x with (fst p) by (rewrite H; reflexivity); - replace y with (snd p) by (rewrite H; reflexivity) - end. - rewrite <-IHfuel. rewrite interp_cast2_correct. + rewrite <-IHfuel. { cbn; break_match; reflexivity. } { cbn [depth] in *. (* here we have to reason about the depth calculation for arrows; this will probably be unnecessary with new compilers setup *) @@ -8348,16 +8351,11 @@ Module Straightline. { auto. } } { match goal with H : interp_scalar _ = _ |- _ => rewrite H end. - rewrite interp_cast_correct. + rewrite <-interp_cast_correct. reflexivity. } { match goal with H : interp_scalar _ = _ |- _ => rewrite H end. - break_match. - match goal with H : ?p = (?x, ?y) |- _ => - replace x with (fst p) by (rewrite H; reflexivity); - replace y with (snd p) by (rewrite H; reflexivity) - end. - rewrite interp_cast2_correct. + rewrite <-interp_cast2_correct. cbn; break_match; reflexivity. } { invert_ok_scalar. rewrite <-H2. @@ -8385,7 +8383,7 @@ Module Straightline. repeat match goal with H : interp_scalar _ = _ |- _ => rewrite H end. destruct r; reflexivity. } Admitted. - End interp. + End proofs. End expr. Definition of_Expr {s d} (e : Expr (s->d)) (var : type -> Type) (x:var s) dummy_arrow: expr.expr d @@ -8466,213 +8464,848 @@ End StraightlineTest. (* Convert straightline code to code that uses only a certain set of identifiers *) Module PreFancy. - Section with_var. - Import Straightline.expr. - Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d)) (log2wordmax : Z) - (constant_to_scalar : forall ident, Z -> option (@scalar var ident type.Z)). - Local Notation Z := (type.type_primitive type.Z). + Import Straightline.expr. + Section with_wordmax. + Context (log2wordmax : Z) (log2wordmax_pos : 1 < log2wordmax). Let wordmax := 2 ^ log2wordmax. Let half_bits := log2wordmax / 2. Let wordmax_half_bits := 2 ^ half_bits. - Inductive ident : type -> type -> Type := - | add : ident (Z * Z) (Z * Z) - | addc : ident (Z * Z * Z) (Z * Z) - | mulll : ident (Z * Z) Z - | mullh : ident (Z * Z) Z - | mulhl : ident (Z * Z) Z - | mulhh : ident (Z * Z) Z - | sub : ident (Z * Z) (Z * Z) - | land : BinInt.Z -> ident Z Z - | shiftr : BinInt.Z -> ident Z Z - | shiftl : BinInt.Z -> ident Z Z - | rshi : BinInt.Z -> ident (Z * Z) Z - | selc : ident (Z * Z * Z) Z - | selm : ident (Z * Z * Z) Z - | sell : ident (Z * Z * Z) Z - | addm : ident (Z * Z * Z) Z - . - Definition dummy t : @expr var ident t := Scalar (dummy_scalar (dummy_arrow:=dummy_arrow) t). - - Definition invert_lower' {t} (e : @scalar var ident t) : - option (@scalar var ident Z) := - match e in scalar t return option (@scalar var ident Z) with - | Cast r (Land n x) => - if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? 2^half_bits-1) - then Some x - else None - | _ => None - end. + Lemma wordmax_gt_2 : 2 < wordmax. + Proof. + clear - wordmax log2wordmax log2wordmax_pos. subst wordmax. + apply Z.le_lt_trans with (m:=2 ^ 1); [ reflexivity | ]. + apply Z.pow_lt_mono_r; omega. + Qed. - Definition invert_upper' {t} (e : @scalar var ident t) : - option (@scalar var ident Z) := - match e in scalar t return option (@scalar var ident Z) with - | Cast r (Shiftr n x) => - if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? half_bits) - then Some x - else None - | _ => None - end. + Lemma wordmax_even : wordmax mod 2 = 0. + Proof. + replace 2 with (2 ^ 1) by reflexivity. + subst wordmax. apply Z.mod_same_pow; omega. + Qed. - Definition invert_lower {t} (e : @scalar var ident t) : - option (@scalar var ident Z) := - match e in scalar t return option (@scalar var ident Z) with - | Primitive type.Z x => - match constant_to_scalar ident x with - | Some y => invert_lower' y - | None => None - end - | _ => invert_lower' e - end. + Lemma wordmax_half_bits_pos : 0 < wordmax_half_bits. + Proof. subst wordmax_half_bits half_bits. Z.zero_bounds. Qed. - Definition invert_upper {t} (e : @scalar var ident t) : - option (@scalar var ident Z) := - match e in scalar t return option (@scalar var ident Z) with - | Primitive type.Z x => - match constant_to_scalar ident x with - | Some y => invert_upper' y - | None => None - end - | _ => invert_upper' e - end. + Lemma half_bits_squared : (wordmax_half_bits - 1) * (wordmax_half_bits - 1) <= wordmax - 1. + Proof. + pose proof wordmax_half_bits_pos. + subst wordmax_half_bits. + transitivity (2 ^ (half_bits + half_bits) - 2 * 2 ^ half_bits + 1). + { rewrite Z.pow_add_r by (subst half_bits; Z.zero_bounds). + autorewrite with push_Zmul; omega. } + { transitivity (wordmax - 2 * 2 ^ half_bits + 1); [ | lia]. + subst wordmax. + apply Z.add_le_mono_r. + apply Z.sub_le_mono_r. + apply Z.pow_le_mono_r; [ omega | ]. + rewrite Z.add_diag; subst half_bits. + apply BinInt.Z.mul_div_le; omega. } + Qed. - Definition invert_sell {t} (e : @scalar var ident t) : - option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) := - match e return _ with - | Pair _ Z (Pair Z Z x y) z => - match x return option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) with - | Cast r (Land n x') => - if (lower r =? 0) && (upper r =? 1) && (n =? 1) - then Some (x', y, z) + Section with_var. + Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d)) + (constant_to_scalar : Z -> option (@scalar var type.Z)). + Local Notation Z := (type.type_primitive type.Z). + + Inductive ident : type -> type -> Type := + | add : ident (Z * Z) (Z * Z) + | addc : ident (Z * Z * Z) (Z * Z) + | mulll : ident (Z * Z) Z + | mullh : ident (Z * Z) Z + | mulhl : ident (Z * Z) Z + | mulhh : ident (Z * Z) Z + | sub : ident (Z * Z) (Z * Z) + (* | land : BinInt.Z -> ident Z Z *) + | shiftr : BinInt.Z -> ident Z Z + | shiftl : BinInt.Z -> ident Z Z + | rshi : BinInt.Z -> ident (Z * Z) Z + | selc : ident (Z * Z * Z) Z + | selm : ident (Z * Z * Z) Z + | sell : ident (Z * Z * Z) Z + | addm : ident (Z * Z * Z) Z + . + Definition dummy t : @expr var ident t := Scalar (dummy_scalar (dummy_arrow:=dummy_arrow) t). + + Definition invert_lower' {t} (e : @scalar var t) : + option (@scalar var Z) := + match e in scalar t return option (@scalar var Z) with + | Cast r (Land n x) => + if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? 2^half_bits-1) + then Some x else None - | _ => (@None _) - end - | _ => None - end. + | _ => None + end. - Definition invert_selm {t} (e : @scalar var ident t) : - option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) := - match e return _ with - | Pair _ Z (Pair Z Z x y) z => - match x return option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) with - | Cast r (CC_m n x') => - if (lower r =? 0) && (upper r =? 1) && (n =? wordmax) - then Some (x', y, z) + Definition invert_upper' {t} (e : @scalar var t) : + option (@scalar var Z) := + match e in scalar t return option (@scalar var Z) with + | Cast r (Shiftr n x) => + if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? half_bits) + then Some x else None - | _ => (@None _) - end - | _ => None - end. + | _ => None + end. - Definition of_straightline_ident {s d} (idc : ident.ident s d) - : forall t, range_type d -> @scalar var ident s -> (var d -> @expr var ident t) -> @expr var ident t := - match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with - | ident.Z.add_get_carry_concrete w => - fun t r x f => - if w =? wordmax - then LetInAppIdentZZ r add x f - else dummy _ - | ident.Z.add_with_get_carry_concrete w => - fun t r x f => - if w =? wordmax - then LetInAppIdentZZ r addc x f - else dummy _ - | ident.Z.sub_get_borrow_concrete w => - fun t r x f => - if w =? wordmax - then LetInAppIdentZZ r sub x f - else dummy _ - | ident.Z.land n => fun _ r => LetInAppIdentZ r (land n) - | ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n) - | ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n) - | ident.Z.rshi_concrete w n => - fun _ r x f => - if w =? wordmax - then LetInAppIdentZ r (rshi n) x f - else dummy _ - | ident.Z.zselect => - fun t r x f => - match invert_selm x with - | Some (x, y, z) => LetInAppIdentZ r selm (Pair (Pair x y) z) f - | None => match invert_sell x with - | Some (x, y, z) => LetInAppIdentZ r sell (Pair (Pair x y) z) f - | None => LetInAppIdentZ r selc x f - end + Definition invert_lower {t} (e : @scalar var t) : + option (@scalar var Z) := + match e in scalar t return option (@scalar var Z) with + | Primitive type.Z x => + match constant_to_scalar x with + | Some y => invert_lower' y + | None => None end - | ident.Z.add_modulo => fun _ r => LetInAppIdentZ r addm - | ident.Z.mul => - fun t r x f => - match x return expr t with - | Pair _ _ x0 x1 => - match invert_lower x0, invert_lower x1 with - | Some y0, Some y1 => LetInAppIdentZ r mulll (Pair y0 y1) f - | Some y0, None => - match invert_upper x1 with - | Some y1 => LetInAppIdentZ r mullh (Pair y0 y1) f - | None => dummy _ - end - | None, Some y1 => - match invert_upper x0 with - | Some y0 => LetInAppIdentZ r mulhl (Pair y0 y1) f - | None => dummy _ - end - | None, None => - match invert_upper x0, invert_upper x1 with - | Some y0, Some y1 => LetInAppIdentZ r mulhh (Pair y0 y1) f - | _,_ => dummy _ + | _ => invert_lower' e + end. + + Definition invert_upper {t} (e : @scalar var t) : + option (@scalar var Z) := + match e in scalar t return option (@scalar var Z) with + | Primitive type.Z x => + match constant_to_scalar x with + | Some y => invert_upper' y + | None => None + end + | _ => invert_upper' e + end. + + Definition invert_sell {t} (e : @scalar var t) : + option (@scalar var Z * @scalar var Z * @scalar var Z) := + match e return _ with + | Pair _ Z (Pair Z Z x y) z => + match x return option (@scalar var Z * @scalar var Z * @scalar var Z) with + | Cast r (Land n x') => + if (lower r =? 0) && (upper r =? 1) && (n =? 1) + then Some (x', y, z) + else None + | _ => (@None _) + end + | _ => None + end. + + Definition invert_selm {t} (e : @scalar var t) : + option (@scalar var Z * @scalar var Z * @scalar var Z) := + match e return _ with + | Pair _ Z (Pair Z Z x y) z => + match x return option (@scalar var Z * @scalar var Z * @scalar var Z) with + | Cast r (CC_m n x') => + if (lower r =? 0) && (upper r =? 1) && (n =? wordmax) + then Some (x', y, z) + else None + | _ => (@None _) + end + | _ => None + end. + + Definition of_straightline_ident {s d} (idc : ident.ident s d) + : forall t, range_type d -> @scalar var s -> (var d -> @expr var ident t) -> @expr var ident t := + match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with + | ident.Z.add_get_carry_concrete w => + fun t r x f => + if w =? wordmax + then LetInAppIdentZZ r add x f + else dummy _ + | ident.Z.add_with_get_carry_concrete w => + fun t r x f => + if w =? wordmax + then LetInAppIdentZZ r addc x f + else dummy _ + | ident.Z.sub_get_borrow_concrete w => + fun t r x f => + if w =? wordmax + then LetInAppIdentZZ r sub x f + else dummy _ + (* | ident.Z.land n => fun _ r => LetInAppIdentZ r (land n) *) + | ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n) + | ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n) + | ident.Z.rshi_concrete w n => + fun _ r x f => + if w =? wordmax + then LetInAppIdentZ r (rshi n) x f + else dummy _ + | ident.Z.zselect => + fun t r x f => + match invert_selm x with + | Some (x, y, z) => LetInAppIdentZ r selm (Pair (Pair x y) z) f + | None => match invert_sell x with + | Some (x, y, z) => LetInAppIdentZ r sell (Pair (Pair x y) z) f + | None => LetInAppIdentZ r selc x f + end + end + | ident.Z.add_modulo => fun _ r => LetInAppIdentZ r addm + | ident.Z.mul => + fun t r x f => + match x return expr t with + | Pair _ _ x0 x1 => + match invert_lower x0, invert_lower x1 with + | Some y0, Some y1 => LetInAppIdentZ r mulll (Pair y0 y1) f + | Some y0, None => + match invert_upper x1 with + | Some y1 => LetInAppIdentZ r mullh (Pair y0 y1) f + | None => dummy _ + end + | None, Some y1 => + match invert_upper x0 with + | Some y0 => LetInAppIdentZ r mulhl (Pair y0 y1) f + | None => dummy _ + end + | None, None => + match invert_upper x0, invert_upper x1 with + | Some y0, Some y1 => LetInAppIdentZ r mulhh (Pair y0 y1) f + | _,_ => dummy _ + end end + | _ => dummy _ end - | _ => dummy _ + | _ => fun t _ _ _ => dummy t + end. + + Fixpoint of_straightline {t} (e : @expr var ident.ident t) + : @expr var ident t := + match e with + | Scalar _ s => Scalar s + | LetInAppIdentZ _ t r idc x f => + of_straightline_ident idc t r x (fun y => of_straightline (f y)) + | LetInAppIdentZZ _ t r idc x f => + of_straightline_ident idc t r x (fun y => of_straightline (f y)) + end. + + Definition constant_to_scalar_single (const x : BinInt.Z) : option (@scalar var Z) := + if x =? (BinInt.Z.shiftr const half_bits) + then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Shiftr half_bits (Primitive (t:=type.Z) const))) + else if x =? (BinInt.Z.land const (wordmax_half_bits - 1)) + then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Land (wordmax_half_bits-1) (Primitive (t:=type.Z) const))) + else None. + + Definition constant_to_scalar_gen (consts : list BinInt.Z) (x : BinInt.Z) + : option (Straightline.expr.scalar Z) := + fold_right (fun c res => match res with + | Some s => Some s + | None => constant_to_scalar_single c x + end) None consts. + End with_var. + + Section interp. + Local Notation low x := (Z.land x (wordmax_half_bits - 1)). + Local Notation high x := (x >> half_bits). + + Definition interp_ident {s d} (idc : ident s d) : type.interp s -> type.interp d := + match idc with + | add => fun x => Z.add_get_carry_full wordmax (fst x) (snd x) + | addc => fun x => Z.add_with_get_carry_full wordmax (fst (fst x)) (snd (fst x)) (snd x) + | mulll => fun x => low (fst x) * low (snd x) + | mullh => fun x => low (fst x) * high (snd x) + | mulhl => fun x => high (fst x) * low (snd x) + | mulhh => fun x => high (fst x) * high (snd x) + | sub => fun x => Z.sub_get_borrow_full wordmax (fst x) (snd x) + (* | land n => fun x => Z.land x n *) (* only allowed inside select/mul? *) + | shiftr n => fun x => Z.shiftr x n + | shiftl n => fun x => Z.shiftl x n + | rshi n => fun x => Z.rshi wordmax (fst x) (snd x) n + | selc => fun x => Z.zselect (fst (fst x)) (snd (fst x)) (snd x) + | selm => fun x => Z.zselect (Z.cc_m wordmax (fst (fst x))) (snd (fst x)) (snd x) + | sell => fun x => Z.zselect (Z.land (fst (fst x)) 1) (snd (fst x)) (snd x) + | addm => fun x => Z.add_modulo (fst (fst x)) (snd (fst x)) (snd x) + end. + + Fixpoint interp {t} (e : @expr type.interp ident t) : type.interp t := + match e with + | Scalar t s => interp_scalar s + | LetInAppIdentZ s d r idc x f => + interp (f (interp_cast r (interp_ident idc (interp_scalar x)))) + | LetInAppIdentZZ s d r idc x f => + interp (f (interp_cast2 r (interp_ident idc (interp_scalar x)))) + end. + End interp. + + Section proofs. + Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) + (constant_to_scalar : Z -> option (@scalar type.interp type.Z)). + + Local Notation word_range := (r[0~>wordmax-1])%zrange. + Local Notation half_word_range := (r[0~>wordmax_half_bits-1])%zrange. + Local Notation flag_range := (r[0~>1])%zrange. + + Definition in_word_range (r : zrange) := is_tighter_than_bool r word_range = true. + Definition in_flag_range (r : zrange) := is_tighter_than_bool r flag_range = true. + + Fixpoint get_range_var (t : type) : type.interp t -> range_type t := + match t with + | type.type_primitive type.Z => + fun x => {| lower := x; upper := x |} + | type.prod a b => + fun x => (get_range_var a (fst x), get_range_var b (snd x)) + | _ => fun _ => tt + end. + + Fixpoint get_range {t} (x : @scalar type.interp t) : range_type t := + match x with + | Var t v => get_range_var t v + | TT => tt + | Nil _ => tt + | Pair _ _ x y => (get_range x, get_range y) + | Cast r _ => r + | Cast2 r _ => r + | Fst _ _ p => fst (get_range p) + | Snd _ _ p => snd (get_range p) + | Shiftr n x => ZRange.map (fun y => Z.shiftr y n) (get_range x) + | Shiftl n x => ZRange.map (fun y => Z.shiftl y n) (get_range x) + | Land n x => r[0~>n]%zrange + | CC_m n x => ZRange.map (Z.cc_m n) (get_range x) + | Primitive type.Z x => {| lower := x; upper := x |} + | Primitive p x => tt + end. + + Fixpoint has_range {t} : range_type t -> type.interp t -> Prop := + match t with + | type.type_primitive type.Z => + fun r x => + lower r <= x <= upper r + | type.prod a b => + fun r x => + has_range (fst r) (fst x) /\ has_range (snd r) (snd x) + | _ => fun _ _ => True + end. + + Inductive ok_scalar : forall {t}, @scalar type.interp t -> Prop := + | sc_ok_var : forall t v, ok_scalar (Var t v) + | sc_ok_unit : ok_scalar TT + | sc_ok_nil : forall t, ok_scalar (Nil t) + | sc_ok_pair : forall A B x y, + @ok_scalar A x -> + @ok_scalar B y -> + ok_scalar (Pair x y) + | sc_ok_cast : forall r (x : scalar type.Z), + ok_scalar x -> + is_tighter_than_bool (get_range x) r = true -> + ok_scalar (Cast r x) + | sc_ok_cast2 : forall r (x : scalar (type.prod type.Z type.Z)), + ok_scalar x -> + is_tighter_than_bool (fst (get_range x)) (fst r) = true -> + is_tighter_than_bool (snd (get_range x)) (snd r) = true -> + ok_scalar (Cast2 r x) + | sc_ok_fst : + forall A B p, @ok_scalar (A * B) p -> ok_scalar (Fst p) + | sc_ok_snd : + forall A B p, @ok_scalar (A * B) p -> ok_scalar (Snd p) + | sc_ok_shiftr : + forall n x, 0 <= n -> ok_scalar x -> ok_scalar (Shiftr n x) + | sc_ok_shiftl : + forall n x, 0 <= n -> 0 <= lower (@get_range type.Z x) -> ok_scalar x -> ok_scalar (Shiftl n x) + | sc_ok_land : + forall n x, 0 <= n -> 0 <= lower (@get_range type.Z x) -> ok_scalar x -> ok_scalar (Land n x) + | sc_ok_cc_m : + forall x, ok_scalar x -> ok_scalar (CC_m wordmax x) + | sc_ok_prim : forall p x, ok_scalar (@Primitive _ p x) + . + + Print ident. + Inductive ok_ident : forall s d, scalar s -> range_type d -> ident.ident s d -> Prop := + | ok_add : + forall x : scalar (type.prod type.Z type.Z), + in_word_range (fst (get_range x)) -> + in_word_range (snd (get_range x)) -> + ok_ident _ + (type.prod type.Z type.Z) + x + (word_range, flag_range) + (ident.Z.add_get_carry_concrete wordmax) + | ok_addc : + forall x : scalar (type.prod (type.prod type.Z type.Z) type.Z), + in_flag_range (fst (fst (get_range x))) -> + in_word_range (snd (fst (get_range x))) -> + in_word_range (snd (get_range x)) -> + ok_ident _ + (type.prod type.Z type.Z) + x + (word_range, flag_range) + (ident.Z.add_with_get_carry_concrete wordmax) + | ok_sub : + forall x : scalar (type.prod type.Z type.Z), + in_word_range (fst (get_range x)) -> + in_word_range (snd (get_range x)) -> + ok_ident _ + (type.prod type.Z type.Z) + x + (word_range, flag_range) + (ident.Z.sub_get_borrow_concrete wordmax) + (* | ok_land : + forall (x : scalar type.Z) n, + in_word_range (get_range x) -> + 0 <= n < wordmax -> + ok_ident type.Z type.Z x (ZRange.map (fun y => Z.land y n) (get_range x)) (ident.Z.land n)*) + | ok_shiftr : + forall (x : scalar type.Z) n, + in_word_range (get_range x) -> + 0 <= n <= log2wordmax -> + ok_ident type.Z type.Z x (ZRange.map (fun y => Z.shiftr y n) (get_range x)) (ident.Z.shiftr n) + | ok_shiftl : + forall (x : scalar type.Z) n, + in_word_range (get_range x) -> + 0 <= n < log2wordmax - Z.log2_up (upper (get_range x)) -> + ok_ident type.Z type.Z x word_range (ident.Z.shiftl n) + | ok_rshi : + forall (x : scalar (type.prod type.Z type.Z)) n, + in_word_range (fst (get_range x)) -> + in_word_range (snd (get_range x)) -> + 0 <= n < 2 * log2wordmax -> + ok_ident (type.prod type.Z type.Z) type.Z x word_range (ident.Z.rshi_concrete wordmax n) + | ok_selc : + forall (x : scalar (type.prod type.Z type.Z)) (y z : scalar type.Z), + in_flag_range (snd (get_range x)) -> + in_word_range (get_range y) -> + in_word_range (get_range z) -> + ok_ident _ + type.Z + (Pair (Pair (Snd x) y) z) + word_range + ident.Z.zselect + | ok_selm : + forall x y z : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> + in_word_range (get_range z) -> + ok_ident _ + type.Z + (Pair (Pair (Cast flag_range (CC_m wordmax x)) y) z) + word_range + ident.Z.zselect + | ok_sell : + forall x y z : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> + in_word_range (get_range z) -> + ok_ident _ + type.Z + (Pair (Pair (Cast flag_range (Land 1 x)) y) z) + word_range + ident.Z.zselect + | ok_addm : + forall (x : scalar (type.prod (type.prod type.Z type.Z) type.Z)), + in_word_range (fst (fst (get_range x))) -> + in_word_range (snd (fst (get_range x))) -> + in_word_range (snd (get_range x)) -> + upper (fst (fst (get_range x))) + upper (snd (fst (get_range x))) - lower (snd (get_range x)) < wordmax -> + ok_ident _ + type.Z + x + word_range + ident.Z.add_modulo + | ok_mulll : + forall x y : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> + ok_ident (type.prod type.Z type.Z) + type.Z + (Pair + (Cast half_word_range (Land (wordmax_half_bits - 1) x)) + (Cast half_word_range (Land (wordmax_half_bits - 1) y))) + word_range + ident.Z.mul + | ok_mullh : + forall x y : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> + ok_ident (type.prod type.Z type.Z) + type.Z + (Pair + (Cast half_word_range (Land (wordmax_half_bits - 1) x)) + (Cast half_word_range (Shiftr half_bits y))) + word_range + ident.Z.mul + | ok_mulhl : + forall x y : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> + ok_ident (type.prod type.Z type.Z) + type.Z + (Pair + (Cast half_word_range (Shiftr half_bits x)) + (Cast half_word_range (Land (wordmax_half_bits - 1) y))) + word_range + ident.Z.mul + | ok_mulhh : + forall x y : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> + ok_ident (type.prod type.Z type.Z) + type.Z + (Pair + (Cast half_word_range (Shiftr half_bits x)) + (Cast half_word_range (Shiftr half_bits y))) + word_range + ident.Z.mul + . + + Inductive ok_expr : forall {t}, @expr type.interp ident.ident t -> Prop := + | ok_of_scalar : forall t s, @ok_expr t (Scalar s) + | ok_letin_z : forall s d r idc x f, + ok_ident _ type.Z x r idc -> + ok_scalar x -> + (forall y, has_range r y -> ok_expr (f y)) -> + ok_expr (@LetInAppIdentZ _ _ s d r idc x f) + | ok_letin_zz : forall s d r idc x f, + ok_ident _ (type.prod type.Z type.Z) x r idc -> + ok_scalar x -> + (forall y, has_range r y -> ok_expr (f y)) -> + ok_expr (@LetInAppIdentZZ _ _ s d r idc x f) + . + + Ltac invert H := + inversion H; subst; + repeat match goal with + | H : existT _ _ _ = existT _ _ _ |- _ => apply (Eqdep_dec.inj_pair2_eq_dec _ type.type_eq_dec) in H; subst + end. + + Lemma has_range_get_range_var {t} (v : type.interp t) : + has_range (get_range_var _ v) v. + Proof. + induction t; cbn [get_range_var has_range fst snd]; auto. + destruct p; auto; cbn [upper lower]; omega. + Qed. + + Lemma has_range_loosen r1 r2 (x : Z) : + @has_range type.Z r1 x -> + is_tighter_than_bool r1 r2 = true -> + @has_range type.Z r2 x. + Proof. + cbv [is_tighter_than_bool has_range]; intros; + match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end; + Z.ltb_to_lt; omega. + Qed. + + Lemma interp_cast_noop x r : + @has_range type.Z r x -> + interp_cast r x = x. + Proof. + cbv [has_range interp_cast ident.cast]; intros. + break_match; + try match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end; + try match goal with H : _ && _ = false |- _ => rewrite andb_false_iff in H; destruct H end; + Z.ltb_to_lt; omega. + Qed. + + Lemma interp_cast2_noop x r : + @has_range (type.prod type.Z type.Z) r x -> + interp_cast2 r x = x. + Proof. + cbv [has_range interp_cast2 interp_cast ident.cast]; intros. + break_match; destruct x; + repeat match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end; + repeat match goal with H : _ && _ = false |- _ => rewrite andb_false_iff in H; destruct H end; + Z.ltb_to_lt; try omega; reflexivity. + Qed. + + Lemma has_range_shiftr n (x : scalar type.Z) : + 0 <= n -> + has_range (get_range x) (interp_scalar x) -> + @has_range type.Z (ZRange.map (fun y : Z => y >> n) (get_range x)) (interp_scalar x >> n). + Proof. cbv [has_range]; intros; cbn. auto using Z.shiftr_le with omega. Qed. + Hint Resolve has_range_shiftr : has_range. + + Lemma has_range_shiftl n r x : + 0 <= n -> 0 <= lower r -> + @has_range type.Z r x -> + @has_range type.Z (ZRange.map (fun y : Z => y << n) r) (x << n). + Proof. cbv [has_range]; intros; cbn. auto using Z.shiftl_le_mono with omega. Qed. + Hint Resolve has_range_shiftl : has_range. + + Lemma has_range_land n (x : scalar type.Z) : + 0 <= n -> 0 <= lower (get_range x) -> + has_range (get_range x) (interp_scalar x) -> + @has_range type.Z (r[0~>n])%zrange (Z.land (interp_scalar x) n). + Proof. + cbv [has_range]; intros; cbn. + split; [ apply Z.land_nonneg | apply Z.land_upper_bound_r ]; omega. + Qed. + Hint Resolve has_range_land : has_range. + + Lemma has_range_interp_scalar {t} (x : scalar t) : + ok_scalar x -> + has_range (get_range x) (interp_scalar x). + Proof. + induction 1; cbn [interp_scalar get_range]; + auto with has_range; + try solve [try inversion IHok_scalar; cbn [has_range]; + auto using has_range_get_range_var]; [ | | | ]. + { rewrite interp_cast_noop by eauto using has_range_loosen. + eapply has_range_loosen; eauto. } + { inversion IHok_scalar. + rewrite interp_cast2_noop; + cbn [has_range]; split; eapply has_range_loosen; eauto. } + { cbn. cbv [has_range] in *. + pose proof wordmax_gt_2. + rewrite !Z.cc_m_eq by omega. + split; apply Z.div_le_mono; Z.zero_bounds; omega. } + { destruct p; cbn [has_range upper lower]; auto; omega. } + Qed. + Hint Resolve has_range_interp_scalar : has_range. + + Lemma has_word_range_interp_scalar (x : scalar type.Z) : + ok_scalar x -> + in_word_range (get_range x) -> + @has_range type.Z word_range (interp_scalar x). + Proof. eauto using has_range_loosen, has_range_interp_scalar. Qed. + + Lemma in_word_range_nonneg r : in_word_range r -> 0 <= lower r. + Proof. + cbv [in_word_range is_tighter_than_bool]. + rewrite andb_true_iff; intuition. + Qed. + + Lemma in_word_range_upper_nonneg r x : @has_range type.Z r x -> in_word_range r -> 0 <= upper r. + Proof. + cbv [in_word_range is_tighter_than_bool]; cbn. + rewrite andb_true_iff; intuition. + Z.ltb_to_lt. omega. + Qed. + + Lemma has_word_range_shiftl n r x : + 0 <= n < log2wordmax - Z.log2_up (upper r) -> + @has_range type.Z r x -> + in_word_range r -> + @has_range type.Z word_range (x << n). + Proof. + intros. + eapply has_range_loosen; + [ apply has_range_shiftl; eauto using in_word_range_nonneg with has_range; omega | ]. + cbv [is_tighter_than_bool]. cbn. + apply andb_true_iff; split; apply Z.leb_le; + [ apply Z.shiftl_nonneg; solve [auto using in_word_range_nonneg] | ]. + rewrite Z.shiftl_mul_pow2 by omega. + destruct (dec (upper r = 0)); + [ match goal with H : _ = 0 |- _ => rewrite H end; pose proof wordmax_gt_2; lia | ]. + match goal with |- ?a * ?b <= _ => + transitivity (2 ^ (Z.log2_up a) * b) end. + { apply Z.mul_le_mono_nonneg_r; auto with zarith; [ ]. + apply Z.log2_log2_up_spec. + apply Z.le_neq; eauto using in_word_range_upper_nonneg with has_range. } + { rewrite <-Z.pow_add_r by auto with zarith. + assert (2 ^ (Z.log2_up (upper r) + n) < 2 ^ log2wordmax) + by (apply Z.pow_lt_mono_r; omega). + replace wordmax with (2 ^ log2wordmax) by reflexivity. + omega. } + Qed. + + Lemma has_word_range_rshi n x y : + 0 <= n -> + @has_range type.Z word_range (Z.rshi wordmax x y n). + Proof. + pose proof wordmax_gt_2. + intros; rewrite Z.rshi_correct by omega. + match goal with |- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + cbn [has_range lower upper]; lia. + Qed. + + Lemma in_word_range_spec r : + 0 <= lower r -> upper r <= wordmax - 1 -> + in_word_range r. + Proof. + intros; cbv [in_word_range is_tighter_than_bool]. + apply andb_true_iff. + split; apply Z.leb_le; cbn [upper lower]; omega. + Qed. + + Ltac destruct_scalar := + match goal with + | x : scalar (type.prod (type.prod _ _) _) |- _ => + match goal with |- context [interp_scalar x] => + destruct (interp_scalar x) as [ [? ?] ?]; + destruct (get_range x) as [ [? ?] ?] end - | _ => fun t _ _ _ => dummy t - end. + | x : scalar (type.prod _ _) |- _ => + match goal with |- context [interp_scalar x] => + destruct (interp_scalar x) as [? ?]; destruct (get_range x) as [? ?] + end + end. - Fixpoint of_straightline_scalar {t} (s : @scalar var ident.ident t) - : @scalar var ident t := - match s with - | Var _ v => Var _ v - | TT => TT - | Nil _ => Nil _ - | Pair _ _ x y => Pair (of_straightline_scalar x) (of_straightline_scalar y) - | Cast r x => Cast r (of_straightline_scalar x) - | Cast2 r x => Cast2 r (of_straightline_scalar x) - | Fst _ _ x => Fst (of_straightline_scalar x) - | Snd _ _ x => Snd (of_straightline_scalar x) - | Shiftr n x => Shiftr n (of_straightline_scalar x) - | Shiftl n x => Shiftl n (of_straightline_scalar x) - | Land n x => Land n (of_straightline_scalar x) - | CC_m n x => CC_m n (of_straightline_scalar x) - | Primitive _ x => Primitive x - end. + Lemma ident_interp_has_range s d x r idc: + ok_scalar x -> + ok_ident s d x r idc -> + has_range r (ident.interp idc (interp_scalar x)). + Proof. + intro. + pose proof (has_range_interp_scalar x ltac:(assumption)). + pose proof wordmax_gt_2. + induction 1; cbn [ident.interp ident.gen_interp]; intros; try destruct_scalar; + repeat match goal with + | H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt + | H : _ /\ _ |- _ => destruct H + | _ => progress subst + | _ => progress (cbv [in_word_range in_flag_range is_tighter_than_bool] in * ) + | _ => progress (cbn [interp_scalar get_range has_range upper lower fst snd] in * ) + end. + { + autorewrite with to_div_mod. + match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite Z.div_between_0_if by omega. + split; break_match; lia. } + { + autorewrite with to_div_mod. + match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite Z.div_between_0_if by omega. + split; break_match; lia. } + { + autorewrite with to_div_mod. + match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite Z.div_sub_small by omega. + split; break_match; lia. } + { apply has_range_shiftr; cbn [has_range]; omega. } + { eapply has_word_range_shiftl; eauto using in_word_range_spec with has_range omega. } + { apply has_word_range_rshi; omega. } + { rewrite Z.zselect_correct. break_match; omega. } + { cbn [interp_scalar fst snd get_range] in *. + rewrite Z.zselect_correct. break_match; omega. } + { cbn [interp_scalar fst snd get_range] in *. + rewrite Z.zselect_correct. break_match; omega. } + { rewrite Z.add_modulo_correct. + break_match; Z.ltb_to_lt; omega. } + { cbn [interp_scalar fst snd get_range upper lower] in *. + pose proof half_bits_squared. nia. } + { cbn [interp_scalar fst snd get_range upper lower] in *. + pose proof half_bits_squared. nia. } + { cbn [interp_scalar fst snd get_range upper lower] in *. + pose proof half_bits_squared. nia. } + { cbn [interp_scalar fst snd get_range upper lower] in *. + pose proof half_bits_squared. nia. } + Qed. - Fixpoint of_straightline {t} (e : @expr var ident.ident t) - : @expr var ident t := - match e with - | Scalar _ s => Scalar (of_straightline_scalar s) - | LetInAppIdentZ _ t r idc x f => - of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y)) - | LetInAppIdentZZ _ t r idc x f => - of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y)) - end. + Ltac extract_ok_scalar' level x := + match goal with + | H : ok_scalar (Pair (Pair (?f (?g x)) _) _) |- _ => + match (eval compute in (4 <=? level)) with + | true => invert H; extract_ok_scalar' 3 x + | _ => fail + end + | H : ok_scalar (Pair (?f (?g x)) _) |- _ => + match (eval compute in (3 <=? level)) with + | true => invert H; extract_ok_scalar' 2 x + | _ => fail + end + | H : ok_scalar (?f (?g x)) |- _ => + match (eval compute in (2 <=? level)) with + | true => invert H; extract_ok_scalar' 1 x + | _ => fail + end + | H : ok_scalar (?g x) |- _ => invert H + end. + + + Ltac extract_ok_scalar := + match goal with |- ok_scalar ?x => extract_ok_scalar' 4 x; assumption end. - Definition constant_to_scalar_single ident (const x : BinInt.Z) : option (@Straightline.expr.scalar var ident Z) := - if x =? (BinInt.Z.shiftr const half_bits) - then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Shiftr half_bits (Primitive (t:=type.Z) const))) - else if x =? (BinInt.Z.land const (wordmax_half_bits - 1)) - then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Land (wordmax_half_bits-1) (Primitive (t:=type.Z) const))) - else None. - - Definition constant_to_scalar_gen (consts : list BinInt.Z) ident (x : BinInt.Z) - : option (Straightline.expr.scalar Z) := - fold_right (fun c res => match res with - | Some s => Some s - | None => constant_to_scalar_single ident c x - end) None consts. - End with_var. + Lemma has_flag_range_cc_m r x : + @has_range type.Z r x -> + in_word_range r -> + @has_range type.Z flag_range (Z.cc_m wordmax x). + Proof. + cbv [has_range in_word_range is_tighter_than_bool]. + cbn [upper lower]; rewrite andb_true_iff; intros. + match goal with H : _ /\ _ |- _ => destruct H; Z.ltb_to_lt end. + pose proof wordmax_gt_2. pose proof wordmax_even. + pose proof (Z.cc_m_small wordmax x). omega. + Qed. + + Lemma has_flag_range_cc_m' (x : scalar type.Z) : + ok_scalar x -> + in_word_range (get_range x) -> + @has_range type.Z flag_range (Z.cc_m wordmax (interp_scalar x)). + Proof. eauto using has_flag_range_cc_m with has_range. Qed. + + Lemma has_flag_range_land r x : + @has_range type.Z r x -> + in_word_range r -> + @has_range type.Z flag_range (Z.land x 1). + Proof. + cbv [has_range in_word_range is_tighter_than_bool]. + cbn [upper lower]; rewrite andb_true_iff; intuition; Z.ltb_to_lt. + { apply Z.land_nonneg. left; omega. } + { apply Z.land_upper_bound_r; omega. } + Qed. + + Lemma has_flag_range_land' (x : scalar type.Z) : + ok_scalar x -> + in_word_range (get_range x) -> + @has_range type.Z flag_range (Z.land (interp_scalar x) 1). + Proof. eauto using has_flag_range_land with has_range. Qed. + + Ltac rewrite_cast_noop_in_mul := + repeat match goal with + | _ => rewrite interp_cast_noop with (r:=half_word_range) in * + by (eapply has_range_loosen; auto using has_range_land, has_range_interp_scalar) + | _ => rewrite interp_cast_noop with (r:=half_word_range) in * + by (eapply has_range_loosen; try apply has_range_shiftr; auto using has_range_interp_scalar; + cbn [ZRange.map get_range] in *; auto) + | _ => rewrite interp_cast_noop by assumption + end. + + Lemma of_straightline_ident_correct s d t x r (idc : ident.ident s d) g : + ok_ident s d x r idc -> + ok_scalar x -> + interp (of_straightline_ident dummy_arrow constant_to_scalar idc t r x g) = + interp (g (ident.interp idc (interp_scalar x))). + Proof. + intros. + pose proof wordmax_half_bits_pos. + pose proof (ident_interp_has_range _ _ x r idc ltac:(assumption) ltac:(assumption)). + induction H; cbn [of_straightline_ident ident.interp ident.gen_interp invert_selm invert_sell invert_upper invert_lower invert_upper' invert_lower'] in *; intros; + rewrite ?Z.eqb_refl; cbn [interp interp_ident andb]; try destruct_scalar; + repeat match goal with + | _ => progress (cbn [fst snd interp_scalar] in * ) + | _ => progress break_match; [ ] + | _ => rewrite interp_cast_noop with (r:=flag_range) in * + by (apply has_flag_range_cc_m'; auto; extract_ok_scalar) + | _ => rewrite interp_cast_noop with (r:=flag_range) in * + by (apply has_flag_range_land'; auto; extract_ok_scalar) + | H : _ = (_,_) |- _ => progress (inversion H; subst) + | _ => rewrite interp_cast_noop by assumption + | _ => rewrite interp_cast2_noop by assumption + | _ => reflexivity + end; [ | | | ]; (* leftover cases are the 4 kinds of multiplication *) + match goal with + | H1 : in_word_range (get_range ?x), H2 : in_word_range (get_range ?y) |- _ => + extract_ok_scalar' 4 x; extract_ok_scalar' 4 y + end; rewrite_cast_noop_in_mul; reflexivity. + Qed. + + Lemma of_straightline_correct {t} (e : expr t) : + ok_expr e -> + interp (of_straightline dummy_arrow constant_to_scalar e) = Straightline.expr.interp (interp_ident:=@ident.interp) e. + Proof. + induction 1; cbn [of_straightline]; intros; + repeat match goal with + | _ => progress cbn [Straightline.expr.interp] + | _ => rewrite of_straightline_ident_correct by auto + | _ => rewrite interp_cast_noop by auto using ident_interp_has_range + | _ => rewrite interp_cast2_noop by auto using ident_interp_has_range + | H : forall y, has_range _ y -> interp _ = _ |- _ => rewrite H by auto using ident_interp_has_range + | _ => reflexivity + end. + Qed. + End proofs. + End with_wordmax. Definition of_Expr {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d)) (var : type -> Type) (x : var s) dummy_arrow : @Straightline.expr.expr var ident d := - @of_straightline var dummy_arrow log2wordmax (@constant_to_scalar_gen var log2wordmax consts) _ (Straightline.of_Expr e var x dummy_arrow). + @of_straightline log2wordmax var dummy_arrow (@constant_to_scalar_gen log2wordmax var consts) _ (Straightline.of_Expr e var x dummy_arrow). + + Lemma of_Expr_correct {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d)) + (e' : (type.interp s -> Uncurried.expr.expr d)) + (x : type.interp s) dummy_arrow : + e type.interp = Abs e' -> + 1 < log2wordmax -> + Straightline.expr.ok_expr (e' x) -> + ok_expr log2wordmax (of_uncurried (dummy_arrow:=dummy_arrow) (depth (fun _ : type => unit) (fun _ : type => tt) (e (fun _ : type => unit))) (e' x)) -> + (depth type.interp (@DefaultValue.type.default) (e' x) <= depth (fun _ : type => unit) (fun _ : type => tt) (e (fun _ : type => unit)))%nat -> + interp log2wordmax (of_Expr log2wordmax consts e type.interp x dummy_arrow) = @Uncurried.expr.interp _ (@ident.interp) _ (e type.interp) x. + Proof. + intro He'; intros; cbv [of_Expr Straightline.of_Expr]. + rewrite He'; cbn [invert_Abs expr.interp]. + rewrite of_straightline_correct by assumption. + erewrite Straightline.expr.of_uncurried_correct by eassumption. + reflexivity. + Qed. + Module Notations. Import PrintingNotations. Import Straightline.expr. -- cgit v1.2.3