diff options
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 124 |
1 files changed, 95 insertions, 29 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 8bda15482..3fd269fe7 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -687,6 +687,51 @@ Module Compilers. | Z : arguments type.Z | bool : arguments type.bool. + Definition preinvertT (t : type) := + match t with + | type.unit => Datatypes.unit + | type.prod A B => arguments A * arguments B + | type.arrow s d => arguments d + | type.list A => arguments A + | type.nat => Datatypes.unit + | type.Z => Datatypes.unit + | type.bool => Datatypes.unit + end%type. + Definition invertT (t : type) := + option (* [None] means "generic" *) (preinvertT t). + + Definition invert {t : type} (P : arguments t -> Type) + (generic_case : P generic) + (non_generic_cases + : forall v : preinvertT t, + match t return forall v : preinvertT t, (arguments t -> Type) -> Type with + | type.unit + => fun v P => P unit + | type.prod A B + => fun '((a, b) : arguments A * arguments B) P + => P (prod a b) + | type.arrow s d => fun v P => P (arrow v) + | type.list A => fun v P => P (list v) + | type.nat => fun v P => P nat + | type.Z => fun v P => P Z + | type.bool => fun v P => P bool + end v P) + (a : arguments t) + : P a. + Proof. + destruct a; + try specialize (fun a b => non_generic_cases (a, b)); + cbn in *; + [ exact generic_case | apply non_generic_cases; apply tt .. ]. + Defined. + + Definition invert_arrow {s d} (a : arguments (type.arrow s d)) : arguments d + := @invert (type.arrow s d) (fun _ => arguments d) generic (fun ad => ad) a. + + Definition invert_prod {A B} (a : arguments (type.prod A B)) : arguments A * arguments B + := @invert (type.prod A B) (fun _ => arguments A * arguments B)%type (generic, generic) (fun '(a, b) => (a, b)) a. + + Fixpoint ground {t : type} : arguments t := match t with | type.unit => unit @@ -997,39 +1042,54 @@ Module Compilers. Section partial_reduce. Context {var : type -> Type}. - Fixpoint partial_reduce_cps {T} {t} (e : @expr (@expr var) t) - : (@expr var t -> @expr var T) -> @expr var T - := match e in expr t return (expr t -> expr T) -> expr T with - | TT => fun k => k TT + Local Notation partial_reduceT t a + := ((@expr var t * arguments.type.option.interp (@expr var) (@expr var) a)%type) + (only parsing). + + Fixpoint partial_reduce' {t} (e : @expr (@expr var) t) + : forall a : arguments t, partial_reduceT t a + := match e in expr t return (forall a : arguments t, partial_reduceT t a) with + | TT + => arguments.invert + (fun a : arguments type.unit => partial_reduceT type.unit a) + (TT, TT) + (fun u => (TT, u)) | Pair A B a b - => fun k - => @partial_reduce_cps - T A a - (fun a' - => @partial_reduce_cps - T B b - (fun b' => k (Pair a' b'))) - | Var t v => fun k => k v - | Abs s d f - => fun k - => k (Abs (fun x => @partial_reduce_cps _ d (f (Var x)) id)) + => arguments.invert + (fun a => partial_reduceT (type.prod A B) a) + (let ab := (fst (@partial_reduce' A a arguments.generic), + fst (@partial_reduce' B b arguments.generic))%expr in + (ab, ab)) + (fun '(aA, aB) + => let '(a0, a1) := @partial_reduce' A a aA in + let '(b0, b1) := @partial_reduce' B b aB in + ((a0, b0)%expr, Some (a1, b1))) + | Var t v + => fun a => (v, arguments.expr.interp _ _ v) | Op s d opc args - => fun k - => @partial_reduce_cps - T s args - (fun args' - => k - match arguments.op.rewrite - opc - (arguments.expr.interp _ _ args') - with - | Some e => arguments.expr.reify _ _ e - | None => Op opc args' - end) + => let '(args0, args1) := @partial_reduce' s args (arguments.op.lookup_src opc) in + let e := + match arguments.op.rewrite opc args1 with + | Some e => arguments.expr.reify _ _ e + | None => Op opc args0 + end in + fun a => (e, arguments.expr.interp _ _ e) + | Abs s d f + => fun a + => let e' := Abs (fun x => fst (@partial_reduce' d (f (Var x)) (arguments.invert_arrow a))) in + arguments.invert + (fun a => partial_reduceT (type.arrow s d) a) + (e', e') + (fun ad + => (e', + (fun x => + snd (@partial_reduce' d (f x) ad)))) + a end. + Definition partial_reduce {t} (e : @expr (@expr var) t) : @expr var t - := @partial_reduce_cps t t e id. + := snd (@partial_reduce' t e arguments.generic). End partial_reduce. Definition PartialReduce {t} (e : Expr t) : Expr t @@ -1332,6 +1392,13 @@ Example base_25_5_mul (*(f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 g0 g1 g2 g3 g4 g5 g6 g7 g apply (f_equal (fun F => F f g)). cbv [n]. cbv delta [mulmod w to_associational mul to_associational reduce from_associational add_to_nth zeros place split]. + assert True. + { let v := Reify ((fun x => 2^x) 255)%Z in + pose v as E. + vm_compute in E. + pose (PartialReduce E) as E'. + vm_compute in E'. + constructor. } Reify_rhs (). let e := match goal with |- _ = Interp ?e => e end in pose e as E. @@ -1339,7 +1406,6 @@ Example base_25_5_mul (*(f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 g0 g1 g2 g3 g4 g5 g6 g7 g Timeout 2 vm_compute in E. pose (PartialReduce E) as E'. Timeout 2 vm_compute in E'. - (*cbv -[runtime_mul runtime_add]; cbv [runtime_mul runtime_add]. ring_simplify_subterms.*) (* ?fg = |