diff options
author | Jason Gross <jgross@mit.edu> | 2017-11-24 18:19:12 -0500 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-01-29 18:04:58 -0500 |
commit | 20de7a22f740afe92570223b10347c9773598de7 (patch) | |
tree | f16b2ccaf09d50da834b3279e4a1853733e23128 /src/Experiments/SimplyTypedArithmetic.v | |
parent | 23b11c08bb5b810aa934d26e5f0596ff099d5389 (diff) |
Partial work on implementing partial reduction
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 591 |
1 files changed, 591 insertions, 0 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 6584e813e..8bda15482 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -451,6 +451,590 @@ Module Compilers. Definition Interp {t} (e : Expr t) := interp (e _). + Definition const {var t} (v : type.interp t) : @expr var t + := Op (op.Const v) TT. + + Section option_partition. + Context {A : Type} (f : A -> option Datatypes.bool). + Fixpoint option_partition (l : list A) : option (list A * list A) + := match l with + | nil => Some (nil, nil) + | cons x tl + => match option_partition tl, f x with + | Some (g, d), Some fx + => Some (if fx then (x :: g, d) else (g, x :: d)) + | _, _ => None + end + end. + End option_partition. + Section option_flat_map. + Context {A B : Type} (f : A -> option (list B)). + Fixpoint option_flat_map (l : list A) : option (list B) + := match l with + | nil => Some nil + | cons x t => match f x, option_flat_map t with + | Some fx, Some ft + => Some (fx ++ ft) + | _, _ => None + end + end. + End option_flat_map. + + Definition lift_option_list {A} (ls : list (option A)) : option (list A) + := list_rect + (fun _ => _) + (Some nil) + (fun x _ xs + => match x, xs with + | Some x, Some xs => Some (cons x xs) + | _, _ => None + end) + ls. + + Section invert. + Context {var : type -> Type}. + + Definition invert_Var {t} (e : @expr var t) : option (var t) + := match e with + | Var t v => Some v + | _ => None + end. + + Definition invert_Abs {s d} (e : @expr var (type.arrow s d)) : option (var s -> @expr var d) + := match e in expr t return option match t with + | type.arrow _ _ => _ + | _ => True + end with + | Abs s d f => Some f + | _ => None + end. + + Definition invert_Pair {A B} (e : @expr var (type.prod A B)) : option (@expr var A * @expr var B) + := match e in expr t return option match t with + | type.prod _ _ => _ + | _ => True + end with + | Pair _ _ a b => Some (a, b) + | _ => None + end. + + Definition invert_Op {t} (e : @expr var t) : option { s : _ & op s t * @expr var s }%type + := match e with + | Op s d opc args => Some (existT _ s (opc, args)) + | _ => None + end. + + Definition invert_OpConst {t} (e : @expr var t) : option (type.interp t) + := match invert_Op e with + | Some (existT s (opc, args)) + => match opc with + | op.Const t v => Some v + | _ => None + end + | None => None + end. + + Definition invert_op_S (e : @expr var type.nat) : option (@expr var type.nat) + := match invert_Op e with + | Some (existT s (opc, args)) + => match opc in op s d return expr s -> option (expr type.nat) with + | op.S => fun args => Some args + | _ => fun _ => None + end args + | None => None + end. + + Definition invert_Z (e : @expr var type.Z) : option Z := invert_OpConst e. + Definition invert_bool (e : @expr var type.bool) : option bool := invert_OpConst e. + Fixpoint invert_nat_full (e : @expr var type.nat) : option nat + := match e with + | Op _ _ op.S args + => option_map S (invert_nat_full args) + | Op _ _ (op.Const type.nat v) _ + => Some v + | _ => None + end. + (* oh, the horrors of not being able to use non-linear deep pattern matches *) + Fixpoint invert_list_full {t} (e : @expr var (type.list t)) + : option (list (@expr var t)) + := match e in expr t return option match t with + | type.list t => list (expr t) + | _ => True + end + with + | Op s d opc args + => match opc in op s d + return option match s with + | type.prod A (type.list B) => expr A * list (expr B) + | _ => True + end + -> option match d with + | type.list t => list (expr t) + | _ => True + end + with + | op.Const (type.list _) v => fun _ => Some (List.map const v) + | op.cons _ => option_map (fun '(x, xs) => cons x xs) + | op.nil _ => fun _ => Some nil + | _ => fun _ => None + end + (match args in expr t + return option match t with + | type.prod A (type.list B) => expr A * list (expr B) + | _ => True + end + with + | Pair _ (type.list _) x xs + => match @invert_list_full _ xs with + | Some xs => Some (x, xs) + | None => None + end + | _ => None + end) + | _ => None + end. + (*Section with_map. + (* oh, the horrors of not being able to use non-linear deep + pattern matches. Luckily Coq's guard checker unfolds things, + so as long as the thing we need to evaluate at the bottom is + generic in what type it's looking at, we're good. We can + even give it data of the right type, which we need, but it + costs us a lot *) + Context (extra_dataT : type -> Type) {U} (f : forall t (d : extra_dataT t), @expr var t -> U t d). + Local Notation list_to_extra_dataT t + := (match t%ctype return Type with + | type.list t' => extra_dataT t' + | _ => True + end). + Local Notation list_to_forall_data_option_list t + := (forall (d : list_to_extra_dataT t), + option (match t%ctype as t' + return list_to_extra_dataT t' -> Type + with + | type.list t' => fun d => list (U t' d) + | _ => fun _ => True + end d)). + Local Notation prod_list_to_extra_dataT t + := (match t%ctype return Type with + | type.prod _ (type.list t') => extra_dataT t' + | _ => True + end). + Local Notation prod_list_to_forall_data_option_list t + := (forall (d : prod_list_to_extra_dataT t), + option (match t%ctype as t' + return prod_list_to_extra_dataT t' -> Type + with + | type.prod A (type.list t') => fun d => (expr A * list (U t' d))%type + | _ => fun _ => True + end d)). + + + + Fixpoint invert_map_list_full {t} (e : @expr var (type.list t)) + : forall d : extra_dataT t, option (list (U t d)) + := match e in expr t return list_to_forall_data_option_list t + with + | Op s d opc args + => match opc in op s d + return (prod_list_to_forall_data_option_list s + -> list_to_forall_data_option_list d) + with + | op.Const (type.list _) v + => fun _ data => Some (List.map (f _ data) (List.map const v)) + | op.cons _ + => fun xs data + => option_map (fun '(x, xs) => cons (f _ data x) xs) (xs data) + | op.nil _ + => fun _ _ => Some nil + | _ => fun _ _ => None + end + (match args in expr t + return prod_list_to_forall_data_option_list t + with + | Pair _ (type.list _) x xs + => fun data + => match @invert_map_list_full _ xs data with + | Some xs => Some (x, xs) + | None => None + end + | _ => fun _ => None + end) + | _ => fun _ => None + end. + End with_map.*) + End invert. + + Section gallina_reify. + Context {var : type -> Type}. + Definition reify_list {t} (ls : list (@expr var t)) : @expr var (type.list t) + := list_rect + (fun _ => _) + (Op op.nil TT) + (fun x _ xs => Op op.cons (x, xs)) + ls. + End gallina_reify. + + + Module arguments. + Inductive arguments : type -> Set := + | generic {T} : arguments T + (*| cps {T} (aT : arguments T) : arguments T*) + | arrow {A B} (aB : arguments B) : arguments (A -> B) + | unit : arguments type.unit + | prod {A B} (aA : arguments A) (aB : arguments B) : arguments (A * B) + | list {T} (aT : arguments T) : arguments (type.list T) + | nat : arguments type.nat + | Z : arguments type.Z + | bool : arguments type.bool. + + Fixpoint ground {t : type} : arguments t + := match t with + | type.unit => unit + | type.prod A B => prod (@ground A) (@ground B) + | type.arrow s d => arrow (@ground d) + | type.list A => list (@ground A) + | type.nat => nat + | type.Z => Z + | type.bool => bool + end. + + Module type. + Local Notation interp_type := type.interp. + Section interp. + Context (var_dom var_cod : type -> Type). + Fixpoint interp {t} (a : arguments t) : Type + := match a with + | generic T => var_cod T + (*| cps T aT => forall U, (@interp T aT -> var U) -> var U*) + | arrow A B aB => var_dom A -> @interp B aB + | unit => Datatypes.unit + | prod A B aA aB => @interp A aA * @interp B aB + | list T aT => Datatypes.list (@interp T aT) + | nat => Datatypes.nat + | Z => BinInt.Z + | bool => Datatypes.bool + end%type. + End interp. + + Section ground. + Context {var_dom var_cod : type -> Type}. + Fixpoint const_of_ground {t} + : interp_type t -> option (interp var_dom (@expr var_cod) (@ground t)) + := match t return interp_type t -> option (interp var_dom expr (@ground t)) with + | type.prod A B + => fun '((a, b) : interp_type A * interp_type B) + => match @const_of_ground A a, @const_of_ground B b with + | Some a', Some b' => Some (a', b') + | _, _ => None + end + | type.arrow s d => fun _ => None + | type.list A + => fun ls : Datatypes.list (interp_type A) + => lift_option_list + (List.map (@const_of_ground A) ls) + | type.unit + | type.nat + | type.Z + | type.bool + => fun v => Some v + end. + End ground. + + Module option. + Section interp. + Context (var_dom var_cod : type -> Type). + Fixpoint interp {t} (a : arguments t) : Type + := match a with + | generic T => var_cod T + (*| cps T aT => forall U, (@interp T aT -> var U) -> var U*) + | arrow A B aB => var_dom A -> @interp B aB + | unit => Datatypes.unit + | prod A B aA aB => option (@interp A aA * @interp B aB) + | list T aT => option (Datatypes.list (@interp T aT)) + | nat => option Datatypes.nat + | Z => option BinInt.Z + | bool => option Datatypes.bool + end%type. + End interp. + + Section flat_interp. + Context (var_generic var_dom : type -> Type) (var_cod : forall t, arguments t -> Type). + Fixpoint flat_interp {t} (a : arguments t) : Type + := match a with + | generic T => var_generic T + (*| cps T aT => forall U, (@interp T aT -> var U) -> var U*) + | arrow A B aB => var_dom A -> var_cod B aB + | unit => Datatypes.unit + | prod A B aA aB => @flat_interp A aA * @flat_interp B aB + | list T aT => Datatypes.list (@flat_interp T aT) + | nat => Datatypes.nat + | Z => BinInt.Z + | bool => Datatypes.bool + end%type. + End flat_interp. + + Definition interp_to_arrow_or_generic var_dom var_cod {t} a + := @flat_interp var_cod var_dom (@interp var_dom var_cod) t a. + + Section lift_option. + Context {var_dom var_cod : type -> Type}. + Fixpoint lift_interp {t} {a : arguments t} + : interp var_dom var_cod a -> option (interp_to_arrow_or_generic var_dom var_cod a) + := match a in arguments t + return interp var_dom var_cod a -> option (interp_to_arrow_or_generic var_dom var_cod a) + with + | prod A B aA aB + => fun ab : option (interp _ _ aA * interp _ _ aB) + => match ab with + | Some (a, b) + => match @lift_interp A aA a, @lift_interp B aB b with + | Some a, Some b => Some (a, b) + | _, _ => None + end + | None => None + end + | list T aT + => fun ls : option (Datatypes.list (interp _ _ aT)) + => match ls with + | Some ls + => lift_option_list + (List.map (@lift_interp T aT) ls) + | None => None + end + | arrow _ _ _ + | generic _ + | unit + => fun v => Some v + | nat + | Z + | bool + => fun x => x + end. + End lift_option. + End option. + End type. + + Module expr. + Section interp. + Context (var : type -> Type). + Fixpoint interp {t} (a : arguments t) + : @expr var t -> type.option.interp (@expr var) (@expr var) a + := match a in arguments t return expr t -> type.option.interp expr expr a with + | generic T => fun e => e + (*| cps T aT => fun e*) + | arrow A B aB + => fun e arg + => @interp + B aB + match invert_Abs e, invert_Var arg with + | Some f, Some arg => f arg + | _, _ => Op op.App (e, arg) + end + | unit => fun _ => tt + | prod A B aA aB + => fun e + => option_map (fun '(a, b) + => (@interp A aA a, @interp B aB b)) + (invert_Pair e) + | list T aT + => fun e + => option_map + (List.map (@interp T aT)) + (invert_list_full e) + | nat => invert_nat_full + | Z => invert_Z + | bool => invert_bool + end. + End interp. + + Section reify. + Context (var : type -> Type). + Fixpoint reify {t} (a : arguments t) + : type.interp var (@expr var) a -> @expr var t + := match a in arguments t return type.interp var expr a -> expr t with + | generic T => fun e => e + | arrow A B aB => fun f => Abs (fun x => @reify B aB (f x)) + | unit => fun _ => TT + | prod A B aA aB + => fun '((a, b) : type.interp _ _ aA * type.interp _ _ aB) + => (@reify A aA a, @reify B aB b)%expr + | list T aT + => fun ls + => reify_list (List.map (@reify T aT) ls) + | nat => @const var type.nat + | Z => @const var type.Z + | bool => @const var type.bool + end. + End reify. + End expr. + + Module Export Notations. + Delimit Scope arguments_scope with arguments. + Bind Scope arguments_scope with arguments. + Notation "()" := unit : arguments_scope. + Notation "A * B" := (prod A B) : arguments_scope. + Notation "A -> B" := (@arrow A _ B) (only printing) : arguments_scope. + Notation "A -> B" := (arrow B) (only parsing) : arguments_scope. + Global Coercion generic : type.type >-> arguments. + Notation arguments := arguments. + End Notations. + + Module op. + Local Open Scope arguments_scope. + Definition lookup {s d} (opc : op s d) : arguments s * arguments d + := match opc in op s d return arguments s * arguments d with + | op.Const t v => (generic, ground) + | op.Let_In tx tC => (tx * (tx -> tC), generic) + | op.App s d => ((s -> d) * s, generic) + | op.S => (ground, ground) + | op.nil t => (generic, ground) + | op.cons t => (t * list t, list t) + | op.fst A B => (A * B, generic) + | op.snd A B => (A * B, generic) + | op.bool_rect T => (T * T * bool, generic) + | op.nat_rect P => (P * (nat -> P -> P) * nat, generic) + | op.pred => (nat, ground) + | op.List_seq => (nat * nat, ground) + | op.List_repeat A => (A * nat, list A) + | op.List_combine A B => (list A * list B, list (A * B)) + | op.List_map A B => ((A -> B) * list A, list B) + | op.List_flat_map A B => ((A -> list B) * list A, list B) + | op.List_partition A => ((A -> bool) * list A, list A * list A) + | op.List_app A => (list A * list A, list A) + | op.List_fold_right A B => ((B -> A -> A) * A * list B, generic) + | op.List_update_nth T => (nat * (T -> T) * list T, list T) + | op.Z_runtime_mul => (Z * Z, Z) + | op.Z_runtime_add => (Z * Z, Z) + | op.Z_add => (Z * Z, Z) + | op.Z_mul => (Z * Z, Z) + | op.Z_pow => (Z * Z, Z) + | op.Z_opp => (Z, Z) + | op.Z_div => (Z * Z, Z) + | op.Z_modulo => (Z * Z, Z) + | op.Z_eqb => (Z * Z, bool) + | op.Z_of_nat => (nat, Z) + end. + + Definition lookup_src {s d} opc := fst (@lookup s d opc). + Definition lookup_dst {s d} opc := snd (@lookup s d opc). + + Definition option_map_prod {A B C} (f : A -> B -> C) (v : option (option A * option B)) + : option C + := match v with + | Some (Some a, Some b) => Some (f a b) + | _ => None + end. + + + + Definition rewrite + {var : type -> Type} + {s d} (opc : op s d) + (exploded_arguments : type.option.interp (@expr var) (@expr var) (lookup_src opc)) + : option (type.interp var (@expr var) (lookup_dst opc)) + := match opc in op s d + return + (forall (exploded_arguments' : option (type.option.interp_to_arrow_or_generic expr expr (lookup_src opc))), + option (type.interp var expr (lookup_dst opc))) + with + | op.Const t v => fun _ => arguments.type.const_of_ground v + | op.Let_In tx tC + => option_map + (fun '(ex, eC) + => match invert_Var ex, invert_OpConst ex with + | Some v, _ => eC ex + | None, Some v => eC ex + | None, None => Op op.Let_In (ex, Abs (fun v => eC (Var v))) + end) + | op.App s d => option_map (fun '(f, x) => f x) + | op.S as opc + | op.pred as opc + | op.Z_runtime_mul as opc + | op.Z_runtime_add as opc + | op.Z_add as opc + | op.Z_mul as opc + | op.Z_pow as opc + | op.Z_opp as opc + | op.Z_div as opc + | op.Z_modulo as opc + | op.Z_eqb as opc + | op.Z_of_nat as opc + => option_map (op.interp opc) + | op.nil t => fun _ => Some (@nil (type.interp _ _ ground)) + | op.cons t => option_map (op.curry2 cons) + | op.fst A B => option_map (@fst (expr A) (expr B)) + | op.snd A B => option_map (@snd (expr A) (expr B)) + | op.bool_rect T => option_map (op.curry3 (bool_rect (fun _ => _))) + | op.nat_rect P + => option_map + (fun '(O_case, S_case, v) + => nat_rect (fun _ => expr P) O_case (fun n (v : expr P) => S_case (@const _ type.nat n) v) v) + | op.List_seq => option_map (op.curry2 List.seq) + | op.List_repeat A => option_map (op.curry2 (@List.repeat (expr A))) + | op.List_combine A B => option_map (op.curry2 (@List.combine (expr A) (expr B))) + | op.List_map A B => option_map (op.curry2 (@List.map (expr A) (expr B))) + | op.List_flat_map A B + => fun args : option ((expr A -> option (Datatypes.list (expr B))) * Datatypes.list (expr A)) + => match args with + | Some (f, ls) => option_flat_map f ls + | None => None + end + | op.List_partition A + => fun args : option ((expr A -> option Datatypes.bool) * Datatypes.list (expr A)) + => match args with + | Some (f, ls) => option_partition f ls + | None => None + end + | op.List_app A => option_map (op.curry2 (@List.app (expr A))) + | op.List_fold_right A B => option_map (op.curry3 (@List.fold_right (expr A) (expr B))) + | op.List_update_nth T => option_map (op.curry3 (@update_nth (expr T))) + end + (type.option.lift_interp exploded_arguments). + End op. + End arguments. + Export arguments.Notations. + + Section partial_reduce. + Context {var : type -> Type}. + + Fixpoint partial_reduce_cps {T} {t} (e : @expr (@expr var) t) + : (@expr var t -> @expr var T) -> @expr var T + := match e in expr t return (expr t -> expr T) -> expr T with + | TT => fun k => k TT + | Pair A B a b + => fun k + => @partial_reduce_cps + T A a + (fun a' + => @partial_reduce_cps + T B b + (fun b' => k (Pair a' b'))) + | Var t v => fun k => k v + | Abs s d f + => fun k + => k (Abs (fun x => @partial_reduce_cps _ d (f (Var x)) id)) + | Op s d opc args + => fun k + => @partial_reduce_cps + T s args + (fun args' + => k + match arguments.op.rewrite + opc + (arguments.expr.interp _ _ args') + with + | Some e => arguments.expr.reify _ _ e + | None => Op opc args' + end) + end. + + Definition partial_reduce {t} (e : @expr (@expr var) t) : @expr var t + := @partial_reduce_cps t t e id. + End partial_reduce. + + Definition PartialReduce {t} (e : Expr t) : Expr t + := fun var => @partial_reduce var t (e _). + Ltac is_known_const_cps2 term on_success on_failure := let recurse term := is_known_const_cps2 term on_success on_failure in lazymatch term with @@ -749,6 +1333,13 @@ Example base_25_5_mul (*(f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 g0 g1 g2 g3 g4 g5 g6 g7 g cbv [n]. cbv delta [mulmod w to_associational mul to_associational reduce from_associational add_to_nth zeros place split]. Reify_rhs (). + let e := match goal with |- _ = Interp ?e => e end in + pose e as E. + exfalso. + Timeout 2 vm_compute in E. + pose (PartialReduce E) as E'. + Timeout 2 vm_compute in E'. + (*cbv -[runtime_mul runtime_add]; cbv [runtime_mul runtime_add]. ring_simplify_subterms.*) (* ?fg = |