aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2018-07-14 19:37:32 +0100
committerGravatar Jason Gross <jasongross9@gmail.com>2018-07-15 21:57:47 +0100
commit83c6684b5d5c3f3dbdc68275290dbb8be359ef01 (patch)
tree99ec60ecd67303aa7115b6e980d11933ea0561fb /src
parentf79fdb77c7baff92444204c00787d2c95da18997 (diff)
Allow reification of nat_rect (fun _ => _ -> _)
We now support reification of nat_rect returning an arrow. This is needed for montgomery.
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/NewPipeline/AbstractInterpretation.v12
-rw-r--r--src/Experiments/NewPipeline/CStringification.v4
-rw-r--r--src/Experiments/NewPipeline/CompilersTestCases.v17
-rw-r--r--src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v28
-rw-r--r--src/Experiments/NewPipeline/Language.v12
-rw-r--r--src/Experiments/NewPipeline/Rewriter.v6
6 files changed, 78 insertions, 1 deletions
diff --git a/src/Experiments/NewPipeline/AbstractInterpretation.v b/src/Experiments/NewPipeline/AbstractInterpretation.v
index 3f905a869..e550ae150 100644
--- a/src/Experiments/NewPipeline/AbstractInterpretation.v
+++ b/src/Experiments/NewPipeline/AbstractInterpretation.v
@@ -437,6 +437,18 @@ Module Compilers.
n
| None => ZRange.type.base.option.None
end
+ | ident.nat_rect_arrow _ _
+ => fun O_case S_case n v
+ => match n with
+ | Some n
+ => nat_rect
+ _
+ O_case
+ (fun n' rec => S_case (Some n') rec)
+ n
+ v
+ | None => ZRange.type.base.option.None
+ end
| ident.list_rect _ _
=> fun N C ls
=> match ls with
diff --git a/src/Experiments/NewPipeline/CStringification.v b/src/Experiments/NewPipeline/CStringification.v
index bb94402f4..ffe23f2ef 100644
--- a/src/Experiments/NewPipeline/CStringification.v
+++ b/src/Experiments/NewPipeline/CStringification.v
@@ -220,6 +220,8 @@ Module Compilers.
| ident.bool_rect T => fun '(t, (f, ((b, br), tt))) => (fun lvl => maybe_wrap_parens (Nat.ltb lvl 200) ("if " ++ b 200%nat ++ " then " ++ maybe_wrap_cast with_casts t 200%nat ++ " else " ++ maybe_wrap_cast with_casts f 200%nat), ZRange.type.base.option.None)
| ident.nat_rect P
=> fun args => (show_application with_casts (fun _ => "nat_rect") args, ZRange.type.base.option.None)
+ | ident.nat_rect_arrow P Q
+ => fun args => (show_application with_casts (fun _ => "nat_rect(→)") args, ZRange.type.base.option.None)
| ident.list_rect A P
=> fun args => (show_application with_casts (fun _ => "list_rect") args, ZRange.type.base.option.None)
| ident.list_case A P
@@ -342,6 +344,7 @@ Module Compilers.
| ident.prod_rect A B T => "prod_rect"
| ident.bool_rect T => "bool_rect"
| ident.nat_rect P => "nat_rect"
+ | ident.nat_rect_arrow P Q => "nat_rect(→)"
| ident.list_rect A P => "list_rect"
| ident.list_case A P => "list_case"
| ident.List_length T => "length"
@@ -1135,6 +1138,7 @@ Module Compilers.
| ident.prod_rect _ _ _
| ident.bool_rect _
| ident.nat_rect _
+ | ident.nat_rect_arrow _ _
| ident.list_rect _ _
| ident.list_case _ _
| ident.List_length _
diff --git a/src/Experiments/NewPipeline/CompilersTestCases.v b/src/Experiments/NewPipeline/CompilersTestCases.v
index f69a0db42..0fa697fbe 100644
--- a/src/Experiments/NewPipeline/CompilersTestCases.v
+++ b/src/Experiments/NewPipeline/CompilersTestCases.v
@@ -374,3 +374,20 @@ Module test12.
exact I.
Qed.
End test12.
+Module test13.
+ Example test13 : True.
+ Proof.
+ let v0 := constr:(nat_rect (fun _ => nat -> nat) (fun v => v) (fun n' rec v => (n' + rec (S v))%nat) 3 0%nat) in
+ let v := Reify v0 in
+ pose v as E;
+ pose v0 as exp.
+ vm_compute in E.
+ vm_compute in exp.
+ pose (PartialEvaluate E) as E'.
+ vm_compute in E'.
+ clear E.
+ let r := Reify exp in
+ unify r E'.
+ exact I.
+ Qed.
+End test13.
diff --git a/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v b/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v
index 3b376ae4e..a6ad5a725 100644
--- a/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v
+++ b/src/Experiments/NewPipeline/GENERATEDIdentifiersWithoutTypes.v
@@ -109,6 +109,20 @@ print_ident = r"""Inductive ident : defaults.type -> Set :=
(fun x : base.type => type.base x)
(base.type.type_base base.type.nat) ->
(fun x : base.type => type.base x) P)
+ | nat_rect_arrow : forall P Q : base.type,
+ ident
+ (((fun x : base.type => type.base x) P ->
+ (fun x : base.type => type.base x) Q) ->
+ ((fun x : base.type => type.base x)
+ (base.type.type_base base.type.nat) ->
+ ((fun x : base.type => type.base x) P ->
+ (fun x : base.type => type.base x) Q) ->
+ (fun x : base.type => type.base x) P ->
+ (fun x : base.type => type.base x) Q) ->
+ (fun x : base.type => type.base x)
+ (base.type.type_base base.type.nat) ->
+ (fun x : base.type => type.base x) P ->
+ (fun x : base.type => type.base x) Q)
| list_rect : forall A P : base.type,
ident
(((fun x : base.type => type.base x) ()%etype ->
@@ -570,6 +584,7 @@ show_match_ident = r"""match # with
| ident.prod_rect A B T =>
| ident.bool_rect T =>
| ident.nat_rect P =>
+ | ident.nat_rect_arrow P Q =>
| ident.list_rect A P =>
| ident.list_case A P =>
| ident.List_length T =>
@@ -859,6 +874,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| prod_rect
| bool_rect
| nat_rect
+ | nat_rect_arrow
| list_rect
| list_case
| List_length
@@ -939,6 +955,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| prod_rect, prod_rect
| bool_rect, bool_rect
| nat_rect, nat_rect
+ | nat_rect_arrow, nat_rect_arrow
| list_rect, list_rect
| list_case, list_case
| List_length, List_length
@@ -1017,6 +1034,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| prod_rect, _
| bool_rect, _
| nat_rect, _
+ | nat_rect_arrow, _
| list_rect, _
| list_case, _
| List_length, _
@@ -1101,6 +1119,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| Compilers.ident.prod_rect A B T => f _ (@Compilers.ident.prod_rect A B T)
| Compilers.ident.bool_rect T => f _ (@Compilers.ident.bool_rect T)
| Compilers.ident.nat_rect P => f _ (@Compilers.ident.nat_rect P)
+ | Compilers.ident.nat_rect_arrow P Q => f _ (@Compilers.ident.nat_rect_arrow P Q)
| Compilers.ident.list_rect A P => f _ (@Compilers.ident.list_rect A P)
| Compilers.ident.list_case A P => f _ (@Compilers.ident.list_case A P)
| Compilers.ident.List_length T => f _ (@Compilers.ident.List_length T)
@@ -1182,6 +1201,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| Compilers.ident.prod_rect A B T => prod_rect
| Compilers.ident.bool_rect T => bool_rect
| Compilers.ident.nat_rect P => nat_rect
+ | Compilers.ident.nat_rect_arrow P Q => nat_rect_arrow
| Compilers.ident.list_rect A P => list_rect
| Compilers.ident.list_case A P => list_case
| Compilers.ident.List_length T => List_length
@@ -1263,6 +1283,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| prod_rect => None
| bool_rect => None
| nat_rect => None
+ | nat_rect_arrow => None
| list_rect => None
| list_case => None
| List_length => None
@@ -1344,6 +1365,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| prod_rect => base.type * base.type * base.type
| bool_rect => base.type
| nat_rect => base.type
+ | nat_rect_arrow => base.type * base.type
| list_rect => base.type * base.type
| list_case => base.type * base.type
| List_length => base.type
@@ -1425,6 +1447,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| Compilers.ident.prod_rect A B T => tt
| Compilers.ident.bool_rect T => tt
| Compilers.ident.nat_rect P => tt
+ | Compilers.ident.nat_rect_arrow P Q => tt
| Compilers.ident.list_rect A P => tt
| Compilers.ident.list_case A P => tt
| Compilers.ident.List_length T => tt
@@ -1506,6 +1529,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| prod_rect, Compilers.ident.prod_rect A B T => Some (A, B, T)
| bool_rect, Compilers.ident.bool_rect T => Some T
| nat_rect, Compilers.ident.nat_rect P => Some P
+ | nat_rect_arrow, Compilers.ident.nat_rect_arrow P Q => Some (P, Q)
| list_rect, Compilers.ident.list_rect A P => Some (A, P)
| list_case, Compilers.ident.list_case A P => Some (A, P)
| List_length, Compilers.ident.List_length T => Some T
@@ -1583,6 +1607,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| prod_rect, _
| bool_rect, _
| nat_rect, _
+ | nat_rect_arrow, _
| list_rect, _
| list_case, _
| List_length, _
@@ -1668,6 +1693,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| prod_rect => fun arg => let '(A, B, T) := eta3 arg in ((type.base A -> type.base B -> type.base T) -> type.base (A * B)%etype -> type.base T)
| bool_rect => fun T => ((type.base ()%etype -> type.base T) -> (type.base ()%etype -> type.base T) -> type.base (base.type.type_base base.type.bool) -> type.base T)
| nat_rect => fun P => ((type.base ()%etype -> type.base P) -> (type.base (base.type.type_base base.type.nat) -> type.base P -> type.base P) -> type.base (base.type.type_base base.type.nat) -> type.base P)
+ | nat_rect_arrow => fun arg => let '(P, Q) := eta2 arg in ((type.base P -> type.base Q) -> (type.base (base.type.type_base base.type.nat) -> (type.base P -> type.base Q) -> type.base P -> type.base Q) -> type.base (base.type.type_base base.type.nat) -> type.base P -> type.base Q)
| list_rect => fun arg => let '(A, P) := eta2 arg in ((type.base ()%etype -> type.base P) -> (type.base A -> type.base (base.type.list A) -> type.base P -> type.base P) -> type.base (base.type.list A) -> type.base P)
| list_case => fun arg => let '(A, P) := eta2 arg in ((type.base ()%etype -> type.base P) -> (type.base A -> type.base (base.type.list A) -> type.base P) -> type.base (base.type.list A) -> type.base P)
| List_length => fun T => (type.base (base.type.list T) -> type.base (base.type.type_base base.type.nat))
@@ -1749,6 +1775,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| prod_rect => fun arg => match eta3 arg as args' return Compilers.ident.ident (type_of prod_rect args') with (A, B, T) => @Compilers.ident.prod_rect A B T end
| bool_rect => fun T => @Compilers.ident.bool_rect T
| nat_rect => fun P => @Compilers.ident.nat_rect P
+ | nat_rect_arrow => fun arg => match eta2 arg as args' return Compilers.ident.ident (type_of nat_rect_arrow args') with (P, Q) => @Compilers.ident.nat_rect_arrow P Q end
| list_rect => fun arg => match eta2 arg as args' return Compilers.ident.ident (type_of list_rect args') with (A, P) => @Compilers.ident.list_rect A P end
| list_case => fun arg => match eta2 arg as args' return Compilers.ident.ident (type_of list_case args') with (A, P) => @Compilers.ident.list_case A P end
| List_length => fun T => @Compilers.ident.List_length T
@@ -1830,6 +1857,7 @@ with open('GENERATEDIdentifiersWithoutTypes.v', 'w') as f:
| Compilers.ident.prod_rect A B T => fun _ => @Compilers.ident.prod_rect A B T
| Compilers.ident.bool_rect T => fun _ => @Compilers.ident.bool_rect T
| Compilers.ident.nat_rect P => fun _ => @Compilers.ident.nat_rect P
+ | Compilers.ident.nat_rect_arrow P Q => fun _ => @Compilers.ident.nat_rect_arrow P Q
| Compilers.ident.list_rect A P => fun _ => @Compilers.ident.list_rect A P
| Compilers.ident.list_case A P => fun _ => @Compilers.ident.list_case A P
| Compilers.ident.List_length T => fun _ => @Compilers.ident.List_length T
diff --git a/src/Experiments/NewPipeline/Language.v b/src/Experiments/NewPipeline/Language.v
index f4defdb6f..08da2ff6c 100644
--- a/src/Experiments/NewPipeline/Language.v
+++ b/src/Experiments/NewPipeline/Language.v
@@ -761,6 +761,7 @@ Module Compilers.
| prod_rect {A B T:base.type} : ident ((A -> B -> T) -> A * B -> T)
| bool_rect {T:base.type} : ident ((unit -> T) -> (unit -> T) -> bool -> T)
| nat_rect {P:base.type} : ident ((unit -> P) -> (nat -> P -> P) -> nat -> P)
+ | nat_rect_arrow {P Q:base.type} : ident ((P -> Q) -> (nat -> (P -> Q) -> (P -> Q)) -> nat -> P -> Q)
| list_rect {A P:base.type} : ident ((unit -> P) -> (A -> list A -> P -> P) -> list A -> P)
| list_case {A P:base.type} : ident ((unit -> P) -> (A -> list A -> P) -> list A -> P)
| List_length {T} : ident (list T -> nat)
@@ -884,6 +885,8 @@ Module Compilers.
=> fun t f => Datatypes.bool_rect _ (t tt) (f tt)
| nat_rect P
=> fun O_case S_case => Datatypes.nat_rect _ (O_case tt) S_case
+ | nat_rect_arrow P Q
+ => fun O_case S_case => Datatypes.nat_rect _ O_case S_case
| list_rect A P
=> fun N_case C_case => Datatypes.list_rect _ (N_case tt) C_case
| list_case A P
@@ -1051,7 +1054,14 @@ Module Compilers.
let rT := base.reify T in
then_tac (@ident.prod_rect rA rB rT)
| @Datatypes.nat_rect (fun _ => ?T) ?P0
- => reify_rec (@Thunked.nat_rect T (fun _ : Datatypes.unit => P0))
+ => lazymatch T with
+ | _ -> _ => else_tac ()
+ | _ => reify_rec (@Thunked.nat_rect T (fun _ : Datatypes.unit => P0))
+ end
+ | @Datatypes.nat_rect (fun _ => ?P -> ?Q)
+ => let rP := base.reify P in
+ let rQ := base.reify Q in
+ then_tac (@ident.nat_rect_arrow rP rQ)
| @Thunked.nat_rect ?T
=> let rT := base.reify T in
then_tac (@ident.nat_rect rT)
diff --git a/src/Experiments/NewPipeline/Rewriter.v b/src/Experiments/NewPipeline/Rewriter.v
index a3f9f4e8c..5fe127ca0 100644
--- a/src/Experiments/NewPipeline/Rewriter.v
+++ b/src/Experiments/NewPipeline/Rewriter.v
@@ -1240,6 +1240,12 @@ In the RHS, the follow notation applies:
=> S_case <- @castv _ (@type.base base.type base.type.nat -> type.base P -> type.base P) S_case;
ret (nat_rect _ (O_case ##tt) (fun n' rec => rec <-- rec; S_case ##n' rec) n))
; make_rewrite
+ (#pident.nat_rect_arrow @ ??{?? -> ??} @ ??{base.type.nat -> (?? -> ??) -> (?? -> ??)} @ #?ℕ @ ??)
+ (fun P Q O_case _ _ _ _ S_case n _ v
+ => S_case <- @castv _ (@type.base base.type base.type.nat -> (type.base P -> type.base Q) -> (type.base P -> type.base Q)) S_case;
+ v <- castbe v;
+ ret (nat_rect _ O_case (fun n' rec v => S_case ##n' rec v) n v))
+ ; make_rewrite
(#pident.List_length @ ??{list ??})
(fun _ xs => xs <- reflect_list_cps xs; ##(List.length xs))
; make_rewrite