aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2018-02-19 13:42:49 -0500
committerGravatar Jason Gross <jasongross9@gmail.com>2018-02-19 17:59:16 -0500
commita93e56748621f3afa227829697c3f97ba585885e (patch)
treecb6a6d11a8181e37f3205b6e6811277f35a48ab4 /src
parentaa6044f40e9e46856dd94748bfad61565de1266a (diff)
Remove runtime_scope
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v129
1 files changed, 41 insertions, 88 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index e6b8681b0..ed96850cb 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -18,14 +18,6 @@ Require Import Crypto.Util.ZUtil.Tactics.LtbToLt.
Require Import Crypto.Util.Notations.
Import ListNotations. Local Open Scope Z_scope.
-Definition runtime_mul := Z.mul.
-Definition runtime_add := Z.add.
-Definition runtime_opp := Z.opp.
-Delimit Scope runtime_scope with RT.
-Infix "*" := runtime_mul : runtime_scope.
-Infix "+" := runtime_add : runtime_scope.
-Notation "- a" := (runtime_opp a%RT) : runtime_scope.
-
Module Associational.
Definition eval (p:list (Z*Z)) : Z :=
fold_right Z.add 0%Z (map (fun t => fst t * snd t) p).
@@ -51,14 +43,14 @@ Module Associational.
Definition mul (p q:list (Z*Z)) : list (Z*Z) :=
flat_map (fun t =>
map (fun t' =>
- (fst t * fst t', (snd t * snd t')%RT))
+ (fst t * fst t', snd t * snd t'))
q) p.
Lemma eval_mul p q : eval (mul p q) = eval p * eval q.
Proof. induction p; cbv [mul]; push; nsatz. Qed.
Hint Rewrite eval_mul : push_eval.
Definition negate_snd (p:list (Z*Z)) : list (Z*Z) :=
- map (fun cx => (fst cx, (-snd cx)%RT)) p.
+ map (fun cx => (fst cx, -snd cx)) p.
Lemma eval_negate_snd p : eval (negate_snd p) = - eval p.
Proof. induction p; cbv [negate_snd]; push; nsatz. Qed.
Hint Rewrite eval_negate_snd : push_eval.
@@ -69,7 +61,7 @@ Module Associational.
(* Goal: eval ?ab = eval [(10,a1);(1,a0)] * eval [(10,b1);(1,b0)] *)
rewrite <-eval_mul.
(* Goal: eval ?ab = eval (mul [(10,a1);(1,a0)] [(10,b1);(1,b0)]) *)
- cbv -[runtime_mul eval].
+ cbv -[Z.mul eval]; cbn -[eval].
(* Goal: eval ?ab = eval [(100,(a1*b1));(10,a1*b0);(10,a0*b1);(1,a0*b0)]%RT *)
trivial. Defined.
@@ -153,13 +145,13 @@ Module Positional. Section Positional.
generalize dependent (List.seq 0 n); intro xs.
induction xs; simpl; nsatz. Qed.
Definition add_to_nth i x (ls : list Z) : list Z
- := ListUtil.update_nth i (fun y => runtime_add x y) ls.
+ := ListUtil.update_nth i (fun y => x + y) ls.
Lemma eval_add_to_nth (n:nat) (i:nat) (x:Z) (xs:list Z) (H:(i<length xs)%nat)
(Hn : length xs = n) (* N.B. We really only need [i < Nat.min n (length xs)] *) :
eval n (add_to_nth i x xs) = weight i * x + eval n xs.
Proof.
subst n.
- cbv [eval to_associational add_to_nth runtime_add].
+ cbv [eval to_associational add_to_nth].
rewrite ListUtil.combine_update_nth_r at 1.
rewrite <-(update_nth_id i (List.combine _ _)) at 2.
rewrite <-!(ListUtil.splice_nth_equiv_update_nth_update _ _
@@ -175,11 +167,11 @@ Module Positional. Section Positional.
Definition place (t:Z*Z) (i:nat) : nat * Z :=
nat_rect
(fun _ => (nat * Z)%type)
- ((O, fst t * snd t)%RT)
+ (O, fst t * snd t)
(fun i' place_i'
=> let i := S i' in
if (fst t mod weight i =? 0)
- then (i, let c := fst t / weight i in (c * snd t)%RT)
+ then (i, let c := fst t / weight i in c * snd t)
else place_i')
i.
@@ -772,8 +764,6 @@ Module Compilers.
| List_fold_right {A B} : ident ((B * A -> A) * A * list B) A
| List_update_nth {T} : ident (nat * (T -> T) * list T) (list T)
| List_nth_default {T} : ident (T * list T * nat) T
- | Z_runtime_mul : ident (Z * Z) Z
- | Z_runtime_add : ident (Z * Z) Z
| Z_add : ident (Z * Z) Z
| Z_mul : ident (Z * Z) Z
| Z_pow : ident (Z * Z) Z
@@ -819,8 +809,6 @@ Module Compilers.
| List_fold_right A B => curry3_1 (@List.fold_right (type.interp A) (type.interp B))
| List_update_nth T => curry3 (@update_nth (type.interp T))
| List_nth_default T => curry3 (@List.nth_default (type.interp T))
- | Z_runtime_mul => curry2 runtime_mul
- | Z_runtime_add => curry2 runtime_add
| Z_add => curry2 Z.add
| Z_mul => curry2 Z.mul
| Z_pow => curry2 Z.pow
@@ -914,8 +902,6 @@ Module Compilers.
| @List.nth_default ?T ?d ?ls ?n
=> let rT := type.reify T in
mkAppIdent (@ident.List_nth_default rT) (d, ls, n)
- | runtime_mul ?x ?y => mkAppIdent ident.Z_runtime_mul (x, y)
- | runtime_add ?x ?y => mkAppIdent ident.Z_runtime_add (x, y)
| Z.add ?x ?y => mkAppIdent ident.Z_add (x, y)
| Z.mul ?x ?y => mkAppIdent ident.Z_mul (x, y)
| Z.pow ?x ?y => mkAppIdent ident.Z_pow (x, y)
@@ -953,8 +939,6 @@ Module Compilers.
End List.
Module Z.
- Notation runtime_mul := Z_runtime_mul.
- Notation runtime_add := Z_runtime_add.
Notation add := Z_add.
Notation mul := Z_mul.
Notation pow := Z_pow.
@@ -1018,10 +1002,8 @@ Module Compilers.
| list_rect {A P} : ident (P * (A * list A * P -> P) * list A) P
| List_nth_default {T} : ident (T * list T * nat) T
| List_nth_default_concrete {T : type.primitive} (d : interp T) (n : Datatypes.nat) : ident (list T) T
- | Z_runtime_mul : ident (Z * Z) Z
- | Z_runtime_add : ident (Z * Z) Z
- | Z_runtime_shiftr (offset : BinInt.Z) : ident Z Z
- | Z_runtime_land (mask : BinInt.Z) : ident Z Z
+ | Z_shiftr (offset : BinInt.Z) : ident Z Z
+ | Z_land (mask : BinInt.Z) : ident Z Z
| Z_add : ident (Z * Z) Z
| Z_mul : ident (Z * Z) Z
| Z_pow : ident (Z * Z) Z
@@ -1061,10 +1043,8 @@ Module Compilers.
| list_rect A P => curry3_23 (@Datatypes.list_rect (type.interp A) (fun _ => type.interp P))
| List_nth_default T => curry3 (@List.nth_default (type.interp T))
| List_nth_default_concrete T d n => fun ls => @List.nth_default (type.interp T) d ls n
- | Z_runtime_mul => curry2 runtime_mul
- | Z_runtime_add => curry2 runtime_add
- | Z_runtime_shiftr n => fun v => Z.shiftr v n
- | Z_runtime_land mask => fun v => Z.land v mask
+ | Z_shiftr n => fun v => Z.shiftr v n
+ | Z_land mask => fun v => Z.land v mask
| Z_add => curry2 Z.add
| Z_mul => curry2 Z.mul
| Z_pow => curry2 Z.pow
@@ -1137,8 +1117,6 @@ Module Compilers.
| @List.nth_default ?T ?d ?ls ?n
=> let rT := type.reify T in
mkAppIdent (@ident.List_nth_default rT) (d, ls, n)
- | runtime_mul ?x ?y => mkAppIdent ident.Z_runtime_mul (x, y)
- | runtime_add ?x ?y => mkAppIdent ident.Z_runtime_add (x, y)
| Z.add ?x ?y => mkAppIdent ident.Z_add (x, y)
| Z.mul ?x ?y => mkAppIdent ident.Z_mul (x, y)
| Z.pow ?x ?y => mkAppIdent ident.Z_pow (x, y)
@@ -1167,10 +1145,8 @@ Module Compilers.
End List.
Module Z.
- Notation runtime_mul := Z_runtime_mul.
- Notation runtime_add := Z_runtime_add.
- Notation runtime_shiftr := Z_runtime_shiftr.
- Notation runtime_land := Z_runtime_land.
+ Notation shiftr := Z_shiftr.
+ Notation land := Z_land.
Notation add := Z_add.
Notation mul := Z_mul.
Notation pow := Z_pow.
@@ -1258,10 +1234,6 @@ Module Compilers.
=> AppIdent ident.pred
| for_reification.ident.primitive t v
=> AppIdent (ident.primitive v)
- | for_reification.ident.Z_runtime_mul
- => AppIdent ident.Z.runtime_mul
- | for_reification.ident.Z_runtime_add
- => AppIdent ident.Z.runtime_add
| for_reification.ident.Z_add
=> AppIdent ident.Z.add
| for_reification.ident.Z_mul
@@ -2029,10 +2001,8 @@ Module Compilers.
| ident.primitive _ _ as idc
| ident.Nat_succ as idc
| ident.pred as idc
- | ident.Z_runtime_mul as idc
- | ident.Z_runtime_add as idc
- | ident.Z_runtime_shiftr _ as idc
- | ident.Z_runtime_land _ as idc
+ | ident.Z_shiftr _ as idc
+ | ident.Z_land _ as idc
| ident.Z_add as idc
| ident.Z_mul as idc
| ident.Z_pow as idc
@@ -2198,16 +2168,14 @@ Module Compilers.
(ident.snd @@ (Var xyk))
@ ((idc : default.ident _ type.nat)
@@ (ident.fst @@ (Var xyk)))
- | ident.Z_runtime_shiftr _ as idc
- | ident.Z_runtime_land _ as idc
+ | ident.Z_shiftr _ as idc
+ | ident.Z_land _ as idc
| ident.Z_opp as idc
=> λ (xyk :
(* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.Z * (type.Z -> R))%ctype) ,
(ident.snd @@ (Var xyk))
@ ((idc : default.ident _ type.Z)
@@ (ident.fst @@ (Var xyk)))
- | ident.Z_runtime_mul as idc
- | ident.Z_runtime_add as idc
| ident.Z_add as idc
| ident.Z_mul as idc
| ident.Z_pow as idc
@@ -2817,15 +2785,13 @@ Module Compilers.
| ident.Nat_succ as idc
| ident.Z_of_nat as idc
| ident.Z_opp as idc
- | ident.Z_runtime_shiftr _ as idc
- | ident.Z_runtime_land _ as idc
+ | ident.Z_shiftr _ as idc
+ | ident.Z_land _ as idc
=> fun x : expr _ + type.interp _
=> match x return expr _ + type.interp _ with
| inr x => inr (ident.interp idc x)
| inl x => expr.reflect (AppIdent idc x)
end
- | ident.Z_add as idc
- | ident.Z_mul as idc
| ident.Z_pow as idc
| ident.Z_eqb as idc
=> fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _))
@@ -2840,7 +2806,7 @@ Module Compilers.
| inr (inr x, inr y) => inr (ident.interp idc (x, y))
| inr (x, inr y)
=> if Z.eqb y (2^Z.log2 y)
- then expr.reflect (AppIdent (ident.Z.runtime_shiftr (Z.log2 y)) (expr.reify (t:=type.Z) x))
+ then expr.reflect (AppIdent (ident.Z.shiftr (Z.log2 y)) (expr.reify (t:=type.Z) x))
else default
| _ => default
end
@@ -2851,11 +2817,11 @@ Module Compilers.
| inr (inr x, inr y) => inr (ident.interp idc (x, y))
| inr (x, inr y)
=> if Z.eqb y (2^Z.log2 y)
- then expr.reflect (AppIdent (ident.Z.runtime_land (y-1)) (expr.reify (t:=type.Z) x))
+ then expr.reflect (AppIdent (ident.Z.land (y-1)) (expr.reify (t:=type.Z) x))
else default
| _ => default
end
- | ident.Z_runtime_mul as idc
+ | ident.Z_mul as idc
=> fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _))
=> let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in
match x_y return expr _ + type.interp _ with
@@ -2869,7 +2835,7 @@ Module Compilers.
else default
| inr (inl _, inl _) | inl _ => default
end
- | ident.Z_runtime_add as idc
+ | ident.Z_add as idc
=> fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _))
=> let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in
match x_y return expr _ + type.interp _ with
@@ -2916,7 +2882,6 @@ Module Compilers.
Fixpoint to_mul_list (e : @expr var type.Z) : list (@expr var type.Z)
:= match e in expr.expr t return list (@expr var t) with
- | AppIdent s type.Z ident.Z_runtime_mul (Pair type.Z type.Z x y)
| AppIdent s type.Z ident.Z_mul (Pair type.Z type.Z x y)
=> to_mul_list x ++ to_mul_list y
| Var _ _ as e
@@ -2930,10 +2895,8 @@ Module Compilers.
Fixpoint to_add_mul_list (e : @expr var type.Z) : list (list (@expr var type.Z))
:= match e in expr.expr t return list (list (@expr var t)) with
- | AppIdent s type.Z ident.Z_runtime_add (Pair type.Z type.Z x y)
| AppIdent s type.Z ident.Z_add (Pair type.Z type.Z x y)
=> to_add_mul_list x ++ to_add_mul_list y
- | AppIdent s type.Z ident.Z_runtime_mul (Pair type.Z type.Z x y)
| AppIdent s type.Z ident.Z_mul (Pair type.Z type.Z x y)
=> [to_mul_list x ++ to_mul_list y]
| Var _ _ as e
@@ -3369,15 +3332,13 @@ Module Compilers.
| default.ident.List_nth_default T => None
| ident.List_nth_default_concrete T d n
=> Some (@List_nth (type.primitive.compile T) n)
- | ident.Z_runtime_mul
| default.ident.Z_mul
=> Some Z.mul
- | ident.Z_runtime_add
| default.ident.Z_add
=> Some Z.add
- | ident.Z_runtime_shiftr n
+ | ident.Z_shiftr n
=> Some (Z.shiftr n)
- | ident.Z_runtime_land mask
+ | ident.Z_land mask
=> Some (Z.land mask)
| ident.Z_pow
| ident.Z_opp
@@ -3835,25 +3796,17 @@ Local Coercion QArith_base.inject_Z : Z >-> Q.
- reassociation
- indexed + bounds analysis + of phoas *)
-Delimit Scope RT_expr_scope with RT_expr.
Import Uncurried.
Import expr.
Import for_reification.Notations.Reification.
Notation "x + y"
- := (AppIdent ident.Z.runtime_add (x%RT_expr, y%RT_expr)%expr)
- : RT_expr_scope.
-Notation "x * y"
- := (AppIdent ident.Z.runtime_mul (x%RT_expr, y%RT_expr)%expr)
- : RT_expr_scope.
-Notation "x + y"
- := (AppIdent ident.Z.runtime_add (x%RT_expr, y%RT_expr)%expr)
+ := (AppIdent ident.Z.add (x, y)%expr)
: expr_scope.
Notation "x * y"
- := (AppIdent ident.Z.runtime_mul (x%RT_expr, y%RT_expr)%expr)
+ := (AppIdent ident.Z.mul (x, y)%expr)
: expr_scope.
Notation "x" := (Var x) (only printing, at level 9) : expr_scope.
-Open Scope RT_expr_scope.
Require Import AdmitAxiom.
@@ -3874,8 +3827,8 @@ Module test2.
Proof.
let v := Reify (fun y : Z
=> (fun k : Z * Z -> Z * Z
- => dlet_nd x := (y * y)%RT in
- dlet_nd z := (x * x)%RT in
+ => dlet_nd x := (y * y) in
+ dlet_nd z := (x * x) in
k (z, z))
(fun v => v)) in
pose v as E.
@@ -3885,8 +3838,8 @@ Module test2.
lazymatch (eval cbv delta [E'] in E') with
| (fun var : type -> Type =>
(λ x : var (type.type_primitive type.Z),
- expr_let x0 := (Var x * Var x)%RT_expr in
- expr_let x1 := (Var x0 * Var x0)%RT_expr in
+ expr_let x0 := (Var x * Var x) in
+ expr_let x1 := (Var x0 * Var x0) in
(Var x1, Var x1))%expr) => idtac
end.
Import BoundsAnalysis.ident.
@@ -3909,11 +3862,11 @@ Module test3.
Example test3 : True.
Proof.
let v := Reify (fun y : Z
- => dlet_nd x := dlet_nd x := (y * y)%RT in
- (x * x)%RT in
- dlet_nd z := dlet_nd z := (x * x)%RT in
- (z * z)%RT in
- (z * z)%RT) in
+ => dlet_nd x := dlet_nd x := (y * y) in
+ (x * x) in
+ dlet_nd z := dlet_nd z := (x * x) in
+ (z * z) in
+ (z * z)) in
pose v as E.
vm_compute in E.
pose (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'.
@@ -3925,7 +3878,7 @@ Module test3.
expr_let x1 := Var x0 * Var x0 in
expr_let x2 := Var x1 * Var x1 in
expr_let x3 := Var x2 * Var x2 in
- Var x3 * Var x3)%RT_expr%expr)
+ Var x3 * Var x3)%expr)
=> idtac
end.
Import BoundsAnalysis.ident.
@@ -3955,8 +3908,8 @@ Module test4.
let v := Reify (fun y : (list Z * list Z)
=> dlet_nd x := List.nth_default (-1) (fst y) 0 in
dlet_nd z := List.nth_default (-1) (snd y) 0 in
- dlet_nd xz := (x * z)%RT in
- (xz :: xz :: nil)%RT) in
+ dlet_nd xz := (x * z) in
+ (xz :: xz :: nil)) in
pose v as E.
vm_compute in E.
pose (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'.
@@ -3986,7 +3939,7 @@ Module test5.
Example test5 : True.
Proof.
let v := Reify (fun y : (Z * Z)
- => dlet_nd x := (13 * (fst y * snd y))%RT in
+ => dlet_nd x := (13 * (fst y * snd y)) in
x) in
pose v as E.
vm_compute in E.
@@ -4010,7 +3963,7 @@ Module test6.
Proof.
let v := Reify (fun y : Z
=> if 0 =? 1
- then dlet_nd x := (y * y)%RT in
+ then dlet_nd x := (y * y) in
x
else y) in
pose v as E.