aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v124
1 files changed, 95 insertions, 29 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 8bda15482..3fd269fe7 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -687,6 +687,51 @@ Module Compilers.
| Z : arguments type.Z
| bool : arguments type.bool.
+ Definition preinvertT (t : type) :=
+ match t with
+ | type.unit => Datatypes.unit
+ | type.prod A B => arguments A * arguments B
+ | type.arrow s d => arguments d
+ | type.list A => arguments A
+ | type.nat => Datatypes.unit
+ | type.Z => Datatypes.unit
+ | type.bool => Datatypes.unit
+ end%type.
+ Definition invertT (t : type) :=
+ option (* [None] means "generic" *) (preinvertT t).
+
+ Definition invert {t : type} (P : arguments t -> Type)
+ (generic_case : P generic)
+ (non_generic_cases
+ : forall v : preinvertT t,
+ match t return forall v : preinvertT t, (arguments t -> Type) -> Type with
+ | type.unit
+ => fun v P => P unit
+ | type.prod A B
+ => fun '((a, b) : arguments A * arguments B) P
+ => P (prod a b)
+ | type.arrow s d => fun v P => P (arrow v)
+ | type.list A => fun v P => P (list v)
+ | type.nat => fun v P => P nat
+ | type.Z => fun v P => P Z
+ | type.bool => fun v P => P bool
+ end v P)
+ (a : arguments t)
+ : P a.
+ Proof.
+ destruct a;
+ try specialize (fun a b => non_generic_cases (a, b));
+ cbn in *;
+ [ exact generic_case | apply non_generic_cases; apply tt .. ].
+ Defined.
+
+ Definition invert_arrow {s d} (a : arguments (type.arrow s d)) : arguments d
+ := @invert (type.arrow s d) (fun _ => arguments d) generic (fun ad => ad) a.
+
+ Definition invert_prod {A B} (a : arguments (type.prod A B)) : arguments A * arguments B
+ := @invert (type.prod A B) (fun _ => arguments A * arguments B)%type (generic, generic) (fun '(a, b) => (a, b)) a.
+
+
Fixpoint ground {t : type} : arguments t
:= match t with
| type.unit => unit
@@ -997,39 +1042,54 @@ Module Compilers.
Section partial_reduce.
Context {var : type -> Type}.
- Fixpoint partial_reduce_cps {T} {t} (e : @expr (@expr var) t)
- : (@expr var t -> @expr var T) -> @expr var T
- := match e in expr t return (expr t -> expr T) -> expr T with
- | TT => fun k => k TT
+ Local Notation partial_reduceT t a
+ := ((@expr var t * arguments.type.option.interp (@expr var) (@expr var) a)%type)
+ (only parsing).
+
+ Fixpoint partial_reduce' {t} (e : @expr (@expr var) t)
+ : forall a : arguments t, partial_reduceT t a
+ := match e in expr t return (forall a : arguments t, partial_reduceT t a) with
+ | TT
+ => arguments.invert
+ (fun a : arguments type.unit => partial_reduceT type.unit a)
+ (TT, TT)
+ (fun u => (TT, u))
| Pair A B a b
- => fun k
- => @partial_reduce_cps
- T A a
- (fun a'
- => @partial_reduce_cps
- T B b
- (fun b' => k (Pair a' b')))
- | Var t v => fun k => k v
- | Abs s d f
- => fun k
- => k (Abs (fun x => @partial_reduce_cps _ d (f (Var x)) id))
+ => arguments.invert
+ (fun a => partial_reduceT (type.prod A B) a)
+ (let ab := (fst (@partial_reduce' A a arguments.generic),
+ fst (@partial_reduce' B b arguments.generic))%expr in
+ (ab, ab))
+ (fun '(aA, aB)
+ => let '(a0, a1) := @partial_reduce' A a aA in
+ let '(b0, b1) := @partial_reduce' B b aB in
+ ((a0, b0)%expr, Some (a1, b1)))
+ | Var t v
+ => fun a => (v, arguments.expr.interp _ _ v)
| Op s d opc args
- => fun k
- => @partial_reduce_cps
- T s args
- (fun args'
- => k
- match arguments.op.rewrite
- opc
- (arguments.expr.interp _ _ args')
- with
- | Some e => arguments.expr.reify _ _ e
- | None => Op opc args'
- end)
+ => let '(args0, args1) := @partial_reduce' s args (arguments.op.lookup_src opc) in
+ let e :=
+ match arguments.op.rewrite opc args1 with
+ | Some e => arguments.expr.reify _ _ e
+ | None => Op opc args0
+ end in
+ fun a => (e, arguments.expr.interp _ _ e)
+ | Abs s d f
+ => fun a
+ => let e' := Abs (fun x => fst (@partial_reduce' d (f (Var x)) (arguments.invert_arrow a))) in
+ arguments.invert
+ (fun a => partial_reduceT (type.arrow s d) a)
+ (e', e')
+ (fun ad
+ => (e',
+ (fun x =>
+ snd (@partial_reduce' d (f x) ad))))
+ a
end.
+
Definition partial_reduce {t} (e : @expr (@expr var) t) : @expr var t
- := @partial_reduce_cps t t e id.
+ := snd (@partial_reduce' t e arguments.generic).
End partial_reduce.
Definition PartialReduce {t} (e : Expr t) : Expr t
@@ -1332,6 +1392,13 @@ Example base_25_5_mul (*(f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 g0 g1 g2 g3 g4 g5 g6 g7 g
apply (f_equal (fun F => F f g)).
cbv [n].
cbv delta [mulmod w to_associational mul to_associational reduce from_associational add_to_nth zeros place split].
+ assert True.
+ { let v := Reify ((fun x => 2^x) 255)%Z in
+ pose v as E.
+ vm_compute in E.
+ pose (PartialReduce E) as E'.
+ vm_compute in E'.
+ constructor. }
Reify_rhs ().
let e := match goal with |- _ = Interp ?e => e end in
pose e as E.
@@ -1339,7 +1406,6 @@ Example base_25_5_mul (*(f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 g0 g1 g2 g3 g4 g5 g6 g7 g
Timeout 2 vm_compute in E.
pose (PartialReduce E) as E'.
Timeout 2 vm_compute in E'.
-
(*cbv -[runtime_mul runtime_add]; cbv [runtime_mul runtime_add].
ring_simplify_subterms.*)
(* ?fg =