aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-05-09 13:17:52 +0200
committerGravatar Jade Philipoom <jadep@google.com>2018-05-31 13:46:48 +0200
commite6b25ccf1bb99582c3f83aa8b3b4fe3f0de31870 (patch)
tree13c241c54dd62935f63725671181ff7b4472fa13 /src/Experiments/SimplyTypedArithmetic.v
parent0d3ae3b975eb56d3247646d0a2023f675c14d759 (diff)
Proofs for pre-fancy pass (could use cleanup)
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v1175
1 files changed, 904 insertions, 271 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index e7c46ee9a..4a5be2505 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -3031,14 +3031,14 @@ Module Compilers.
| Z_mul_split => curry3 Z.mul_split
| Z_mul_split_concrete s => curry2 (Z.mul_split s)
| Z_add_get_carry => curry3 Z.add_get_carry_full
- | Z_add_get_carry_concrete s => curry2 (Z.add_get_carry s)
+ | Z_add_get_carry_concrete s => curry2 (Z.add_get_carry_full s)
| Z_add_with_carry => curry3 Z.add_with_carry
| Z_add_with_get_carry => curry4 Z.add_with_get_carry_full
- | Z_add_with_get_carry_concrete s => curry3 (Z.add_with_get_carry s)
+ | Z_add_with_get_carry_concrete s => curry3 (Z.add_with_get_carry_full s)
| Z_sub_get_borrow => curry3 Z.sub_get_borrow_full
- | Z_sub_get_borrow_concrete s => curry2 (Z.sub_get_borrow s)
+ | Z_sub_get_borrow_concrete s => curry2 (Z.sub_get_borrow_full s)
| Z_sub_with_get_borrow => curry4 Z.sub_with_get_borrow_full
- | Z_sub_with_get_borrow_concrete s => curry3 (Z.sub_with_get_borrow s)
+ | Z_sub_with_get_borrow_concrete s => curry3 (Z.sub_with_get_borrow_full s)
| Z_zselect => curry3 Z.zselect
| Z_add_modulo => curry3 Z.add_modulo
| Z_rshi => curry4 Z.rshi
@@ -4198,13 +4198,13 @@ Module Compilers.
| ident.Z_mul_split_concrete s
=> cps_of (curry2 (Z.mul_split s))
| ident.Z_add_get_carry_concrete s
- => cps_of (curry2 (Z.add_get_carry s))
+ => cps_of (curry2 (Z.add_get_carry_full s))
| ident.Z_add_with_get_carry_concrete s
- => cps_of (curry3 (Z.add_with_get_carry s))
+ => cps_of (curry3 (Z.add_with_get_carry_full s))
| ident.Z_sub_get_borrow_concrete s
- => cps_of (curry2 (Z.sub_get_borrow s))
+ => cps_of (curry2 (Z.sub_get_borrow_full s))
| ident.Z_sub_with_get_borrow_concrete s
- => cps_of (curry3 (Z.sub_with_get_borrow s))
+ => cps_of (curry3 (Z.sub_with_get_borrow_full s))
| ident.Z_rshi_concrete s n
=> cps_of (curry2 (fun x y => Z.rshi s x y n))
| ident.Z_cc_m_concrete s
@@ -7919,6 +7919,22 @@ End X25519_32.
Module Straightline.
Module expr.
+ (* TODO: move these to a better location *)
+ Module type.
+ Definition primitive_eq_dec (a b : type.primitive) : {a = b} + {a <> b}.
+ Proof. destruct a,b; auto; right; congruence. Defined.
+ Fixpoint type_eq_dec (a b : type) : {a = b} + {a <> b}.
+ Proof.
+ destruct a, b; try solve [right; congruence]; [ | | | ].
+ { destruct (primitive_eq_dec p p0); subst; [left | right]; congruence. }
+ { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence].
+ left; congruence. }
+ { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence].
+ left; congruence. }
+ { destruct (type_eq_dec a b); [left | right]; congruence. }
+ Defined.
+ End type.
+
Section with_var.
Context {var : type.type -> Type}.
Context {dummy_arrow : forall s d, var (s -> d)}. (* TODO: remove once arrow-containing pairs are removed at type level *)
@@ -7927,24 +7943,26 @@ Module Straightline.
Section with_ident.
Context {ident : type.type -> type.type -> Type}.
+ Inductive scalar : type.type -> Type :=
+ | Var t : var t -> scalar t
+ | TT : scalar (type.type_primitive type.unit)
+ | Nil t : scalar (type.list t)
+ | Pair {a b} : scalar a -> scalar b -> scalar (a * b)
+ | Cast : zrange -> scalar type.Z -> scalar type.Z
+ | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z)
+ | Fst {a b} : scalar (a * b) -> scalar a
+ | Snd {a b} : scalar (a * b) -> scalar b
+ | Shiftr : Z -> scalar type.Z -> scalar type.Z
+ | Shiftl : Z -> scalar type.Z -> scalar type.Z
+ | Land : Z -> scalar type.Z -> scalar type.Z
+ | CC_m : Z -> scalar type.Z -> scalar type.Z
+ | Primitive {t} : type.interp (type.type_primitive t) -> scalar t
+ .
+
Inductive expr : type.type -> Type :=
| Scalar {t} : scalar t -> expr t
| LetInAppIdentZ {s d} : zrange -> ident s (type.Z) -> scalar s -> (var (type.Z) -> expr d) -> expr d
| LetInAppIdentZZ {s d} : zrange * zrange -> ident s (type.Z*type.Z) -> scalar s -> (var (type.Z*type.Z) -> expr d) -> expr d
- with scalar : type.type -> Type :=
- | Var t : var t -> scalar t
- | TT : scalar (type.type_primitive type.unit)
- | Nil t : scalar (type.list t)
- | Pair {a b} : scalar a -> scalar b -> scalar (a * b)
- | Cast : zrange -> scalar type.Z -> scalar type.Z
- | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z)
- | Fst {a b} : scalar (a * b) -> scalar a
- | Snd {a b} : scalar (a * b) -> scalar b
- | Shiftr : Z -> scalar type.Z -> scalar type.Z
- | Shiftl : Z -> scalar type.Z -> scalar type.Z
- | Land : Z -> scalar type.Z -> scalar type.Z
- | CC_m : Z -> scalar type.Z -> scalar type.Z
- | Primitive {t} : type.interp (type.type_primitive t) -> scalar t
.
Fixpoint dummy_scalar t : scalar t :=
@@ -7960,7 +7978,7 @@ Module Straightline.
Definition of_uncurried_scalar_ident {s d} (idc : ident.ident s d)
: scalar s -> option (scalar d) :=
- match idc in ident.ident s d return scalar (ident:=ident.ident) s -> option (scalar d) with
+ match idc in ident.ident s d return scalar s -> option (scalar d) with
| ident.Z.cast r => fun args => Some (Cast r args)
| ident.Z.cast2 r => fun args => Some (Cast2 r args)
| @ident.fst A B => fun args => Some (Fst args)
@@ -8131,20 +8149,20 @@ Module Straightline.
End depth.
Section interp.
- Local Notation scalar := (@scalar type.interp default.ident).
- Local Notation expr := (@expr type.interp default.ident).
+ Context {ident : type -> type -> Type} {interp_ident : forall s d, ident s d -> type.interp s -> type.interp d}.
Definition interp_cast (r : zrange) (x : Z) : Z := ident.cast ident.cast_outside_of_range r x.
-
- Fixpoint interp_scalar {t} (s : scalar t) : type.interp t :=
+ Definition interp_cast2 (r : zrange * zrange) (x : Z * Z) : Z * Z :=
+ (interp_cast (fst r) (fst x), interp_cast (snd r) (snd x)).
+
+ Fixpoint interp_scalar {t} (s : @scalar type.interp t) : type.interp t :=
match s with
| Var t v => v
| TT => tt
| Nil _ => []
| Pair _ _ x y => (interp_scalar x, interp_scalar y)
| Cast r x => interp_cast r (interp_scalar x)
- | Cast2 r x =>
- let '(a,b) := interp_scalar x in (interp_cast (fst r) a, interp_cast (snd r) b)
+ | Cast2 r x => interp_cast2 r (interp_scalar x)
| Fst _ _ p => fst (interp_scalar p)
| Snd _ _ p => snd (interp_scalar p)
| Shiftr n x => Z.shiftr (interp_scalar x) n
@@ -8154,15 +8172,20 @@ Module Straightline.
| Primitive _ x => x
end.
- Fixpoint interp {t} (e : expr t) : type.interp t :=
+ Fixpoint interp {t} (e : @expr type.interp ident t) : type.interp t :=
match e with
| Scalar _ s => interp_scalar s
| LetInAppIdentZ _ _ r idc x f =>
- interp (f (interp_cast r (ident.interp idc (interp_scalar x))))
- | LetInAppIdentZZ _ _ (r1,r2) idc x f =>
- let '(a,b) := ident.interp idc (interp_scalar x) in
- interp (f (interp_cast r1 a, interp_cast r2 b))
+ interp (f (interp_cast r (interp_ident _ _ idc (interp_scalar x))))
+ | LetInAppIdentZZ _ _ r idc x f =>
+ interp (f (interp_cast2 r (interp_ident _ _ idc (interp_scalar x))))
end.
+ End interp.
+
+ Section proofs.
+ Local Notation straightline_interp := (expr.interp (ident:=default.ident) (interp_ident:=@ident.interp)).
+ Local Notation uinterp := (Uncurried.expr.interp (@ident.interp)).
+ Local Notation uexpr := (@Uncurried.expr.expr ident type.interp).
Inductive ok_scalar_ident : forall {s d}, ident.ident s d -> Prop :=
| ok_si_cast : forall r, ok_scalar_ident (ident.Z.cast r)
@@ -8176,7 +8199,7 @@ Module Straightline.
| ok_prim : forall p x, ok_scalar_ident (@ident.primitive p x)
.
- Inductive ok_scalar: forall {t}, @Uncurried.expr.expr ident type.interp t -> Prop :=
+ Inductive ok_scalar: forall {t}, uexpr t -> Prop :=
| ok_Var : forall t v, @ok_scalar t (Uncurried.expr.Var v)
| ok_TT : ok_scalar Uncurried.expr.TT
| ok_AppIdent :
@@ -8191,7 +8214,7 @@ Module Straightline.
ok_scalar (Uncurried.expr.Pair a b)
.
- Inductive ok_expr : forall {t}, Uncurried.expr.expr t -> Prop :=
+ Inductive ok_expr : forall {t}, uexpr t -> Prop :=
| ok_LetInAppIdentZ :
forall tC r s (idc : ident s type.Z) x k,
ok_scalar x -> (forall y, @ok_expr tC (k y)) ->
@@ -8214,34 +8237,18 @@ Module Straightline.
forall t x, @ok_scalar t x -> @ok_expr t x
.
- Lemma interp_cast_correct {t} r (idc : ident t type.Z) x :
- interp_cast r (ident.interp idc (expr.interp (@ident.interp) x))
- = expr.interp (@ident.interp) (AppIdent (ident.Z.cast r) (AppIdent idc x)).
+ Lemma interp_cast_correct r (x : uexpr type.Z) :
+ interp_cast r (uinterp x) = uinterp (AppIdent (ident.Z.cast r) x).
Proof. reflexivity. Qed.
- Lemma interp_cast2_correct {t} r1 r2 (idc : ident t (type.prod type.Z type.Z)) x :
- (interp_cast r1 (fst (ident.interp idc (expr.interp (@ident.interp) x))),
- interp_cast r2 (snd (ident.interp idc (expr.interp (@ident.interp) x))))
- = expr.interp (@ident.interp) (AppIdent (ident.Z.cast2 (r1, r2)) (AppIdent idc x)).
+ Lemma interp_cast2_correct r (x : uexpr (type.prod type.Z type.Z)) :
+ interp_cast2 r (uinterp x) = uinterp (AppIdent (ident.Z.cast2 r) x).
Proof. cbn; break_match; reflexivity. Qed.
- Definition primitive_eq_dec (a b : type.primitive) : {a = b} + {a <> b}.
- Proof. destruct a,b; auto; right; congruence. Defined.
- Fixpoint type_eq_dec (a b : type) : {a = b} + {a <> b}.
- Proof.
- destruct a, b; try solve [right; congruence]; [ | | | ].
- { destruct (primitive_eq_dec p p0); subst; [left | right]; congruence. }
- { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence].
- left; congruence. }
- { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence].
- left; congruence. }
- { destruct (type_eq_dec a b); [left | right]; congruence. }
- Defined.
-
Ltac invert H :=
inversion H; subst;
repeat match goal with
- | H : existT _ _ _ = existT _ _ _ |- _ => apply (Eqdep_dec.inj_pair2_eq_dec _ type_eq_dec) in H; subst
+ | H : existT _ _ _ = existT _ _ _ |- _ => apply (Eqdep_dec.inj_pair2_eq_dec _ type.type_eq_dec) in H; subst
end.
Ltac invert_ok_expr :=
@@ -8253,7 +8260,7 @@ Module Straightline.
Ltac simpl_inversions :=
cbn [invert_LetInAppIdent invert_LetInCast invert_Pair invert_cast invert_AppIdent invert_Abs].
- Lemma invert_AppIdent_correct {d} (e : @Uncurried.expr.expr ident type.interp d) x p :
+ Lemma invert_AppIdent_correct {d} (e : uexpr d) x p :
invert_AppIdent e = Some (existT (fun s : type => (ident s d * default.expr s)%type) x p) ->
e = AppIdent (fst p) (snd p).
Proof.
@@ -8268,37 +8275,38 @@ Module Straightline.
Lemma of_uncurried_scalar_ident_correct {s d} (idc : ident s d) args args':
ok_scalar_ident idc ->
of_uncurried_scalar args = Some args' ->
- interp_scalar args' = expr.interp (@ident.interp) args ->
+ interp_scalar args' = uinterp args ->
exists s,
of_uncurried_scalar_ident idc args' = Some s
- /\ interp_scalar s = expr.interp (@ident.interp) (AppIdent idc args).
+ /\ interp_scalar s = uinterp (AppIdent idc args).
Proof.
destruct 1; intros;
repeat match goal with
| _ => eexists; split; [ reflexivity | cbn [interp_scalar] ]
| H : interp_scalar _ = _ |- _ => rewrite H
| _ => reflexivity
+ | _ => solve [auto using interp_cast2_correct]
+ | |- context [@Uncurried.expr.interp _ _ (type.type_primitive _)] =>
+ cbn; break_match; reflexivity
end.
- { cbn; break_match; reflexivity. }
- { cbn; break_match; reflexivity. }
Qed.
- Lemma of_uncurried_scalar_correct {t} (e : Uncurried.expr.expr t) :
+ Lemma of_uncurried_scalar_correct {t} (e : uexpr t) :
ok_scalar e ->
exists s,
of_uncurried_scalar e = Some s
- /\ interp_scalar s = expr.interp (@ident.interp) e.
+ /\ interp_scalar s = uinterp e.
Proof.
induction 1; cbn [of_uncurried_scalar]; intros;
repeat match goal with
+ | _ => progress cbn [interp_scalar]
| IH : exists _, _ /\ _ |- _ => destruct IH as [? [? ?] ]
| H : of_uncurried_scalar _ = _ |- _ => rewrite H
+ | H : interp_scalar _ = _ |- _ => rewrite H
| _ => apply of_uncurried_scalar_ident_correct; solve [auto]
- | _ => eexists; tauto
- end; [ ].
- eexists; split; [ reflexivity | ].
- cbn [interp_scalar expr.interp].
- congruence.
+ | _ => eexists; split; [ reflexivity | ]
+ | _ => reflexivity
+ end.
Qed.
Ltac rewrite_ok_scalar :=
@@ -8312,10 +8320,10 @@ Module Straightline.
end.
Lemma of_uncurried_correct dummy_arrow fuel dummy_var :
- forall {t} (e : @Uncurried.expr.expr ident type.interp t),
+ forall {t} (e : uexpr t),
(depth _ dummy_var e <= fuel)%nat ->
ok_expr e ->
- expr.interp (@ident.interp) e = interp (@of_uncurried _ dummy_arrow fuel _ e).
+ uinterp e = straightline_interp (@of_uncurried _ dummy_arrow fuel _ e).
Proof.
induction fuel; intros; [ pose proof (depth_positive dummy_var e); omega | ].
destruct e; cbn [depth of_uncurried expr.interp interp]; intros; invert_ok_expr;
@@ -8327,7 +8335,7 @@ Module Straightline.
end; [ | | | | ].
{
match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
- rewrite <-IHfuel; rewrite interp_cast_correct.
+ rewrite <-IHfuel.
{ reflexivity. }
{ cbn [depth] in *.
(* here we have to reason about the depth calculation for arrows; this will probably be unnecessary with new compilers setup *)
@@ -8335,12 +8343,7 @@ Module Straightline.
{ auto. } }
{
match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
- break_match.
- match goal with H : ?p = (?x, ?y) |- _ =>
- replace x with (fst p) by (rewrite H; reflexivity);
- replace y with (snd p) by (rewrite H; reflexivity)
- end.
- rewrite <-IHfuel. rewrite interp_cast2_correct.
+ rewrite <-IHfuel.
{ cbn; break_match; reflexivity. }
{ cbn [depth] in *.
(* here we have to reason about the depth calculation for arrows; this will probably be unnecessary with new compilers setup *)
@@ -8348,16 +8351,11 @@ Module Straightline.
{ auto. } }
{
match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
- rewrite interp_cast_correct.
+ rewrite <-interp_cast_correct.
reflexivity. }
{
match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
- break_match.
- match goal with H : ?p = (?x, ?y) |- _ =>
- replace x with (fst p) by (rewrite H; reflexivity);
- replace y with (snd p) by (rewrite H; reflexivity)
- end.
- rewrite interp_cast2_correct.
+ rewrite <-interp_cast2_correct.
cbn; break_match; reflexivity. }
{ invert_ok_scalar.
rewrite <-H2.
@@ -8385,7 +8383,7 @@ Module Straightline.
repeat match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
destruct r; reflexivity. }
Admitted.
- End interp.
+ End proofs.
End expr.
Definition of_Expr {s d} (e : Expr (s->d)) (var : type -> Type) (x:var s) dummy_arrow: expr.expr d
@@ -8466,213 +8464,848 @@ End StraightlineTest.
(* Convert straightline code to code that uses only a certain set of identifiers *)
Module PreFancy.
- Section with_var.
- Import Straightline.expr.
- Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d)) (log2wordmax : Z)
- (constant_to_scalar : forall ident, Z -> option (@scalar var ident type.Z)).
- Local Notation Z := (type.type_primitive type.Z).
+ Import Straightline.expr.
+ Section with_wordmax.
+ Context (log2wordmax : Z) (log2wordmax_pos : 1 < log2wordmax).
Let wordmax := 2 ^ log2wordmax.
Let half_bits := log2wordmax / 2.
Let wordmax_half_bits := 2 ^ half_bits.
- Inductive ident : type -> type -> Type :=
- | add : ident (Z * Z) (Z * Z)
- | addc : ident (Z * Z * Z) (Z * Z)
- | mulll : ident (Z * Z) Z
- | mullh : ident (Z * Z) Z
- | mulhl : ident (Z * Z) Z
- | mulhh : ident (Z * Z) Z
- | sub : ident (Z * Z) (Z * Z)
- | land : BinInt.Z -> ident Z Z
- | shiftr : BinInt.Z -> ident Z Z
- | shiftl : BinInt.Z -> ident Z Z
- | rshi : BinInt.Z -> ident (Z * Z) Z
- | selc : ident (Z * Z * Z) Z
- | selm : ident (Z * Z * Z) Z
- | sell : ident (Z * Z * Z) Z
- | addm : ident (Z * Z * Z) Z
- .
- Definition dummy t : @expr var ident t := Scalar (dummy_scalar (dummy_arrow:=dummy_arrow) t).
-
- Definition invert_lower' {t} (e : @scalar var ident t) :
- option (@scalar var ident Z) :=
- match e in scalar t return option (@scalar var ident Z) with
- | Cast r (Land n x) =>
- if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? 2^half_bits-1)
- then Some x
- else None
- | _ => None
- end.
+ Lemma wordmax_gt_2 : 2 < wordmax.
+ Proof.
+ clear - wordmax log2wordmax log2wordmax_pos. subst wordmax.
+ apply Z.le_lt_trans with (m:=2 ^ 1); [ reflexivity | ].
+ apply Z.pow_lt_mono_r; omega.
+ Qed.
- Definition invert_upper' {t} (e : @scalar var ident t) :
- option (@scalar var ident Z) :=
- match e in scalar t return option (@scalar var ident Z) with
- | Cast r (Shiftr n x) =>
- if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? half_bits)
- then Some x
- else None
- | _ => None
- end.
+ Lemma wordmax_even : wordmax mod 2 = 0.
+ Proof.
+ replace 2 with (2 ^ 1) by reflexivity.
+ subst wordmax. apply Z.mod_same_pow; omega.
+ Qed.
- Definition invert_lower {t} (e : @scalar var ident t) :
- option (@scalar var ident Z) :=
- match e in scalar t return option (@scalar var ident Z) with
- | Primitive type.Z x =>
- match constant_to_scalar ident x with
- | Some y => invert_lower' y
- | None => None
- end
- | _ => invert_lower' e
- end.
+ Lemma wordmax_half_bits_pos : 0 < wordmax_half_bits.
+ Proof. subst wordmax_half_bits half_bits. Z.zero_bounds. Qed.
- Definition invert_upper {t} (e : @scalar var ident t) :
- option (@scalar var ident Z) :=
- match e in scalar t return option (@scalar var ident Z) with
- | Primitive type.Z x =>
- match constant_to_scalar ident x with
- | Some y => invert_upper' y
- | None => None
- end
- | _ => invert_upper' e
- end.
+ Lemma half_bits_squared : (wordmax_half_bits - 1) * (wordmax_half_bits - 1) <= wordmax - 1.
+ Proof.
+ pose proof wordmax_half_bits_pos.
+ subst wordmax_half_bits.
+ transitivity (2 ^ (half_bits + half_bits) - 2 * 2 ^ half_bits + 1).
+ { rewrite Z.pow_add_r by (subst half_bits; Z.zero_bounds).
+ autorewrite with push_Zmul; omega. }
+ { transitivity (wordmax - 2 * 2 ^ half_bits + 1); [ | lia].
+ subst wordmax.
+ apply Z.add_le_mono_r.
+ apply Z.sub_le_mono_r.
+ apply Z.pow_le_mono_r; [ omega | ].
+ rewrite Z.add_diag; subst half_bits.
+ apply BinInt.Z.mul_div_le; omega. }
+ Qed.
- Definition invert_sell {t} (e : @scalar var ident t) :
- option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) :=
- match e return _ with
- | Pair _ Z (Pair Z Z x y) z =>
- match x return option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) with
- | Cast r (Land n x') =>
- if (lower r =? 0) && (upper r =? 1) && (n =? 1)
- then Some (x', y, z)
+ Section with_var.
+ Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d))
+ (constant_to_scalar : Z -> option (@scalar var type.Z)).
+ Local Notation Z := (type.type_primitive type.Z).
+
+ Inductive ident : type -> type -> Type :=
+ | add : ident (Z * Z) (Z * Z)
+ | addc : ident (Z * Z * Z) (Z * Z)
+ | mulll : ident (Z * Z) Z
+ | mullh : ident (Z * Z) Z
+ | mulhl : ident (Z * Z) Z
+ | mulhh : ident (Z * Z) Z
+ | sub : ident (Z * Z) (Z * Z)
+ (* | land : BinInt.Z -> ident Z Z *)
+ | shiftr : BinInt.Z -> ident Z Z
+ | shiftl : BinInt.Z -> ident Z Z
+ | rshi : BinInt.Z -> ident (Z * Z) Z
+ | selc : ident (Z * Z * Z) Z
+ | selm : ident (Z * Z * Z) Z
+ | sell : ident (Z * Z * Z) Z
+ | addm : ident (Z * Z * Z) Z
+ .
+ Definition dummy t : @expr var ident t := Scalar (dummy_scalar (dummy_arrow:=dummy_arrow) t).
+
+ Definition invert_lower' {t} (e : @scalar var t) :
+ option (@scalar var Z) :=
+ match e in scalar t return option (@scalar var Z) with
+ | Cast r (Land n x) =>
+ if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? 2^half_bits-1)
+ then Some x
else None
- | _ => (@None _)
- end
- | _ => None
- end.
+ | _ => None
+ end.
- Definition invert_selm {t} (e : @scalar var ident t) :
- option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) :=
- match e return _ with
- | Pair _ Z (Pair Z Z x y) z =>
- match x return option (@scalar var ident Z * @scalar var ident Z * @scalar var ident Z) with
- | Cast r (CC_m n x') =>
- if (lower r =? 0) && (upper r =? 1) && (n =? wordmax)
- then Some (x', y, z)
+ Definition invert_upper' {t} (e : @scalar var t) :
+ option (@scalar var Z) :=
+ match e in scalar t return option (@scalar var Z) with
+ | Cast r (Shiftr n x) =>
+ if (lower r =? 0) && (upper r =? (wordmax_half_bits - 1)) && (n =? half_bits)
+ then Some x
else None
- | _ => (@None _)
- end
- | _ => None
- end.
+ | _ => None
+ end.
- Definition of_straightline_ident {s d} (idc : ident.ident s d)
- : forall t, range_type d -> @scalar var ident s -> (var d -> @expr var ident t) -> @expr var ident t :=
- match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with
- | ident.Z.add_get_carry_concrete w =>
- fun t r x f =>
- if w =? wordmax
- then LetInAppIdentZZ r add x f
- else dummy _
- | ident.Z.add_with_get_carry_concrete w =>
- fun t r x f =>
- if w =? wordmax
- then LetInAppIdentZZ r addc x f
- else dummy _
- | ident.Z.sub_get_borrow_concrete w =>
- fun t r x f =>
- if w =? wordmax
- then LetInAppIdentZZ r sub x f
- else dummy _
- | ident.Z.land n => fun _ r => LetInAppIdentZ r (land n)
- | ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n)
- | ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n)
- | ident.Z.rshi_concrete w n =>
- fun _ r x f =>
- if w =? wordmax
- then LetInAppIdentZ r (rshi n) x f
- else dummy _
- | ident.Z.zselect =>
- fun t r x f =>
- match invert_selm x with
- | Some (x, y, z) => LetInAppIdentZ r selm (Pair (Pair x y) z) f
- | None => match invert_sell x with
- | Some (x, y, z) => LetInAppIdentZ r sell (Pair (Pair x y) z) f
- | None => LetInAppIdentZ r selc x f
- end
+ Definition invert_lower {t} (e : @scalar var t) :
+ option (@scalar var Z) :=
+ match e in scalar t return option (@scalar var Z) with
+ | Primitive type.Z x =>
+ match constant_to_scalar x with
+ | Some y => invert_lower' y
+ | None => None
end
- | ident.Z.add_modulo => fun _ r => LetInAppIdentZ r addm
- | ident.Z.mul =>
- fun t r x f =>
- match x return expr t with
- | Pair _ _ x0 x1 =>
- match invert_lower x0, invert_lower x1 with
- | Some y0, Some y1 => LetInAppIdentZ r mulll (Pair y0 y1) f
- | Some y0, None =>
- match invert_upper x1 with
- | Some y1 => LetInAppIdentZ r mullh (Pair y0 y1) f
- | None => dummy _
- end
- | None, Some y1 =>
- match invert_upper x0 with
- | Some y0 => LetInAppIdentZ r mulhl (Pair y0 y1) f
- | None => dummy _
- end
- | None, None =>
- match invert_upper x0, invert_upper x1 with
- | Some y0, Some y1 => LetInAppIdentZ r mulhh (Pair y0 y1) f
- | _,_ => dummy _
+ | _ => invert_lower' e
+ end.
+
+ Definition invert_upper {t} (e : @scalar var t) :
+ option (@scalar var Z) :=
+ match e in scalar t return option (@scalar var Z) with
+ | Primitive type.Z x =>
+ match constant_to_scalar x with
+ | Some y => invert_upper' y
+ | None => None
+ end
+ | _ => invert_upper' e
+ end.
+
+ Definition invert_sell {t} (e : @scalar var t) :
+ option (@scalar var Z * @scalar var Z * @scalar var Z) :=
+ match e return _ with
+ | Pair _ Z (Pair Z Z x y) z =>
+ match x return option (@scalar var Z * @scalar var Z * @scalar var Z) with
+ | Cast r (Land n x') =>
+ if (lower r =? 0) && (upper r =? 1) && (n =? 1)
+ then Some (x', y, z)
+ else None
+ | _ => (@None _)
+ end
+ | _ => None
+ end.
+
+ Definition invert_selm {t} (e : @scalar var t) :
+ option (@scalar var Z * @scalar var Z * @scalar var Z) :=
+ match e return _ with
+ | Pair _ Z (Pair Z Z x y) z =>
+ match x return option (@scalar var Z * @scalar var Z * @scalar var Z) with
+ | Cast r (CC_m n x') =>
+ if (lower r =? 0) && (upper r =? 1) && (n =? wordmax)
+ then Some (x', y, z)
+ else None
+ | _ => (@None _)
+ end
+ | _ => None
+ end.
+
+ Definition of_straightline_ident {s d} (idc : ident.ident s d)
+ : forall t, range_type d -> @scalar var s -> (var d -> @expr var ident t) -> @expr var ident t :=
+ match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with
+ | ident.Z.add_get_carry_concrete w =>
+ fun t r x f =>
+ if w =? wordmax
+ then LetInAppIdentZZ r add x f
+ else dummy _
+ | ident.Z.add_with_get_carry_concrete w =>
+ fun t r x f =>
+ if w =? wordmax
+ then LetInAppIdentZZ r addc x f
+ else dummy _
+ | ident.Z.sub_get_borrow_concrete w =>
+ fun t r x f =>
+ if w =? wordmax
+ then LetInAppIdentZZ r sub x f
+ else dummy _
+ (* | ident.Z.land n => fun _ r => LetInAppIdentZ r (land n) *)
+ | ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n)
+ | ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n)
+ | ident.Z.rshi_concrete w n =>
+ fun _ r x f =>
+ if w =? wordmax
+ then LetInAppIdentZ r (rshi n) x f
+ else dummy _
+ | ident.Z.zselect =>
+ fun t r x f =>
+ match invert_selm x with
+ | Some (x, y, z) => LetInAppIdentZ r selm (Pair (Pair x y) z) f
+ | None => match invert_sell x with
+ | Some (x, y, z) => LetInAppIdentZ r sell (Pair (Pair x y) z) f
+ | None => LetInAppIdentZ r selc x f
+ end
+ end
+ | ident.Z.add_modulo => fun _ r => LetInAppIdentZ r addm
+ | ident.Z.mul =>
+ fun t r x f =>
+ match x return expr t with
+ | Pair _ _ x0 x1 =>
+ match invert_lower x0, invert_lower x1 with
+ | Some y0, Some y1 => LetInAppIdentZ r mulll (Pair y0 y1) f
+ | Some y0, None =>
+ match invert_upper x1 with
+ | Some y1 => LetInAppIdentZ r mullh (Pair y0 y1) f
+ | None => dummy _
+ end
+ | None, Some y1 =>
+ match invert_upper x0 with
+ | Some y0 => LetInAppIdentZ r mulhl (Pair y0 y1) f
+ | None => dummy _
+ end
+ | None, None =>
+ match invert_upper x0, invert_upper x1 with
+ | Some y0, Some y1 => LetInAppIdentZ r mulhh (Pair y0 y1) f
+ | _,_ => dummy _
+ end
end
+ | _ => dummy _
end
- | _ => dummy _
+ | _ => fun t _ _ _ => dummy t
+ end.
+
+ Fixpoint of_straightline {t} (e : @expr var ident.ident t)
+ : @expr var ident t :=
+ match e with
+ | Scalar _ s => Scalar s
+ | LetInAppIdentZ _ t r idc x f =>
+ of_straightline_ident idc t r x (fun y => of_straightline (f y))
+ | LetInAppIdentZZ _ t r idc x f =>
+ of_straightline_ident idc t r x (fun y => of_straightline (f y))
+ end.
+
+ Definition constant_to_scalar_single (const x : BinInt.Z) : option (@scalar var Z) :=
+ if x =? (BinInt.Z.shiftr const half_bits)
+ then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Shiftr half_bits (Primitive (t:=type.Z) const)))
+ else if x =? (BinInt.Z.land const (wordmax_half_bits - 1))
+ then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Land (wordmax_half_bits-1) (Primitive (t:=type.Z) const)))
+ else None.
+
+ Definition constant_to_scalar_gen (consts : list BinInt.Z) (x : BinInt.Z)
+ : option (Straightline.expr.scalar Z) :=
+ fold_right (fun c res => match res with
+ | Some s => Some s
+ | None => constant_to_scalar_single c x
+ end) None consts.
+ End with_var.
+
+ Section interp.
+ Local Notation low x := (Z.land x (wordmax_half_bits - 1)).
+ Local Notation high x := (x >> half_bits).
+
+ Definition interp_ident {s d} (idc : ident s d) : type.interp s -> type.interp d :=
+ match idc with
+ | add => fun x => Z.add_get_carry_full wordmax (fst x) (snd x)
+ | addc => fun x => Z.add_with_get_carry_full wordmax (fst (fst x)) (snd (fst x)) (snd x)
+ | mulll => fun x => low (fst x) * low (snd x)
+ | mullh => fun x => low (fst x) * high (snd x)
+ | mulhl => fun x => high (fst x) * low (snd x)
+ | mulhh => fun x => high (fst x) * high (snd x)
+ | sub => fun x => Z.sub_get_borrow_full wordmax (fst x) (snd x)
+ (* | land n => fun x => Z.land x n *) (* only allowed inside select/mul? *)
+ | shiftr n => fun x => Z.shiftr x n
+ | shiftl n => fun x => Z.shiftl x n
+ | rshi n => fun x => Z.rshi wordmax (fst x) (snd x) n
+ | selc => fun x => Z.zselect (fst (fst x)) (snd (fst x)) (snd x)
+ | selm => fun x => Z.zselect (Z.cc_m wordmax (fst (fst x))) (snd (fst x)) (snd x)
+ | sell => fun x => Z.zselect (Z.land (fst (fst x)) 1) (snd (fst x)) (snd x)
+ | addm => fun x => Z.add_modulo (fst (fst x)) (snd (fst x)) (snd x)
+ end.
+
+ Fixpoint interp {t} (e : @expr type.interp ident t) : type.interp t :=
+ match e with
+ | Scalar t s => interp_scalar s
+ | LetInAppIdentZ s d r idc x f =>
+ interp (f (interp_cast r (interp_ident idc (interp_scalar x))))
+ | LetInAppIdentZZ s d r idc x f =>
+ interp (f (interp_cast2 r (interp_ident idc (interp_scalar x))))
+ end.
+ End interp.
+
+ Section proofs.
+ Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype)
+ (constant_to_scalar : Z -> option (@scalar type.interp type.Z)).
+
+ Local Notation word_range := (r[0~>wordmax-1])%zrange.
+ Local Notation half_word_range := (r[0~>wordmax_half_bits-1])%zrange.
+ Local Notation flag_range := (r[0~>1])%zrange.
+
+ Definition in_word_range (r : zrange) := is_tighter_than_bool r word_range = true.
+ Definition in_flag_range (r : zrange) := is_tighter_than_bool r flag_range = true.
+
+ Fixpoint get_range_var (t : type) : type.interp t -> range_type t :=
+ match t with
+ | type.type_primitive type.Z =>
+ fun x => {| lower := x; upper := x |}
+ | type.prod a b =>
+ fun x => (get_range_var a (fst x), get_range_var b (snd x))
+ | _ => fun _ => tt
+ end.
+
+ Fixpoint get_range {t} (x : @scalar type.interp t) : range_type t :=
+ match x with
+ | Var t v => get_range_var t v
+ | TT => tt
+ | Nil _ => tt
+ | Pair _ _ x y => (get_range x, get_range y)
+ | Cast r _ => r
+ | Cast2 r _ => r
+ | Fst _ _ p => fst (get_range p)
+ | Snd _ _ p => snd (get_range p)
+ | Shiftr n x => ZRange.map (fun y => Z.shiftr y n) (get_range x)
+ | Shiftl n x => ZRange.map (fun y => Z.shiftl y n) (get_range x)
+ | Land n x => r[0~>n]%zrange
+ | CC_m n x => ZRange.map (Z.cc_m n) (get_range x)
+ | Primitive type.Z x => {| lower := x; upper := x |}
+ | Primitive p x => tt
+ end.
+
+ Fixpoint has_range {t} : range_type t -> type.interp t -> Prop :=
+ match t with
+ | type.type_primitive type.Z =>
+ fun r x =>
+ lower r <= x <= upper r
+ | type.prod a b =>
+ fun r x =>
+ has_range (fst r) (fst x) /\ has_range (snd r) (snd x)
+ | _ => fun _ _ => True
+ end.
+
+ Inductive ok_scalar : forall {t}, @scalar type.interp t -> Prop :=
+ | sc_ok_var : forall t v, ok_scalar (Var t v)
+ | sc_ok_unit : ok_scalar TT
+ | sc_ok_nil : forall t, ok_scalar (Nil t)
+ | sc_ok_pair : forall A B x y,
+ @ok_scalar A x ->
+ @ok_scalar B y ->
+ ok_scalar (Pair x y)
+ | sc_ok_cast : forall r (x : scalar type.Z),
+ ok_scalar x ->
+ is_tighter_than_bool (get_range x) r = true ->
+ ok_scalar (Cast r x)
+ | sc_ok_cast2 : forall r (x : scalar (type.prod type.Z type.Z)),
+ ok_scalar x ->
+ is_tighter_than_bool (fst (get_range x)) (fst r) = true ->
+ is_tighter_than_bool (snd (get_range x)) (snd r) = true ->
+ ok_scalar (Cast2 r x)
+ | sc_ok_fst :
+ forall A B p, @ok_scalar (A * B) p -> ok_scalar (Fst p)
+ | sc_ok_snd :
+ forall A B p, @ok_scalar (A * B) p -> ok_scalar (Snd p)
+ | sc_ok_shiftr :
+ forall n x, 0 <= n -> ok_scalar x -> ok_scalar (Shiftr n x)
+ | sc_ok_shiftl :
+ forall n x, 0 <= n -> 0 <= lower (@get_range type.Z x) -> ok_scalar x -> ok_scalar (Shiftl n x)
+ | sc_ok_land :
+ forall n x, 0 <= n -> 0 <= lower (@get_range type.Z x) -> ok_scalar x -> ok_scalar (Land n x)
+ | sc_ok_cc_m :
+ forall x, ok_scalar x -> ok_scalar (CC_m wordmax x)
+ | sc_ok_prim : forall p x, ok_scalar (@Primitive _ p x)
+ .
+
+ Print ident.
+ Inductive ok_ident : forall s d, scalar s -> range_type d -> ident.ident s d -> Prop :=
+ | ok_add :
+ forall x : scalar (type.prod type.Z type.Z),
+ in_word_range (fst (get_range x)) ->
+ in_word_range (snd (get_range x)) ->
+ ok_ident _
+ (type.prod type.Z type.Z)
+ x
+ (word_range, flag_range)
+ (ident.Z.add_get_carry_concrete wordmax)
+ | ok_addc :
+ forall x : scalar (type.prod (type.prod type.Z type.Z) type.Z),
+ in_flag_range (fst (fst (get_range x))) ->
+ in_word_range (snd (fst (get_range x))) ->
+ in_word_range (snd (get_range x)) ->
+ ok_ident _
+ (type.prod type.Z type.Z)
+ x
+ (word_range, flag_range)
+ (ident.Z.add_with_get_carry_concrete wordmax)
+ | ok_sub :
+ forall x : scalar (type.prod type.Z type.Z),
+ in_word_range (fst (get_range x)) ->
+ in_word_range (snd (get_range x)) ->
+ ok_ident _
+ (type.prod type.Z type.Z)
+ x
+ (word_range, flag_range)
+ (ident.Z.sub_get_borrow_concrete wordmax)
+ (* | ok_land :
+ forall (x : scalar type.Z) n,
+ in_word_range (get_range x) ->
+ 0 <= n < wordmax ->
+ ok_ident type.Z type.Z x (ZRange.map (fun y => Z.land y n) (get_range x)) (ident.Z.land n)*)
+ | ok_shiftr :
+ forall (x : scalar type.Z) n,
+ in_word_range (get_range x) ->
+ 0 <= n <= log2wordmax ->
+ ok_ident type.Z type.Z x (ZRange.map (fun y => Z.shiftr y n) (get_range x)) (ident.Z.shiftr n)
+ | ok_shiftl :
+ forall (x : scalar type.Z) n,
+ in_word_range (get_range x) ->
+ 0 <= n < log2wordmax - Z.log2_up (upper (get_range x)) ->
+ ok_ident type.Z type.Z x word_range (ident.Z.shiftl n)
+ | ok_rshi :
+ forall (x : scalar (type.prod type.Z type.Z)) n,
+ in_word_range (fst (get_range x)) ->
+ in_word_range (snd (get_range x)) ->
+ 0 <= n < 2 * log2wordmax ->
+ ok_ident (type.prod type.Z type.Z) type.Z x word_range (ident.Z.rshi_concrete wordmax n)
+ | ok_selc :
+ forall (x : scalar (type.prod type.Z type.Z)) (y z : scalar type.Z),
+ in_flag_range (snd (get_range x)) ->
+ in_word_range (get_range y) ->
+ in_word_range (get_range z) ->
+ ok_ident _
+ type.Z
+ (Pair (Pair (Snd x) y) z)
+ word_range
+ ident.Z.zselect
+ | ok_selm :
+ forall x y z : scalar type.Z,
+ in_word_range (get_range x) ->
+ in_word_range (get_range y) ->
+ in_word_range (get_range z) ->
+ ok_ident _
+ type.Z
+ (Pair (Pair (Cast flag_range (CC_m wordmax x)) y) z)
+ word_range
+ ident.Z.zselect
+ | ok_sell :
+ forall x y z : scalar type.Z,
+ in_word_range (get_range x) ->
+ in_word_range (get_range y) ->
+ in_word_range (get_range z) ->
+ ok_ident _
+ type.Z
+ (Pair (Pair (Cast flag_range (Land 1 x)) y) z)
+ word_range
+ ident.Z.zselect
+ | ok_addm :
+ forall (x : scalar (type.prod (type.prod type.Z type.Z) type.Z)),
+ in_word_range (fst (fst (get_range x))) ->
+ in_word_range (snd (fst (get_range x))) ->
+ in_word_range (snd (get_range x)) ->
+ upper (fst (fst (get_range x))) + upper (snd (fst (get_range x))) - lower (snd (get_range x)) < wordmax ->
+ ok_ident _
+ type.Z
+ x
+ word_range
+ ident.Z.add_modulo
+ | ok_mulll :
+ forall x y : scalar type.Z,
+ in_word_range (get_range x) ->
+ in_word_range (get_range y) ->
+ ok_ident (type.prod type.Z type.Z)
+ type.Z
+ (Pair
+ (Cast half_word_range (Land (wordmax_half_bits - 1) x))
+ (Cast half_word_range (Land (wordmax_half_bits - 1) y)))
+ word_range
+ ident.Z.mul
+ | ok_mullh :
+ forall x y : scalar type.Z,
+ in_word_range (get_range x) ->
+ in_word_range (get_range y) ->
+ ok_ident (type.prod type.Z type.Z)
+ type.Z
+ (Pair
+ (Cast half_word_range (Land (wordmax_half_bits - 1) x))
+ (Cast half_word_range (Shiftr half_bits y)))
+ word_range
+ ident.Z.mul
+ | ok_mulhl :
+ forall x y : scalar type.Z,
+ in_word_range (get_range x) ->
+ in_word_range (get_range y) ->
+ ok_ident (type.prod type.Z type.Z)
+ type.Z
+ (Pair
+ (Cast half_word_range (Shiftr half_bits x))
+ (Cast half_word_range (Land (wordmax_half_bits - 1) y)))
+ word_range
+ ident.Z.mul
+ | ok_mulhh :
+ forall x y : scalar type.Z,
+ in_word_range (get_range x) ->
+ in_word_range (get_range y) ->
+ ok_ident (type.prod type.Z type.Z)
+ type.Z
+ (Pair
+ (Cast half_word_range (Shiftr half_bits x))
+ (Cast half_word_range (Shiftr half_bits y)))
+ word_range
+ ident.Z.mul
+ .
+
+ Inductive ok_expr : forall {t}, @expr type.interp ident.ident t -> Prop :=
+ | ok_of_scalar : forall t s, @ok_expr t (Scalar s)
+ | ok_letin_z : forall s d r idc x f,
+ ok_ident _ type.Z x r idc ->
+ ok_scalar x ->
+ (forall y, has_range r y -> ok_expr (f y)) ->
+ ok_expr (@LetInAppIdentZ _ _ s d r idc x f)
+ | ok_letin_zz : forall s d r idc x f,
+ ok_ident _ (type.prod type.Z type.Z) x r idc ->
+ ok_scalar x ->
+ (forall y, has_range r y -> ok_expr (f y)) ->
+ ok_expr (@LetInAppIdentZZ _ _ s d r idc x f)
+ .
+
+ Ltac invert H :=
+ inversion H; subst;
+ repeat match goal with
+ | H : existT _ _ _ = existT _ _ _ |- _ => apply (Eqdep_dec.inj_pair2_eq_dec _ type.type_eq_dec) in H; subst
+ end.
+
+ Lemma has_range_get_range_var {t} (v : type.interp t) :
+ has_range (get_range_var _ v) v.
+ Proof.
+ induction t; cbn [get_range_var has_range fst snd]; auto.
+ destruct p; auto; cbn [upper lower]; omega.
+ Qed.
+
+ Lemma has_range_loosen r1 r2 (x : Z) :
+ @has_range type.Z r1 x ->
+ is_tighter_than_bool r1 r2 = true ->
+ @has_range type.Z r2 x.
+ Proof.
+ cbv [is_tighter_than_bool has_range]; intros;
+ match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end;
+ Z.ltb_to_lt; omega.
+ Qed.
+
+ Lemma interp_cast_noop x r :
+ @has_range type.Z r x ->
+ interp_cast r x = x.
+ Proof.
+ cbv [has_range interp_cast ident.cast]; intros.
+ break_match;
+ try match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end;
+ try match goal with H : _ && _ = false |- _ => rewrite andb_false_iff in H; destruct H end;
+ Z.ltb_to_lt; omega.
+ Qed.
+
+ Lemma interp_cast2_noop x r :
+ @has_range (type.prod type.Z type.Z) r x ->
+ interp_cast2 r x = x.
+ Proof.
+ cbv [has_range interp_cast2 interp_cast ident.cast]; intros.
+ break_match; destruct x;
+ repeat match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end;
+ repeat match goal with H : _ && _ = false |- _ => rewrite andb_false_iff in H; destruct H end;
+ Z.ltb_to_lt; try omega; reflexivity.
+ Qed.
+
+ Lemma has_range_shiftr n (x : scalar type.Z) :
+ 0 <= n ->
+ has_range (get_range x) (interp_scalar x) ->
+ @has_range type.Z (ZRange.map (fun y : Z => y >> n) (get_range x)) (interp_scalar x >> n).
+ Proof. cbv [has_range]; intros; cbn. auto using Z.shiftr_le with omega. Qed.
+ Hint Resolve has_range_shiftr : has_range.
+
+ Lemma has_range_shiftl n r x :
+ 0 <= n -> 0 <= lower r ->
+ @has_range type.Z r x ->
+ @has_range type.Z (ZRange.map (fun y : Z => y << n) r) (x << n).
+ Proof. cbv [has_range]; intros; cbn. auto using Z.shiftl_le_mono with omega. Qed.
+ Hint Resolve has_range_shiftl : has_range.
+
+ Lemma has_range_land n (x : scalar type.Z) :
+ 0 <= n -> 0 <= lower (get_range x) ->
+ has_range (get_range x) (interp_scalar x) ->
+ @has_range type.Z (r[0~>n])%zrange (Z.land (interp_scalar x) n).
+ Proof.
+ cbv [has_range]; intros; cbn.
+ split; [ apply Z.land_nonneg | apply Z.land_upper_bound_r ]; omega.
+ Qed.
+ Hint Resolve has_range_land : has_range.
+
+ Lemma has_range_interp_scalar {t} (x : scalar t) :
+ ok_scalar x ->
+ has_range (get_range x) (interp_scalar x).
+ Proof.
+ induction 1; cbn [interp_scalar get_range];
+ auto with has_range;
+ try solve [try inversion IHok_scalar; cbn [has_range];
+ auto using has_range_get_range_var]; [ | | | ].
+ { rewrite interp_cast_noop by eauto using has_range_loosen.
+ eapply has_range_loosen; eauto. }
+ { inversion IHok_scalar.
+ rewrite interp_cast2_noop;
+ cbn [has_range]; split; eapply has_range_loosen; eauto. }
+ { cbn. cbv [has_range] in *.
+ pose proof wordmax_gt_2.
+ rewrite !Z.cc_m_eq by omega.
+ split; apply Z.div_le_mono; Z.zero_bounds; omega. }
+ { destruct p; cbn [has_range upper lower]; auto; omega. }
+ Qed.
+ Hint Resolve has_range_interp_scalar : has_range.
+
+ Lemma has_word_range_interp_scalar (x : scalar type.Z) :
+ ok_scalar x ->
+ in_word_range (get_range x) ->
+ @has_range type.Z word_range (interp_scalar x).
+ Proof. eauto using has_range_loosen, has_range_interp_scalar. Qed.
+
+ Lemma in_word_range_nonneg r : in_word_range r -> 0 <= lower r.
+ Proof.
+ cbv [in_word_range is_tighter_than_bool].
+ rewrite andb_true_iff; intuition.
+ Qed.
+
+ Lemma in_word_range_upper_nonneg r x : @has_range type.Z r x -> in_word_range r -> 0 <= upper r.
+ Proof.
+ cbv [in_word_range is_tighter_than_bool]; cbn.
+ rewrite andb_true_iff; intuition.
+ Z.ltb_to_lt. omega.
+ Qed.
+
+ Lemma has_word_range_shiftl n r x :
+ 0 <= n < log2wordmax - Z.log2_up (upper r) ->
+ @has_range type.Z r x ->
+ in_word_range r ->
+ @has_range type.Z word_range (x << n).
+ Proof.
+ intros.
+ eapply has_range_loosen;
+ [ apply has_range_shiftl; eauto using in_word_range_nonneg with has_range; omega | ].
+ cbv [is_tighter_than_bool]. cbn.
+ apply andb_true_iff; split; apply Z.leb_le;
+ [ apply Z.shiftl_nonneg; solve [auto using in_word_range_nonneg] | ].
+ rewrite Z.shiftl_mul_pow2 by omega.
+ destruct (dec (upper r = 0));
+ [ match goal with H : _ = 0 |- _ => rewrite H end; pose proof wordmax_gt_2; lia | ].
+ match goal with |- ?a * ?b <= _ =>
+ transitivity (2 ^ (Z.log2_up a) * b) end.
+ { apply Z.mul_le_mono_nonneg_r; auto with zarith; [ ].
+ apply Z.log2_log2_up_spec.
+ apply Z.le_neq; eauto using in_word_range_upper_nonneg with has_range. }
+ { rewrite <-Z.pow_add_r by auto with zarith.
+ assert (2 ^ (Z.log2_up (upper r) + n) < 2 ^ log2wordmax)
+ by (apply Z.pow_lt_mono_r; omega).
+ replace wordmax with (2 ^ log2wordmax) by reflexivity.
+ omega. }
+ Qed.
+
+ Lemma has_word_range_rshi n x y :
+ 0 <= n ->
+ @has_range type.Z word_range (Z.rshi wordmax x y n).
+ Proof.
+ pose proof wordmax_gt_2.
+ intros; rewrite Z.rshi_correct by omega.
+ match goal with |- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
+ cbn [has_range lower upper]; lia.
+ Qed.
+
+ Lemma in_word_range_spec r :
+ 0 <= lower r -> upper r <= wordmax - 1 ->
+ in_word_range r.
+ Proof.
+ intros; cbv [in_word_range is_tighter_than_bool].
+ apply andb_true_iff.
+ split; apply Z.leb_le; cbn [upper lower]; omega.
+ Qed.
+
+ Ltac destruct_scalar :=
+ match goal with
+ | x : scalar (type.prod (type.prod _ _) _) |- _ =>
+ match goal with |- context [interp_scalar x] =>
+ destruct (interp_scalar x) as [ [? ?] ?];
+ destruct (get_range x) as [ [? ?] ?]
end
- | _ => fun t _ _ _ => dummy t
- end.
+ | x : scalar (type.prod _ _) |- _ =>
+ match goal with |- context [interp_scalar x] =>
+ destruct (interp_scalar x) as [? ?]; destruct (get_range x) as [? ?]
+ end
+ end.
- Fixpoint of_straightline_scalar {t} (s : @scalar var ident.ident t)
- : @scalar var ident t :=
- match s with
- | Var _ v => Var _ v
- | TT => TT
- | Nil _ => Nil _
- | Pair _ _ x y => Pair (of_straightline_scalar x) (of_straightline_scalar y)
- | Cast r x => Cast r (of_straightline_scalar x)
- | Cast2 r x => Cast2 r (of_straightline_scalar x)
- | Fst _ _ x => Fst (of_straightline_scalar x)
- | Snd _ _ x => Snd (of_straightline_scalar x)
- | Shiftr n x => Shiftr n (of_straightline_scalar x)
- | Shiftl n x => Shiftl n (of_straightline_scalar x)
- | Land n x => Land n (of_straightline_scalar x)
- | CC_m n x => CC_m n (of_straightline_scalar x)
- | Primitive _ x => Primitive x
- end.
+ Lemma ident_interp_has_range s d x r idc:
+ ok_scalar x ->
+ ok_ident s d x r idc ->
+ has_range r (ident.interp idc (interp_scalar x)).
+ Proof.
+ intro.
+ pose proof (has_range_interp_scalar x ltac:(assumption)).
+ pose proof wordmax_gt_2.
+ induction 1; cbn [ident.interp ident.gen_interp]; intros; try destruct_scalar;
+ repeat match goal with
+ | H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt
+ | H : _ /\ _ |- _ => destruct H
+ | _ => progress subst
+ | _ => progress (cbv [in_word_range in_flag_range is_tighter_than_bool] in * )
+ | _ => progress (cbn [interp_scalar get_range has_range upper lower fst snd] in * )
+ end.
+ {
+ autorewrite with to_div_mod.
+ match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
+ rewrite Z.div_between_0_if by omega.
+ split; break_match; lia. }
+ {
+ autorewrite with to_div_mod.
+ match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
+ rewrite Z.div_between_0_if by omega.
+ split; break_match; lia. }
+ {
+ autorewrite with to_div_mod.
+ match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
+ rewrite Z.div_sub_small by omega.
+ split; break_match; lia. }
+ { apply has_range_shiftr; cbn [has_range]; omega. }
+ { eapply has_word_range_shiftl; eauto using in_word_range_spec with has_range omega. }
+ { apply has_word_range_rshi; omega. }
+ { rewrite Z.zselect_correct. break_match; omega. }
+ { cbn [interp_scalar fst snd get_range] in *.
+ rewrite Z.zselect_correct. break_match; omega. }
+ { cbn [interp_scalar fst snd get_range] in *.
+ rewrite Z.zselect_correct. break_match; omega. }
+ { rewrite Z.add_modulo_correct.
+ break_match; Z.ltb_to_lt; omega. }
+ { cbn [interp_scalar fst snd get_range upper lower] in *.
+ pose proof half_bits_squared. nia. }
+ { cbn [interp_scalar fst snd get_range upper lower] in *.
+ pose proof half_bits_squared. nia. }
+ { cbn [interp_scalar fst snd get_range upper lower] in *.
+ pose proof half_bits_squared. nia. }
+ { cbn [interp_scalar fst snd get_range upper lower] in *.
+ pose proof half_bits_squared. nia. }
+ Qed.
- Fixpoint of_straightline {t} (e : @expr var ident.ident t)
- : @expr var ident t :=
- match e with
- | Scalar _ s => Scalar (of_straightline_scalar s)
- | LetInAppIdentZ _ t r idc x f =>
- of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y))
- | LetInAppIdentZZ _ t r idc x f =>
- of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y))
- end.
+ Ltac extract_ok_scalar' level x :=
+ match goal with
+ | H : ok_scalar (Pair (Pair (?f (?g x)) _) _) |- _ =>
+ match (eval compute in (4 <=? level)) with
+ | true => invert H; extract_ok_scalar' 3 x
+ | _ => fail
+ end
+ | H : ok_scalar (Pair (?f (?g x)) _) |- _ =>
+ match (eval compute in (3 <=? level)) with
+ | true => invert H; extract_ok_scalar' 2 x
+ | _ => fail
+ end
+ | H : ok_scalar (?f (?g x)) |- _ =>
+ match (eval compute in (2 <=? level)) with
+ | true => invert H; extract_ok_scalar' 1 x
+ | _ => fail
+ end
+ | H : ok_scalar (?g x) |- _ => invert H
+ end.
+
+
+ Ltac extract_ok_scalar :=
+ match goal with |- ok_scalar ?x => extract_ok_scalar' 4 x; assumption end.
- Definition constant_to_scalar_single ident (const x : BinInt.Z) : option (@Straightline.expr.scalar var ident Z) :=
- if x =? (BinInt.Z.shiftr const half_bits)
- then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Shiftr half_bits (Primitive (t:=type.Z) const)))
- else if x =? (BinInt.Z.land const (wordmax_half_bits - 1))
- then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Land (wordmax_half_bits-1) (Primitive (t:=type.Z) const)))
- else None.
-
- Definition constant_to_scalar_gen (consts : list BinInt.Z) ident (x : BinInt.Z)
- : option (Straightline.expr.scalar Z) :=
- fold_right (fun c res => match res with
- | Some s => Some s
- | None => constant_to_scalar_single ident c x
- end) None consts.
- End with_var.
+ Lemma has_flag_range_cc_m r x :
+ @has_range type.Z r x ->
+ in_word_range r ->
+ @has_range type.Z flag_range (Z.cc_m wordmax x).
+ Proof.
+ cbv [has_range in_word_range is_tighter_than_bool].
+ cbn [upper lower]; rewrite andb_true_iff; intros.
+ match goal with H : _ /\ _ |- _ => destruct H; Z.ltb_to_lt end.
+ pose proof wordmax_gt_2. pose proof wordmax_even.
+ pose proof (Z.cc_m_small wordmax x). omega.
+ Qed.
+
+ Lemma has_flag_range_cc_m' (x : scalar type.Z) :
+ ok_scalar x ->
+ in_word_range (get_range x) ->
+ @has_range type.Z flag_range (Z.cc_m wordmax (interp_scalar x)).
+ Proof. eauto using has_flag_range_cc_m with has_range. Qed.
+
+ Lemma has_flag_range_land r x :
+ @has_range type.Z r x ->
+ in_word_range r ->
+ @has_range type.Z flag_range (Z.land x 1).
+ Proof.
+ cbv [has_range in_word_range is_tighter_than_bool].
+ cbn [upper lower]; rewrite andb_true_iff; intuition; Z.ltb_to_lt.
+ { apply Z.land_nonneg. left; omega. }
+ { apply Z.land_upper_bound_r; omega. }
+ Qed.
+
+ Lemma has_flag_range_land' (x : scalar type.Z) :
+ ok_scalar x ->
+ in_word_range (get_range x) ->
+ @has_range type.Z flag_range (Z.land (interp_scalar x) 1).
+ Proof. eauto using has_flag_range_land with has_range. Qed.
+
+ Ltac rewrite_cast_noop_in_mul :=
+ repeat match goal with
+ | _ => rewrite interp_cast_noop with (r:=half_word_range) in *
+ by (eapply has_range_loosen; auto using has_range_land, has_range_interp_scalar)
+ | _ => rewrite interp_cast_noop with (r:=half_word_range) in *
+ by (eapply has_range_loosen; try apply has_range_shiftr; auto using has_range_interp_scalar;
+ cbn [ZRange.map get_range] in *; auto)
+ | _ => rewrite interp_cast_noop by assumption
+ end.
+
+ Lemma of_straightline_ident_correct s d t x r (idc : ident.ident s d) g :
+ ok_ident s d x r idc ->
+ ok_scalar x ->
+ interp (of_straightline_ident dummy_arrow constant_to_scalar idc t r x g) =
+ interp (g (ident.interp idc (interp_scalar x))).
+ Proof.
+ intros.
+ pose proof wordmax_half_bits_pos.
+ pose proof (ident_interp_has_range _ _ x r idc ltac:(assumption) ltac:(assumption)).
+ induction H; cbn [of_straightline_ident ident.interp ident.gen_interp invert_selm invert_sell invert_upper invert_lower invert_upper' invert_lower'] in *; intros;
+ rewrite ?Z.eqb_refl; cbn [interp interp_ident andb]; try destruct_scalar;
+ repeat match goal with
+ | _ => progress (cbn [fst snd interp_scalar] in * )
+ | _ => progress break_match; [ ]
+ | _ => rewrite interp_cast_noop with (r:=flag_range) in *
+ by (apply has_flag_range_cc_m'; auto; extract_ok_scalar)
+ | _ => rewrite interp_cast_noop with (r:=flag_range) in *
+ by (apply has_flag_range_land'; auto; extract_ok_scalar)
+ | H : _ = (_,_) |- _ => progress (inversion H; subst)
+ | _ => rewrite interp_cast_noop by assumption
+ | _ => rewrite interp_cast2_noop by assumption
+ | _ => reflexivity
+ end; [ | | | ]; (* leftover cases are the 4 kinds of multiplication *)
+ match goal with
+ | H1 : in_word_range (get_range ?x), H2 : in_word_range (get_range ?y) |- _ =>
+ extract_ok_scalar' 4 x; extract_ok_scalar' 4 y
+ end; rewrite_cast_noop_in_mul; reflexivity.
+ Qed.
+
+ Lemma of_straightline_correct {t} (e : expr t) :
+ ok_expr e ->
+ interp (of_straightline dummy_arrow constant_to_scalar e) = Straightline.expr.interp (interp_ident:=@ident.interp) e.
+ Proof.
+ induction 1; cbn [of_straightline]; intros;
+ repeat match goal with
+ | _ => progress cbn [Straightline.expr.interp]
+ | _ => rewrite of_straightline_ident_correct by auto
+ | _ => rewrite interp_cast_noop by auto using ident_interp_has_range
+ | _ => rewrite interp_cast2_noop by auto using ident_interp_has_range
+ | H : forall y, has_range _ y -> interp _ = _ |- _ => rewrite H by auto using ident_interp_has_range
+ | _ => reflexivity
+ end.
+ Qed.
+ End proofs.
+ End with_wordmax.
Definition of_Expr {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d))
(var : type -> Type) (x : var s) dummy_arrow : @Straightline.expr.expr var ident d :=
- @of_straightline var dummy_arrow log2wordmax (@constant_to_scalar_gen var log2wordmax consts) _ (Straightline.of_Expr e var x dummy_arrow).
+ @of_straightline log2wordmax var dummy_arrow (@constant_to_scalar_gen log2wordmax var consts) _ (Straightline.of_Expr e var x dummy_arrow).
+
+ Lemma of_Expr_correct {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d))
+ (e' : (type.interp s -> Uncurried.expr.expr d))
+ (x : type.interp s) dummy_arrow :
+ e type.interp = Abs e' ->
+ 1 < log2wordmax ->
+ Straightline.expr.ok_expr (e' x) ->
+ ok_expr log2wordmax (of_uncurried (dummy_arrow:=dummy_arrow) (depth (fun _ : type => unit) (fun _ : type => tt) (e (fun _ : type => unit))) (e' x)) ->
+ (depth type.interp (@DefaultValue.type.default) (e' x) <= depth (fun _ : type => unit) (fun _ : type => tt) (e (fun _ : type => unit)))%nat ->
+ interp log2wordmax (of_Expr log2wordmax consts e type.interp x dummy_arrow) = @Uncurried.expr.interp _ (@ident.interp) _ (e type.interp) x.
+ Proof.
+ intro He'; intros; cbv [of_Expr Straightline.of_Expr].
+ rewrite He'; cbn [invert_Abs expr.interp].
+ rewrite of_straightline_correct by assumption.
+ erewrite Straightline.expr.of_uncurried_correct by eassumption.
+ reflexivity.
+ Qed.
+
Module Notations.
Import PrintingNotations.
Import Straightline.expr.