diff options
author | Jason Gross <jgross@mit.edu> | 2018-07-14 19:37:32 +0100 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-07-15 21:57:47 +0100 |
commit | 83c6684b5d5c3f3dbdc68275290dbb8be359ef01 (patch) | |
tree | 99ec60ecd67303aa7115b6e980d11933ea0561fb /src | |
parent | f79fdb77c7baff92444204c00787d2c95da18997 (diff) |
Allow reification of nat_rect (fun _ => _ -> _)
We now support reification of nat_rect returning an arrow. This is
needed for montgomery.
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/NewPipeline/AbstractInterpretation.v | 12 | ||||
-rw-r--r-- | src/Experiments/NewPipeline/CStringification.v | 4 | ||||
-rw-r--r-- | src/Experiments/NewPipeline/CompilersTestCases.v | 17 | ||||
-rw-r--r-- | src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v | 28 | ||||
-rw-r--r-- | src/Experiments/NewPipeline/Language.v | 12 | ||||
-rw-r--r-- | src/Experiments/NewPipeline/Rewriter.v | 6 |
6 files changed, 78 insertions, 1 deletions
diff --git a/src/Experiments/NewPipeline/AbstractInterpretation.v b/src/Experiments/NewPipeline/AbstractInterpretation.v index 3f905a869..e550ae150 100644 --- a/src/Experiments/NewPipeline/AbstractInterpretation.v +++ b/src/Experiments/NewPipeline/AbstractInterpretation.v @@ -437,6 +437,18 @@ Module Compilers. n | None => ZRange.type.base.option.None end + | ident.nat_rect_arrow _ _ + => fun O_case S_case n v + => match n with + | Some n + => nat_rect + _ + O_case + (fun n' rec => S_case (Some n') rec) + n + v + | None => ZRange.type.base.option.None + end | ident.list_rect _ _ => fun N C ls => match ls with diff --git a/src/Experiments/NewPipeline/CStringification.v b/src/Experiments/NewPipeline/CStringification.v index bb94402f4..ffe23f2ef 100644 --- a/src/Experiments/NewPipeline/CStringification.v +++ b/src/Experiments/NewPipeline/CStringification.v @@ -220,6 +220,8 @@ Module Compilers. | ident.bool_rect T => fun '(t, (f, ((b, br), tt))) => (fun lvl => maybe_wrap_parens (Nat.ltb lvl 200) ("if " ++ b 200%nat ++ " then " ++ maybe_wrap_cast with_casts t 200%nat ++ " else " ++ maybe_wrap_cast with_casts f 200%nat), ZRange.type.base.option.None) | ident.nat_rect P => fun args => (show_application with_casts (fun _ => "nat_rect") args, ZRange.type.base.option.None) + | ident.nat_rect_arrow P Q + => fun args => (show_application with_casts (fun _ => "nat_rect(→)") args, ZRange.type.base.option.None) | ident.list_rect A P => fun args => (show_application with_casts (fun _ => "list_rect") args, ZRange.type.base.option.None) | ident.list_case A P @@ -342,6 +344,7 @@ Module Compilers. | ident.prod_rect A B T => "prod_rect" | ident.bool_rect T => "bool_rect" | ident.nat_rect P => "nat_rect" + | ident.nat_rect_arrow P Q => "nat_rect(→)" | ident.list_rect A P => "list_rect" | ident.list_case A P => "list_case" | ident.List_length T => "length" @@ -1135,6 +1138,7 @@ Module Compilers. | ident.prod_rect _ _ _ | ident.bool_rect _ | ident.nat_rect _ + | ident.nat_rect_arrow _ _ | ident.list_rect _ _ | ident.list_case _ _ | ident.List_length _ diff --git a/src/Experiments/NewPipeline/CompilersTestCases.v b/src/Experiments/NewPipeline/CompilersTestCases.v index f69a0db42..0fa697fbe 100644 --- a/src/Experiments/NewPipeline/CompilersTestCases.v +++ b/src/Experiments/NewPipeline/CompilersTestCases.v @@ -374,3 +374,20 @@ Module test12. exact I. Qed. End test12. +Module test13. + Example test13 : True. + Proof. + let v0 := constr:(nat_rect (fun _ => nat -> nat) (fun v => v) (fun n' rec v => (n' + rec (S v))%nat) 3 0%nat) in + let v := Reify v0 in + pose v as E; + pose v0 as exp. + vm_compute in E. + vm_compute in exp. + pose (PartialEvaluate E) as E'. + vm_compute in E'. + clear E. + let r := Reify exp in + unify r E'. + exact I. + Qed. +End test13. diff --git a/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v b/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v index 3b376ae4e..a6ad5a725 100644 --- a/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v +++ b/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v @@ -109,6 +109,20 @@ print_ident = r"""Inductive ident : defaults.type -> Set := (fun x : base.type => type.base x) (base.type.type_base base.type.nat) -> (fun x : base.type => type.base x) P) + | nat_rect_arrow : forall P Q : base.type, + ident + (((fun x : base.type => type.base x) P -> + (fun x : base.type => type.base x) Q) -> + ((fun x : base.type => type.base x) + (base.type.type_base base.type.nat) -> + ((fun x : base.type => type.base x) P -> + (fun x : base.type => type.base x) Q) -> + (fun x : base.type => type.base x) P -> + (fun x : base.type => type.base x) Q) -> + (fun x : base.type => type.base x) + (base.type.type_base base.type.nat) -> + (fun x : base.type => type.base x) P -> + (fun x : base.type => type.base x) Q) | list_rect : forall A P : base.type, ident (((fun x : base.type => type.base x) ()%etype -> @@ -570,6 +584,7 @@ show_match_ident = r"""match # with | ident.prod_rect A B T => | ident.bool_rect T => | ident.nat_rect P => + | ident.nat_rect_arrow P Q => | ident.list_rect A P => | ident.list_case A P => | ident.List_length T => @@ -859,6 +874,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | prod_rect | bool_rect | nat_rect + | nat_rect_arrow | list_rect | list_case | List_length @@ -939,6 +955,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | prod_rect, prod_rect | bool_rect, bool_rect | nat_rect, nat_rect + | nat_rect_arrow, nat_rect_arrow | list_rect, list_rect | list_case, list_case | List_length, List_length @@ -1017,6 +1034,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | prod_rect, _ | bool_rect, _ | nat_rect, _ + | nat_rect_arrow, _ | list_rect, _ | list_case, _ | List_length, _ @@ -1101,6 +1119,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Compilers.ident.prod_rect A B T => f _ (@Compilers.ident.prod_rect A B T) | Compilers.ident.bool_rect T => f _ (@Compilers.ident.bool_rect T) | Compilers.ident.nat_rect P => f _ (@Compilers.ident.nat_rect P) + | Compilers.ident.nat_rect_arrow P Q => f _ (@Compilers.ident.nat_rect_arrow P Q) | Compilers.ident.list_rect A P => f _ (@Compilers.ident.list_rect A P) | Compilers.ident.list_case A P => f _ (@Compilers.ident.list_case A P) | Compilers.ident.List_length T => f _ (@Compilers.ident.List_length T) @@ -1182,6 +1201,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Compilers.ident.prod_rect A B T => prod_rect | Compilers.ident.bool_rect T => bool_rect | Compilers.ident.nat_rect P => nat_rect + | Compilers.ident.nat_rect_arrow P Q => nat_rect_arrow | Compilers.ident.list_rect A P => list_rect | Compilers.ident.list_case A P => list_case | Compilers.ident.List_length T => List_length @@ -1263,6 +1283,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | prod_rect => None | bool_rect => None | nat_rect => None + | nat_rect_arrow => None | list_rect => None | list_case => None | List_length => None @@ -1344,6 +1365,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | prod_rect => base.type * base.type * base.type | bool_rect => base.type | nat_rect => base.type + | nat_rect_arrow => base.type * base.type | list_rect => base.type * base.type | list_case => base.type * base.type | List_length => base.type @@ -1425,6 +1447,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Compilers.ident.prod_rect A B T => tt | Compilers.ident.bool_rect T => tt | Compilers.ident.nat_rect P => tt + | Compilers.ident.nat_rect_arrow P Q => tt | Compilers.ident.list_rect A P => tt | Compilers.ident.list_case A P => tt | Compilers.ident.List_length T => tt @@ -1506,6 +1529,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | prod_rect, Compilers.ident.prod_rect A B T => Some (A, B, T) | bool_rect, Compilers.ident.bool_rect T => Some T | nat_rect, Compilers.ident.nat_rect P => Some P + | nat_rect_arrow, Compilers.ident.nat_rect_arrow P Q => Some (P, Q) | list_rect, Compilers.ident.list_rect A P => Some (A, P) | list_case, Compilers.ident.list_case A P => Some (A, P) | List_length, Compilers.ident.List_length T => Some T @@ -1583,6 +1607,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | prod_rect, _ | bool_rect, _ | nat_rect, _ + | nat_rect_arrow, _ | list_rect, _ | list_case, _ | List_length, _ @@ -1668,6 +1693,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | prod_rect => fun arg => let '(A, B, T) := eta3 arg in ((type.base A -> type.base B -> type.base T) -> type.base (A * B)%etype -> type.base T) | bool_rect => fun T => ((type.base ()%etype -> type.base T) -> (type.base ()%etype -> type.base T) -> type.base (base.type.type_base base.type.bool) -> type.base T) | nat_rect => fun P => ((type.base ()%etype -> type.base P) -> (type.base (base.type.type_base base.type.nat) -> type.base P -> type.base P) -> type.base (base.type.type_base base.type.nat) -> type.base P) + | nat_rect_arrow => fun arg => let '(P, Q) := eta2 arg in ((type.base P -> type.base Q) -> (type.base (base.type.type_base base.type.nat) -> (type.base P -> type.base Q) -> type.base P -> type.base Q) -> type.base (base.type.type_base base.type.nat) -> type.base P -> type.base Q) | list_rect => fun arg => let '(A, P) := eta2 arg in ((type.base ()%etype -> type.base P) -> (type.base A -> type.base (base.type.list A) -> type.base P -> type.base P) -> type.base (base.type.list A) -> type.base P) | list_case => fun arg => let '(A, P) := eta2 arg in ((type.base ()%etype -> type.base P) -> (type.base A -> type.base (base.type.list A) -> type.base P) -> type.base (base.type.list A) -> type.base P) | List_length => fun T => (type.base (base.type.list T) -> type.base (base.type.type_base base.type.nat)) @@ -1749,6 +1775,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | prod_rect => fun arg => match eta3 arg as args' return Compilers.ident.ident (type_of prod_rect args') with (A, B, T) => @Compilers.ident.prod_rect A B T end | bool_rect => fun T => @Compilers.ident.bool_rect T | nat_rect => fun P => @Compilers.ident.nat_rect P + | nat_rect_arrow => fun arg => match eta2 arg as args' return Compilers.ident.ident (type_of nat_rect_arrow args') with (P, Q) => @Compilers.ident.nat_rect_arrow P Q end | list_rect => fun arg => match eta2 arg as args' return Compilers.ident.ident (type_of list_rect args') with (A, P) => @Compilers.ident.list_rect A P end | list_case => fun arg => match eta2 arg as args' return Compilers.ident.ident (type_of list_case args') with (A, P) => @Compilers.ident.list_case A P end | List_length => fun T => @Compilers.ident.List_length T @@ -1830,6 +1857,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f: | Compilers.ident.prod_rect A B T => fun _ => @Compilers.ident.prod_rect A B T | Compilers.ident.bool_rect T => fun _ => @Compilers.ident.bool_rect T | Compilers.ident.nat_rect P => fun _ => @Compilers.ident.nat_rect P + | Compilers.ident.nat_rect_arrow P Q => fun _ => @Compilers.ident.nat_rect_arrow P Q | Compilers.ident.list_rect A P => fun _ => @Compilers.ident.list_rect A P | Compilers.ident.list_case A P => fun _ => @Compilers.ident.list_case A P | Compilers.ident.List_length T => fun _ => @Compilers.ident.List_length T diff --git a/src/Experiments/NewPipeline/Language.v b/src/Experiments/NewPipeline/Language.v index f4defdb6f..08da2ff6c 100644 --- a/src/Experiments/NewPipeline/Language.v +++ b/src/Experiments/NewPipeline/Language.v @@ -761,6 +761,7 @@ Module Compilers. | prod_rect {A B T:base.type} : ident ((A -> B -> T) -> A * B -> T) | bool_rect {T:base.type} : ident ((unit -> T) -> (unit -> T) -> bool -> T) | nat_rect {P:base.type} : ident ((unit -> P) -> (nat -> P -> P) -> nat -> P) + | nat_rect_arrow {P Q:base.type} : ident ((P -> Q) -> (nat -> (P -> Q) -> (P -> Q)) -> nat -> P -> Q) | list_rect {A P:base.type} : ident ((unit -> P) -> (A -> list A -> P -> P) -> list A -> P) | list_case {A P:base.type} : ident ((unit -> P) -> (A -> list A -> P) -> list A -> P) | List_length {T} : ident (list T -> nat) @@ -884,6 +885,8 @@ Module Compilers. => fun t f => Datatypes.bool_rect _ (t tt) (f tt) | nat_rect P => fun O_case S_case => Datatypes.nat_rect _ (O_case tt) S_case + | nat_rect_arrow P Q + => fun O_case S_case => Datatypes.nat_rect _ O_case S_case | list_rect A P => fun N_case C_case => Datatypes.list_rect _ (N_case tt) C_case | list_case A P @@ -1051,7 +1054,14 @@ Module Compilers. let rT := base.reify T in then_tac (@ident.prod_rect rA rB rT) | @Datatypes.nat_rect (fun _ => ?T) ?P0 - => reify_rec (@Thunked.nat_rect T (fun _ : Datatypes.unit => P0)) + => lazymatch T with + | _ -> _ => else_tac () + | _ => reify_rec (@Thunked.nat_rect T (fun _ : Datatypes.unit => P0)) + end + | @Datatypes.nat_rect (fun _ => ?P -> ?Q) + => let rP := base.reify P in + let rQ := base.reify Q in + then_tac (@ident.nat_rect_arrow rP rQ) | @Thunked.nat_rect ?T => let rT := base.reify T in then_tac (@ident.nat_rect rT) diff --git a/src/Experiments/NewPipeline/Rewriter.v b/src/Experiments/NewPipeline/Rewriter.v index a3f9f4e8c..5fe127ca0 100644 --- a/src/Experiments/NewPipeline/Rewriter.v +++ b/src/Experiments/NewPipeline/Rewriter.v @@ -1240,6 +1240,12 @@ In the RHS, the follow notation applies: => S_case <- @castv _ (@type.base base.type base.type.nat -> type.base P -> type.base P) S_case; ret (nat_rect _ (O_case ##tt) (fun n' rec => rec <-- rec; S_case ##n' rec) n)) ; make_rewrite + (#pident.nat_rect_arrow @ ??{?? -> ??} @ ??{base.type.nat -> (?? -> ??) -> (?? -> ??)} @ #?ℕ @ ??) + (fun P Q O_case _ _ _ _ S_case n _ v + => S_case <- @castv _ (@type.base base.type base.type.nat -> (type.base P -> type.base Q) -> (type.base P -> type.base Q)) S_case; + v <- castbe v; + ret (nat_rect _ O_case (fun n' rec v => S_case ##n' rec v) n v)) + ; make_rewrite (#pident.List_length @ ??{list ??}) (fun _ xs => xs <- reflect_list_cps xs; ##(List.length xs)) ; make_rewrite |