aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2018-10-01 17:48:12 -0400
committerGravatar Jason Gross <jgross@mit.edu>2018-10-01 17:48:12 -0400
commit05aabe205a94b41966115df9ce52056387566193 (patch)
treeb5b1edcbcad26f00944b73ec0f2520f3deea5848
parentcaf6699097f02b8e5b7f69f94b03507386c3ea04 (diff)
Add pattern.ident.to_typed
-rw-r--r--src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v145
-rw-r--r--src/Experiments/NewPipeline/Rewriter.v22
2 files changed, 142 insertions, 25 deletions
diff --git a/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v b/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v
index 02e781610..c991c3454 100644
--- a/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v
+++ b/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v
@@ -33,6 +33,18 @@ Module Compilers.
| Compilers.base.type.list A => type.list (relax A)
end.
+ Fixpoint subst_default (ptype : type) (evar_map : EvarMap) : Compilers.base.type
+ := match ptype with
+ | type.var p => match PositiveMap.find p evar_map with
+ | Some t => t
+ | None => Compilers.base.type.type_base base.type.unit
+ end
+ | type.type_base t => Compilers.base.type.type_base t
+ | type.prod A B
+ => Compilers.base.type.prod (subst_default A evar_map) (subst_default B evar_map)
+ | type.list A => Compilers.base.type.list (subst_default A evar_map)
+ end.
+
Module Notations.
Global Coercion type.type_base : Compilers.base.type.base >-> type.type.
Bind Scope pbtype_scope with type.type.
@@ -67,6 +79,12 @@ Module Compilers.
| type.base t => type.base (base.relax t)
| type.arrow s d => type.arrow (relax s) (relax d)
end.
+
+ Fixpoint subst_default (ptype : type) (evar_map : EvarMap) : type.type Compilers.base.type
+ := match ptype with
+ | type.base t => type.base (base.subst_default t evar_map)
+ | type.arrow A B => type.arrow (subst_default A evar_map) (subst_default B evar_map)
+ end.
End type.
(*
@@ -831,6 +849,18 @@ Module Compilers.
| Compilers.base.type.list A => type.list (relax A)
end.
+ Fixpoint subst_default (ptype : type) (evar_map : EvarMap) : Compilers.base.type
+ := match ptype with
+ | type.var p => match PositiveMap.find p evar_map with
+ | Some t => t
+ | None => Compilers.base.type.type_base base.type.unit
+ end
+ | type.type_base t => Compilers.base.type.type_base t
+ | type.prod A B
+ => Compilers.base.type.prod (subst_default A evar_map) (subst_default B evar_map)
+ | type.list A => Compilers.base.type.list (subst_default A evar_map)
+ end.
+
Module Notations.
Global Coercion type.type_base : Compilers.base.type.base >-> type.type.
Bind Scope pbtype_scope with type.type.
@@ -865,6 +895,12 @@ Module Compilers.
| type.base t => type.base (base.relax t)
| type.arrow s d => type.arrow (relax s) (relax d)
end.
+
+ Fixpoint subst_default (ptype : type) (evar_map : EvarMap) : type.type Compilers.base.type
+ := match ptype with
+ | type.base t => type.base (base.subst_default t evar_map)
+ | type.arrow A B => type.arrow (subst_default A evar_map) (subst_default B evar_map)
+ end.
End type.
(""" + """*
@@ -896,11 +932,13 @@ retcode += addnewline(r"""%s(*
maxeta = max([len(ttype + ctype) for ttype, ctype in zip(ttypes, ctypes)])
#if any(len(ttype) > 0 and len(ctype) > 0 for ttype, ctype in zip(ttypes, ctypes)):
-# retcode += addnewline(r"""%sLocal Notation eta_sigT x := (existT _ (projT1 x) (projT2 x)) (only parsing).""" % indent1)
+# retcode += addnewline(r"""%sLocal Notation eta_sigT x := (existT _ (projT1 x) (projT2 x)) (only parsing).""" % indent0)
if maxeta >= 2:
- retcode += addnewline(r"""%sLocal Notation eta2 x := (Datatypes.fst x, Datatypes.snd x) (only parsing).""" % indent1)
+ retcode += addnewline(r"""%sLocal Notation eta2 x := (Datatypes.fst x, Datatypes.snd x) (only parsing).""" % indent0)
+ retcode += addnewline(r"""%sLocal Notation eta2r x := (Datatypes.fst x, Datatypes.snd x) (only parsing).""" % indent0)
for i in range(3, maxeta+1):
- retcode += addnewline(r"""%sLocal Notation eta%d x := (eta%d (Datatypes.fst x), Datatypes.snd x) (only parsing).""" % (indent1, i, i-1))
+ retcode += addnewline(r"""%sLocal Notation eta%d x := (eta%d (Datatypes.fst x), Datatypes.snd x) (only parsing).""" % (indent0, i, i-1))
+ retcode += addnewline(r"""%sLocal Notation eta%dr x := (Datatypes.fst x, eta%dr (Datatypes.snd x)) (only parsing).""" % (indent0, i, i-1))
retcode += addnewline('')
def do_adjust_type(ctor, ctype):
@@ -943,6 +981,10 @@ def make_fun_project(named_ttype, ctype, ctor):
pr2_eta = ('projT2 arg' if len(ctype) == 1 else 'eta%d (projT2 arg)' % len(ctype))
return "fun arg => let '(%s, %s) := (%s, %s) in " % (pr1, pr2, pr1_eta, pr2_eta)
+def make_fun_project_list(ctype, ctor):
+ if len(ctype) == 0: return 'fun _ => '
+ return "fun arg => let '%s := eta%dr arg in " % (fold_right_pair(ctor[-len(ctype):], tt='_'), len(ctype)+1)
+
def make_fun_project_match(named_ttype, ctype, ctor, retty, body):
if len(named_ttype + ctype) == 0: return 'fun _ => ' + body
if len(named_ttype + ctype) == 1: return 'fun ' + ctor[-1] + ' => ' + body
@@ -1056,6 +1098,16 @@ retcode += addnewline((r"""%sDefinition type_vars {t} (idc : ident t) : list typ
'\n | '.join('@' + pctor + ' => [' + '; '.join(to_type_var(n, t) for n, t in named_ttype) + ']'
for pctor, named_ttype in zip(pctors_with_args, named_ttypes)))).replace('\n', '\n' + indent1))
+retcode += addnewline((r"""%sDefinition to_typed {t} (idc : ident t) (evm : EvarMap) : type_of_list (arg_types idc) -> %sident.ident (pattern.type.subst_default t evm)
+ := match idc in ident t return type_of_list (arg_types idc) -> %sident.ident (pattern.type.subst_default t evm) with
+ | %s
+ end.
+""" % (indent1, prefix, prefix,
+ '\n | '.join('@' + pctor + ' => '
+ + make_fun_project_list(ctype, ctor)
+ + '@' + ' '.join([ctor[0]] + ['_' for _ in ctor[1:len(ctor)-len(ctype)]] + ctor[len(ctor)-len(ctype):])
+ for pctor, ctor, ctype in zip(pctors_with_args, ctors_with_prefix, ctypes)))).replace('\n', '\n' + indent1))
+
assert(ctors[0][0] == 'ident.Literal')
assert(len(ctypes[0]) == 1)
@@ -1183,8 +1235,10 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
>>>
*)
- Local Notation eta2 x := (Datatypes.fst x, Datatypes.snd x) (only parsing).
- Local Notation eta3 x := (eta2 (Datatypes.fst x), Datatypes.snd x) (only parsing).
+ Local Notation eta2 x := (Datatypes.fst x, Datatypes.snd x) (only parsing).
+ Local Notation eta2r x := (Datatypes.fst x, Datatypes.snd x) (only parsing).
+ Local Notation eta3 x := (eta2 (Datatypes.fst x), Datatypes.snd x) (only parsing).
+ Local Notation eta3r x := (Datatypes.fst x, eta2r (Datatypes.snd x)) (only parsing).
Module Raw.
@@ -2161,6 +2215,87 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| @fancy_addm => []
end%type.
+ Definition to_typed {t} (idc : ident t) (evm : EvarMap) : type_of_list (arg_types idc) -> Compilers.ident.ident (pattern.type.subst_default t evm)
+ := match idc in ident t return type_of_list (arg_types idc) -> Compilers.ident.ident (pattern.type.subst_default t evm) with
+ | @Literal t => fun arg => let '(v, _) := eta2r arg in @Compilers.ident.Literal _ v
+ | @Nat_succ => fun _ => @Compilers.ident.Nat_succ
+ | @Nat_pred => fun _ => @Compilers.ident.Nat_pred
+ | @Nat_max => fun _ => @Compilers.ident.Nat_max
+ | @Nat_mul => fun _ => @Compilers.ident.Nat_mul
+ | @Nat_add => fun _ => @Compilers.ident.Nat_add
+ | @Nat_sub => fun _ => @Compilers.ident.Nat_sub
+ | @nil t => fun _ => @Compilers.ident.nil _
+ | @cons t => fun _ => @Compilers.ident.cons _
+ | @pair A B => fun _ => @Compilers.ident.pair _ _
+ | @fst A B => fun _ => @Compilers.ident.fst _ _
+ | @snd A B => fun _ => @Compilers.ident.snd _ _
+ | @prod_rect A B T => fun _ => @Compilers.ident.prod_rect _ _ _
+ | @bool_rect T => fun _ => @Compilers.ident.bool_rect _
+ | @nat_rect P => fun _ => @Compilers.ident.nat_rect _
+ | @nat_rect_arrow P Q => fun _ => @Compilers.ident.nat_rect_arrow _ _
+ | @list_rect A P => fun _ => @Compilers.ident.list_rect _ _
+ | @list_case A P => fun _ => @Compilers.ident.list_case _ _
+ | @List_length T => fun _ => @Compilers.ident.List_length _
+ | @List_seq => fun _ => @Compilers.ident.List_seq
+ | @List_firstn A => fun _ => @Compilers.ident.List_firstn _
+ | @List_skipn A => fun _ => @Compilers.ident.List_skipn _
+ | @List_repeat A => fun _ => @Compilers.ident.List_repeat _
+ | @List_combine A B => fun _ => @Compilers.ident.List_combine _ _
+ | @List_map A B => fun _ => @Compilers.ident.List_map _ _
+ | @List_app A => fun _ => @Compilers.ident.List_app _
+ | @List_rev A => fun _ => @Compilers.ident.List_rev _
+ | @List_flat_map A B => fun _ => @Compilers.ident.List_flat_map _ _
+ | @List_partition A => fun _ => @Compilers.ident.List_partition _
+ | @List_fold_right A B => fun _ => @Compilers.ident.List_fold_right _ _
+ | @List_update_nth T => fun _ => @Compilers.ident.List_update_nth _
+ | @List_nth_default T => fun _ => @Compilers.ident.List_nth_default _
+ | @Z_add => fun _ => @Compilers.ident.Z_add
+ | @Z_mul => fun _ => @Compilers.ident.Z_mul
+ | @Z_pow => fun _ => @Compilers.ident.Z_pow
+ | @Z_sub => fun _ => @Compilers.ident.Z_sub
+ | @Z_opp => fun _ => @Compilers.ident.Z_opp
+ | @Z_div => fun _ => @Compilers.ident.Z_div
+ | @Z_modulo => fun _ => @Compilers.ident.Z_modulo
+ | @Z_log2 => fun _ => @Compilers.ident.Z_log2
+ | @Z_log2_up => fun _ => @Compilers.ident.Z_log2_up
+ | @Z_eqb => fun _ => @Compilers.ident.Z_eqb
+ | @Z_leb => fun _ => @Compilers.ident.Z_leb
+ | @Z_geb => fun _ => @Compilers.ident.Z_geb
+ | @Z_of_nat => fun _ => @Compilers.ident.Z_of_nat
+ | @Z_to_nat => fun _ => @Compilers.ident.Z_to_nat
+ | @Z_shiftr => fun _ => @Compilers.ident.Z_shiftr
+ | @Z_shiftl => fun _ => @Compilers.ident.Z_shiftl
+ | @Z_land => fun _ => @Compilers.ident.Z_land
+ | @Z_lor => fun _ => @Compilers.ident.Z_lor
+ | @Z_bneg => fun _ => @Compilers.ident.Z_bneg
+ | @Z_lnot_modulo => fun _ => @Compilers.ident.Z_lnot_modulo
+ | @Z_mul_split => fun _ => @Compilers.ident.Z_mul_split
+ | @Z_add_get_carry => fun _ => @Compilers.ident.Z_add_get_carry
+ | @Z_add_with_carry => fun _ => @Compilers.ident.Z_add_with_carry
+ | @Z_add_with_get_carry => fun _ => @Compilers.ident.Z_add_with_get_carry
+ | @Z_sub_get_borrow => fun _ => @Compilers.ident.Z_sub_get_borrow
+ | @Z_sub_with_get_borrow => fun _ => @Compilers.ident.Z_sub_with_get_borrow
+ | @Z_zselect => fun _ => @Compilers.ident.Z_zselect
+ | @Z_add_modulo => fun _ => @Compilers.ident.Z_add_modulo
+ | @Z_rshi => fun _ => @Compilers.ident.Z_rshi
+ | @Z_cc_m => fun _ => @Compilers.ident.Z_cc_m
+ | @Z_cast => fun arg => let '(range, _) := eta2r arg in @Compilers.ident.Z_cast range
+ | @Z_cast2 => fun arg => let '(range, _) := eta2r arg in @Compilers.ident.Z_cast2 range
+ | @fancy_add => fun arg => let '(log2wordmax, (imm, _)) := eta3r arg in @Compilers.ident.fancy_add log2wordmax imm
+ | @fancy_addc => fun arg => let '(log2wordmax, (imm, _)) := eta3r arg in @Compilers.ident.fancy_addc log2wordmax imm
+ | @fancy_sub => fun arg => let '(log2wordmax, (imm, _)) := eta3r arg in @Compilers.ident.fancy_sub log2wordmax imm
+ | @fancy_subb => fun arg => let '(log2wordmax, (imm, _)) := eta3r arg in @Compilers.ident.fancy_subb log2wordmax imm
+ | @fancy_mulll => fun arg => let '(log2wordmax, _) := eta2r arg in @Compilers.ident.fancy_mulll log2wordmax
+ | @fancy_mullh => fun arg => let '(log2wordmax, _) := eta2r arg in @Compilers.ident.fancy_mullh log2wordmax
+ | @fancy_mulhl => fun arg => let '(log2wordmax, _) := eta2r arg in @Compilers.ident.fancy_mulhl log2wordmax
+ | @fancy_mulhh => fun arg => let '(log2wordmax, _) := eta2r arg in @Compilers.ident.fancy_mulhh log2wordmax
+ | @fancy_rshi => fun arg => let '(log2wordmax, (x, _)) := eta3r arg in @Compilers.ident.fancy_rshi log2wordmax x
+ | @fancy_selc => fun _ => @Compilers.ident.fancy_selc
+ | @fancy_selm => fun arg => let '(log2wordmax, _) := eta2r arg in @Compilers.ident.fancy_selm log2wordmax
+ | @fancy_sell => fun _ => @Compilers.ident.fancy_sell
+ | @fancy_addm => fun _ => @Compilers.ident.fancy_addm
+ end.
+
Definition unify {t t'} (pidc : ident t) (idc : Compilers.ident.ident t') : option (type_of_list (@arg_types t pidc))
:= match pidc, idc return option (type_of_list (arg_types pidc)) with
| @Literal Compilers.base.type.unit, Compilers.ident.Literal Compilers.base.type.unit v => Some (v, tt)
diff --git a/src/Experiments/NewPipeline/Rewriter.v b/src/Experiments/NewPipeline/Rewriter.v
index 396e6c595..93fa108d6 100644
--- a/src/Experiments/NewPipeline/Rewriter.v
+++ b/src/Experiments/NewPipeline/Rewriter.v
@@ -51,18 +51,6 @@ Module Compilers.
| type.list A => option_map Compilers.base.type.list (subst A evar_map)
end%option.
- Fixpoint subst_default (ptype : type) (evar_map : EvarMap) : Compilers.base.type
- := match ptype with
- | type.var p => match PositiveMap.find p evar_map with
- | Some t => t
- | None => Compilers.base.type.type_base base.type.unit
- end
- | type.type_base t => Compilers.base.type.type_base t
- | type.prod A B
- => Compilers.base.type.prod (subst_default A evar_map) (subst_default B evar_map)
- | type.list A => Compilers.base.type.list (subst_default A evar_map)
- end.
-
Fixpoint subst_default_relax P {t evm} : P t -> P (subst_default (relax t) evm)
:= match t return P t -> P (subst_default (relax t) evm) with
| Compilers.base.type.type_base t => fun x => x
@@ -163,14 +151,8 @@ Module Compilers.
Some (type.arrow s' d'))
end%option.
- Fixpoint subst_default (ptype : type) (evar_map : EvarMap) : type.type Compilers.base.type
- := match ptype with
- | type.base t => type.base (base.subst_default t evar_map)
- | type.arrow A B => type.arrow (subst_default A evar_map) (subst_default B evar_map)
- end.
-
- Fixpoint subst_default_relax P {t evm} : P t -> P (subst_default (type.relax t) evm)
- := match t return P t -> P (subst_default (type.relax t) evm) with
+ Fixpoint subst_default_relax P {t evm} : P t -> P (type.subst_default (type.relax t) evm)
+ := match t return P t -> P (type.subst_default (type.relax t) evm) with
| type.base t => base.subst_default_relax (fun t => P (type.base t))
| type.arrow A B
=> fun v