diff options
author | Jason Gross <jgross@mit.edu> | 2018-02-19 13:42:49 -0500 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-02-19 17:59:16 -0500 |
commit | a93e56748621f3afa227829697c3f97ba585885e (patch) | |
tree | cb6a6d11a8181e37f3205b6e6811277f35a48ab4 /src | |
parent | aa6044f40e9e46856dd94748bfad61565de1266a (diff) |
Remove runtime_scope
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 129 |
1 files changed, 41 insertions, 88 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index e6b8681b0..ed96850cb 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -18,14 +18,6 @@ Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. Require Import Crypto.Util.Notations. Import ListNotations. Local Open Scope Z_scope. -Definition runtime_mul := Z.mul. -Definition runtime_add := Z.add. -Definition runtime_opp := Z.opp. -Delimit Scope runtime_scope with RT. -Infix "*" := runtime_mul : runtime_scope. -Infix "+" := runtime_add : runtime_scope. -Notation "- a" := (runtime_opp a%RT) : runtime_scope. - Module Associational. Definition eval (p:list (Z*Z)) : Z := fold_right Z.add 0%Z (map (fun t => fst t * snd t) p). @@ -51,14 +43,14 @@ Module Associational. Definition mul (p q:list (Z*Z)) : list (Z*Z) := flat_map (fun t => map (fun t' => - (fst t * fst t', (snd t * snd t')%RT)) + (fst t * fst t', snd t * snd t')) q) p. Lemma eval_mul p q : eval (mul p q) = eval p * eval q. Proof. induction p; cbv [mul]; push; nsatz. Qed. Hint Rewrite eval_mul : push_eval. Definition negate_snd (p:list (Z*Z)) : list (Z*Z) := - map (fun cx => (fst cx, (-snd cx)%RT)) p. + map (fun cx => (fst cx, -snd cx)) p. Lemma eval_negate_snd p : eval (negate_snd p) = - eval p. Proof. induction p; cbv [negate_snd]; push; nsatz. Qed. Hint Rewrite eval_negate_snd : push_eval. @@ -69,7 +61,7 @@ Module Associational. (* Goal: eval ?ab = eval [(10,a1);(1,a0)] * eval [(10,b1);(1,b0)] *) rewrite <-eval_mul. (* Goal: eval ?ab = eval (mul [(10,a1);(1,a0)] [(10,b1);(1,b0)]) *) - cbv -[runtime_mul eval]. + cbv -[Z.mul eval]; cbn -[eval]. (* Goal: eval ?ab = eval [(100,(a1*b1));(10,a1*b0);(10,a0*b1);(1,a0*b0)]%RT *) trivial. Defined. @@ -153,13 +145,13 @@ Module Positional. Section Positional. generalize dependent (List.seq 0 n); intro xs. induction xs; simpl; nsatz. Qed. Definition add_to_nth i x (ls : list Z) : list Z - := ListUtil.update_nth i (fun y => runtime_add x y) ls. + := ListUtil.update_nth i (fun y => x + y) ls. Lemma eval_add_to_nth (n:nat) (i:nat) (x:Z) (xs:list Z) (H:(i<length xs)%nat) (Hn : length xs = n) (* N.B. We really only need [i < Nat.min n (length xs)] *) : eval n (add_to_nth i x xs) = weight i * x + eval n xs. Proof. subst n. - cbv [eval to_associational add_to_nth runtime_add]. + cbv [eval to_associational add_to_nth]. rewrite ListUtil.combine_update_nth_r at 1. rewrite <-(update_nth_id i (List.combine _ _)) at 2. rewrite <-!(ListUtil.splice_nth_equiv_update_nth_update _ _ @@ -175,11 +167,11 @@ Module Positional. Section Positional. Definition place (t:Z*Z) (i:nat) : nat * Z := nat_rect (fun _ => (nat * Z)%type) - ((O, fst t * snd t)%RT) + (O, fst t * snd t) (fun i' place_i' => let i := S i' in if (fst t mod weight i =? 0) - then (i, let c := fst t / weight i in (c * snd t)%RT) + then (i, let c := fst t / weight i in c * snd t) else place_i') i. @@ -772,8 +764,6 @@ Module Compilers. | List_fold_right {A B} : ident ((B * A -> A) * A * list B) A | List_update_nth {T} : ident (nat * (T -> T) * list T) (list T) | List_nth_default {T} : ident (T * list T * nat) T - | Z_runtime_mul : ident (Z * Z) Z - | Z_runtime_add : ident (Z * Z) Z | Z_add : ident (Z * Z) Z | Z_mul : ident (Z * Z) Z | Z_pow : ident (Z * Z) Z @@ -819,8 +809,6 @@ Module Compilers. | List_fold_right A B => curry3_1 (@List.fold_right (type.interp A) (type.interp B)) | List_update_nth T => curry3 (@update_nth (type.interp T)) | List_nth_default T => curry3 (@List.nth_default (type.interp T)) - | Z_runtime_mul => curry2 runtime_mul - | Z_runtime_add => curry2 runtime_add | Z_add => curry2 Z.add | Z_mul => curry2 Z.mul | Z_pow => curry2 Z.pow @@ -914,8 +902,6 @@ Module Compilers. | @List.nth_default ?T ?d ?ls ?n => let rT := type.reify T in mkAppIdent (@ident.List_nth_default rT) (d, ls, n) - | runtime_mul ?x ?y => mkAppIdent ident.Z_runtime_mul (x, y) - | runtime_add ?x ?y => mkAppIdent ident.Z_runtime_add (x, y) | Z.add ?x ?y => mkAppIdent ident.Z_add (x, y) | Z.mul ?x ?y => mkAppIdent ident.Z_mul (x, y) | Z.pow ?x ?y => mkAppIdent ident.Z_pow (x, y) @@ -953,8 +939,6 @@ Module Compilers. End List. Module Z. - Notation runtime_mul := Z_runtime_mul. - Notation runtime_add := Z_runtime_add. Notation add := Z_add. Notation mul := Z_mul. Notation pow := Z_pow. @@ -1018,10 +1002,8 @@ Module Compilers. | list_rect {A P} : ident (P * (A * list A * P -> P) * list A) P | List_nth_default {T} : ident (T * list T * nat) T | List_nth_default_concrete {T : type.primitive} (d : interp T) (n : Datatypes.nat) : ident (list T) T - | Z_runtime_mul : ident (Z * Z) Z - | Z_runtime_add : ident (Z * Z) Z - | Z_runtime_shiftr (offset : BinInt.Z) : ident Z Z - | Z_runtime_land (mask : BinInt.Z) : ident Z Z + | Z_shiftr (offset : BinInt.Z) : ident Z Z + | Z_land (mask : BinInt.Z) : ident Z Z | Z_add : ident (Z * Z) Z | Z_mul : ident (Z * Z) Z | Z_pow : ident (Z * Z) Z @@ -1061,10 +1043,8 @@ Module Compilers. | list_rect A P => curry3_23 (@Datatypes.list_rect (type.interp A) (fun _ => type.interp P)) | List_nth_default T => curry3 (@List.nth_default (type.interp T)) | List_nth_default_concrete T d n => fun ls => @List.nth_default (type.interp T) d ls n - | Z_runtime_mul => curry2 runtime_mul - | Z_runtime_add => curry2 runtime_add - | Z_runtime_shiftr n => fun v => Z.shiftr v n - | Z_runtime_land mask => fun v => Z.land v mask + | Z_shiftr n => fun v => Z.shiftr v n + | Z_land mask => fun v => Z.land v mask | Z_add => curry2 Z.add | Z_mul => curry2 Z.mul | Z_pow => curry2 Z.pow @@ -1137,8 +1117,6 @@ Module Compilers. | @List.nth_default ?T ?d ?ls ?n => let rT := type.reify T in mkAppIdent (@ident.List_nth_default rT) (d, ls, n) - | runtime_mul ?x ?y => mkAppIdent ident.Z_runtime_mul (x, y) - | runtime_add ?x ?y => mkAppIdent ident.Z_runtime_add (x, y) | Z.add ?x ?y => mkAppIdent ident.Z_add (x, y) | Z.mul ?x ?y => mkAppIdent ident.Z_mul (x, y) | Z.pow ?x ?y => mkAppIdent ident.Z_pow (x, y) @@ -1167,10 +1145,8 @@ Module Compilers. End List. Module Z. - Notation runtime_mul := Z_runtime_mul. - Notation runtime_add := Z_runtime_add. - Notation runtime_shiftr := Z_runtime_shiftr. - Notation runtime_land := Z_runtime_land. + Notation shiftr := Z_shiftr. + Notation land := Z_land. Notation add := Z_add. Notation mul := Z_mul. Notation pow := Z_pow. @@ -1258,10 +1234,6 @@ Module Compilers. => AppIdent ident.pred | for_reification.ident.primitive t v => AppIdent (ident.primitive v) - | for_reification.ident.Z_runtime_mul - => AppIdent ident.Z.runtime_mul - | for_reification.ident.Z_runtime_add - => AppIdent ident.Z.runtime_add | for_reification.ident.Z_add => AppIdent ident.Z.add | for_reification.ident.Z_mul @@ -2029,10 +2001,8 @@ Module Compilers. | ident.primitive _ _ as idc | ident.Nat_succ as idc | ident.pred as idc - | ident.Z_runtime_mul as idc - | ident.Z_runtime_add as idc - | ident.Z_runtime_shiftr _ as idc - | ident.Z_runtime_land _ as idc + | ident.Z_shiftr _ as idc + | ident.Z_land _ as idc | ident.Z_add as idc | ident.Z_mul as idc | ident.Z_pow as idc @@ -2198,16 +2168,14 @@ Module Compilers. (ident.snd @@ (Var xyk)) @ ((idc : default.ident _ type.nat) @@ (ident.fst @@ (Var xyk))) - | ident.Z_runtime_shiftr _ as idc - | ident.Z_runtime_land _ as idc + | ident.Z_shiftr _ as idc + | ident.Z_land _ as idc | ident.Z_opp as idc => λ (xyk : (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.Z * (type.Z -> R))%ctype) , (ident.snd @@ (Var xyk)) @ ((idc : default.ident _ type.Z) @@ (ident.fst @@ (Var xyk))) - | ident.Z_runtime_mul as idc - | ident.Z_runtime_add as idc | ident.Z_add as idc | ident.Z_mul as idc | ident.Z_pow as idc @@ -2817,15 +2785,13 @@ Module Compilers. | ident.Nat_succ as idc | ident.Z_of_nat as idc | ident.Z_opp as idc - | ident.Z_runtime_shiftr _ as idc - | ident.Z_runtime_land _ as idc + | ident.Z_shiftr _ as idc + | ident.Z_land _ as idc => fun x : expr _ + type.interp _ => match x return expr _ + type.interp _ with | inr x => inr (ident.interp idc x) | inl x => expr.reflect (AppIdent idc x) end - | ident.Z_add as idc - | ident.Z_mul as idc | ident.Z_pow as idc | ident.Z_eqb as idc => fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) @@ -2840,7 +2806,7 @@ Module Compilers. | inr (inr x, inr y) => inr (ident.interp idc (x, y)) | inr (x, inr y) => if Z.eqb y (2^Z.log2 y) - then expr.reflect (AppIdent (ident.Z.runtime_shiftr (Z.log2 y)) (expr.reify (t:=type.Z) x)) + then expr.reflect (AppIdent (ident.Z.shiftr (Z.log2 y)) (expr.reify (t:=type.Z) x)) else default | _ => default end @@ -2851,11 +2817,11 @@ Module Compilers. | inr (inr x, inr y) => inr (ident.interp idc (x, y)) | inr (x, inr y) => if Z.eqb y (2^Z.log2 y) - then expr.reflect (AppIdent (ident.Z.runtime_land (y-1)) (expr.reify (t:=type.Z) x)) + then expr.reflect (AppIdent (ident.Z.land (y-1)) (expr.reify (t:=type.Z) x)) else default | _ => default end - | ident.Z_runtime_mul as idc + | ident.Z_mul as idc => fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in match x_y return expr _ + type.interp _ with @@ -2869,7 +2835,7 @@ Module Compilers. else default | inr (inl _, inl _) | inl _ => default end - | ident.Z_runtime_add as idc + | ident.Z_add as idc => fun (x_y : expr (_ * _) + (expr _ + type.interp _) * (expr _ + type.interp _)) => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in match x_y return expr _ + type.interp _ with @@ -2916,7 +2882,6 @@ Module Compilers. Fixpoint to_mul_list (e : @expr var type.Z) : list (@expr var type.Z) := match e in expr.expr t return list (@expr var t) with - | AppIdent s type.Z ident.Z_runtime_mul (Pair type.Z type.Z x y) | AppIdent s type.Z ident.Z_mul (Pair type.Z type.Z x y) => to_mul_list x ++ to_mul_list y | Var _ _ as e @@ -2930,10 +2895,8 @@ Module Compilers. Fixpoint to_add_mul_list (e : @expr var type.Z) : list (list (@expr var type.Z)) := match e in expr.expr t return list (list (@expr var t)) with - | AppIdent s type.Z ident.Z_runtime_add (Pair type.Z type.Z x y) | AppIdent s type.Z ident.Z_add (Pair type.Z type.Z x y) => to_add_mul_list x ++ to_add_mul_list y - | AppIdent s type.Z ident.Z_runtime_mul (Pair type.Z type.Z x y) | AppIdent s type.Z ident.Z_mul (Pair type.Z type.Z x y) => [to_mul_list x ++ to_mul_list y] | Var _ _ as e @@ -3369,15 +3332,13 @@ Module Compilers. | default.ident.List_nth_default T => None | ident.List_nth_default_concrete T d n => Some (@List_nth (type.primitive.compile T) n) - | ident.Z_runtime_mul | default.ident.Z_mul => Some Z.mul - | ident.Z_runtime_add | default.ident.Z_add => Some Z.add - | ident.Z_runtime_shiftr n + | ident.Z_shiftr n => Some (Z.shiftr n) - | ident.Z_runtime_land mask + | ident.Z_land mask => Some (Z.land mask) | ident.Z_pow | ident.Z_opp @@ -3835,25 +3796,17 @@ Local Coercion QArith_base.inject_Z : Z >-> Q. - reassociation - indexed + bounds analysis + of phoas *) -Delimit Scope RT_expr_scope with RT_expr. Import Uncurried. Import expr. Import for_reification.Notations.Reification. Notation "x + y" - := (AppIdent ident.Z.runtime_add (x%RT_expr, y%RT_expr)%expr) - : RT_expr_scope. -Notation "x * y" - := (AppIdent ident.Z.runtime_mul (x%RT_expr, y%RT_expr)%expr) - : RT_expr_scope. -Notation "x + y" - := (AppIdent ident.Z.runtime_add (x%RT_expr, y%RT_expr)%expr) + := (AppIdent ident.Z.add (x, y)%expr) : expr_scope. Notation "x * y" - := (AppIdent ident.Z.runtime_mul (x%RT_expr, y%RT_expr)%expr) + := (AppIdent ident.Z.mul (x, y)%expr) : expr_scope. Notation "x" := (Var x) (only printing, at level 9) : expr_scope. -Open Scope RT_expr_scope. Require Import AdmitAxiom. @@ -3874,8 +3827,8 @@ Module test2. Proof. let v := Reify (fun y : Z => (fun k : Z * Z -> Z * Z - => dlet_nd x := (y * y)%RT in - dlet_nd z := (x * x)%RT in + => dlet_nd x := (y * y) in + dlet_nd z := (x * x) in k (z, z)) (fun v => v)) in pose v as E. @@ -3885,8 +3838,8 @@ Module test2. lazymatch (eval cbv delta [E'] in E') with | (fun var : type -> Type => (λ x : var (type.type_primitive type.Z), - expr_let x0 := (Var x * Var x)%RT_expr in - expr_let x1 := (Var x0 * Var x0)%RT_expr in + expr_let x0 := (Var x * Var x) in + expr_let x1 := (Var x0 * Var x0) in (Var x1, Var x1))%expr) => idtac end. Import BoundsAnalysis.ident. @@ -3909,11 +3862,11 @@ Module test3. Example test3 : True. Proof. let v := Reify (fun y : Z - => dlet_nd x := dlet_nd x := (y * y)%RT in - (x * x)%RT in - dlet_nd z := dlet_nd z := (x * x)%RT in - (z * z)%RT in - (z * z)%RT) in + => dlet_nd x := dlet_nd x := (y * y) in + (x * x) in + dlet_nd z := dlet_nd z := (x * x) in + (z * z) in + (z * z)) in pose v as E. vm_compute in E. pose (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. @@ -3925,7 +3878,7 @@ Module test3. expr_let x1 := Var x0 * Var x0 in expr_let x2 := Var x1 * Var x1 in expr_let x3 := Var x2 * Var x2 in - Var x3 * Var x3)%RT_expr%expr) + Var x3 * Var x3)%expr) => idtac end. Import BoundsAnalysis.ident. @@ -3955,8 +3908,8 @@ Module test4. let v := Reify (fun y : (list Z * list Z) => dlet_nd x := List.nth_default (-1) (fst y) 0 in dlet_nd z := List.nth_default (-1) (snd y) 0 in - dlet_nd xz := (x * z)%RT in - (xz :: xz :: nil)%RT) in + dlet_nd xz := (x * z) in + (xz :: xz :: nil)) in pose v as E. vm_compute in E. pose (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. @@ -3986,7 +3939,7 @@ Module test5. Example test5 : True. Proof. let v := Reify (fun y : (Z * Z) - => dlet_nd x := (13 * (fst y * snd y))%RT in + => dlet_nd x := (13 * (fst y * snd y)) in x) in pose v as E. vm_compute in E. @@ -4010,7 +3963,7 @@ Module test6. Proof. let v := Reify (fun y : Z => if 0 =? 1 - then dlet_nd x := (y * y)%RT in + then dlet_nd x := (y * y) in x else y) in pose v as E. |