aboutsummaryrefslogtreecommitdiff
path: root/src/Reflection/InputSyntax.v
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-01-11 21:02:50 -0500
committerGravatar Jason Gross <jasongross9@gmail.com>2017-03-01 11:45:47 -0500
commit6b3048c37ad348dc88ecc03ef892ecfb121bfa7f (patch)
tree351e5438c5664ab0caf08b9d5054f296ff4aa2ee /src/Reflection/InputSyntax.v
parent80dc66a34fbf031bfac1214ccbb3bb1dcdef3d39 (diff)
Switch to fully uncurried form for reflection
This will eventually make a number of proofs easier. Unfortunately, the correctness lemmas for AddCoordinates and LadderStep no longer work (because of different arities), and there's a proof in Experiments/Ed25519 that I've admitted. The correctness lemmas will be easy to re-add when we have a more general version that handle arbitrary type shapes.
Diffstat (limited to 'src/Reflection/InputSyntax.v')
-rw-r--r--src/Reflection/InputSyntax.v138
1 files changed, 115 insertions, 23 deletions
diff --git a/src/Reflection/InputSyntax.v b/src/Reflection/InputSyntax.v
index 258241391..12810d20d 100644
--- a/src/Reflection/InputSyntax.v
+++ b/src/Reflection/InputSyntax.v
@@ -2,7 +2,7 @@
Require Import Coq.Strings.String.
Require Import Crypto.Reflection.Syntax.
Require Import Crypto.Reflection.SmartMap.
-Require Import Crypto.Reflection.Relations.
+Require Import Crypto.Reflection.ExprInversion.
Require Import Crypto.Reflection.InterpProofs.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.Tactics.
@@ -17,14 +17,20 @@ Section language.
Context (base_type_code : Type).
Local Notation flat_type := (flat_type base_type_code).
- Local Notation type := (type base_type_code).
+ Inductive type := Tflat (A : flat_type) | Arrow (A : flat_type) (B : type).
Section expr_param.
Context (interp_base_type : base_type_code -> Type).
Context (op : flat_type (* input tuple *) -> flat_type (* output type *) -> Type).
- Local Notation interp_type := (interp_type interp_base_type).
Local Notation interp_flat_type_gen := interp_flat_type.
Local Notation interp_flat_type := (interp_flat_type interp_base_type).
+
+ Fixpoint interp_type (t : type) :=
+ match t with
+ | Tflat A => interp_flat_type A
+ | Arrow A B => (interp_flat_type A -> interp_type B)%type
+ end.
+
Section expr.
Context {var : flat_type -> Type}.
@@ -37,9 +43,11 @@ Section language.
| Pair : forall {t1}, exprf t1 -> forall {t2}, exprf t2 -> exprf (Prod t1 t2)
| MatchPair : forall {t1 t2}, exprf (Prod t1 t2) -> forall {tC}, (var t1 -> var t2 -> exprf tC) -> exprf tC.
Inductive expr : type -> Type :=
- | Return {t} : exprf t -> expr t
- | Abs {src dst} : (var (Tbase src) -> expr dst) -> expr (Arrow src dst).
- Global Coercion Return : exprf >-> expr.
+ | Return {T} : exprf T -> expr (Tflat T)
+ | Abs {src dst} : (var src -> expr dst) -> expr (Arrow src dst).
+
+ Definition Fst {t1 t2} (v : exprf (Prod t1 t2)) : exprf t1 := MatchPair v (fun x y => Var x).
+ Definition Snd {t1 t2} (v : exprf (Prod t1 t2)) : exprf t2 := MatchPair v (fun x y => Var y).
End expr.
Definition Expr (t : type) := forall var, @expr var t.
@@ -69,6 +77,19 @@ Section language.
Context {var : base_type_code -> Type}
(make_const : forall t, interp_base_type t -> op Unit (Tbase t)).
+ Fixpoint compilet (t : type) : Syntax.type base_type_code
+ := Syntax.Arrow
+ match t with
+ | Tflat T => Unit
+ | Arrow A (Tflat B) => A
+ | Arrow A B
+ => A * domain (compilet B)
+ end%ctype
+ match t with
+ | Tflat T => T
+ | Arrow A B => codomain (compilet B)
+ end.
+
Fixpoint SmartConst (t : flat_type) : interp_flat_type t -> Syntax.exprf base_type_code op (var:=var) t
:= match t return interp_flat_type t -> Syntax.exprf _ _ t with
| Unit => fun _ => TT
@@ -87,16 +108,36 @@ Section language.
| MatchPair _ _ ex _ eC => Syntax.LetIn (@compilef _ ex) (fun xy => @compilef _ (eC (fst xy) (snd xy)))
end.
- Fixpoint compile {t} (e : @expr (interp_flat_type_gen var) t) : @Syntax.expr base_type_code op var t
- := match e in expr t return @Syntax.expr _ _ _ t with
- | Return _ x => Syntax.Return (compilef x)
- | Abs a _ f => Syntax.Abs (fun x : var a => @compile _ (f x))
- end.
+ (* ugh, so much manual annotation *)
+ Fixpoint compile {t} (e : @expr (interp_flat_type_gen var) t) : @Syntax.expr base_type_code op var (compilet t)
+ := match e in expr t return @Syntax.expr _ _ _ (compilet t) with
+ | Return _ v => Syntax.Abs (fun _ => compilef v)
+ | Abs src dst f
+ => let res := fun x => @compile _ (f x) in
+ match dst
+ return (_ -> Syntax.expr _ _ (compilet dst))
+ -> Syntax.expr _ _ (compilet (Arrow src dst))
+ with
+ | Tflat T
+ => fun resf => Syntax.Abs (fun x => invert_Abs (resf x) tt)
+ | Arrow A B as dst'
+ => match compilet dst' as cdst
+ return (_ -> Syntax.expr _ _ cdst)
+ -> Syntax.expr _ _ (Syntax.Arrow
+ (_ * domain cdst)
+ (codomain cdst))
+ with
+ | Syntax.Arrow A' B'
+ => fun resf => Syntax.Abs (fun x : interp_flat_type_gen var (_ * _)
+ => invert_Abs (resf (fst x)) (snd x))
+ end
+ end res
+ end.
End compile.
Definition Compile
(make_const : forall t, interp_base_type t -> op Unit (Tbase t))
- {t} (e : Expr t) : Syntax.Expr base_type_code op t
+ {t} (e : Expr t) : Syntax.Expr base_type_code op (compilet t)
:= fun var => compile make_const (e _).
Section compile_correct.
@@ -127,32 +168,83 @@ Section language.
end.
Qed.
- Lemma Compile_correct {t} (e : @Expr t)
- : interp_type_gen_rel_pointwise (fun _ => @eq _)
- (Syntax.Interp interp_op (Compile make_const e))
- (Interp interp_op e).
+ Lemma compile_flat_correct {T} (e : expr (Tflat T))
+ : forall x, Syntax.interp interp_op (compile make_const e) x = interp interp_op e.
+ Proof.
+ intros []; simpl.
+ let G := match goal with |- ?G => G end in
+ let G := match (eval pattern T, e in G) with ?G _ _ => G end in
+ refine match e in expr t return match t return expr t -> _ with
+ | Tflat T => G T
+ | _ => fun _ => True
+ end e
+ with
+ | Return _ _ => _
+ | Abs _ _ _ => I
+ end; simpl.
+ apply compilef_correct.
+ Qed.
+
+ Lemma Compile_flat_correct_flat {T} (e : Expr (Tflat T))
+ : forall x, Syntax.Interp interp_op (Compile make_const e) x = Interp interp_op e.
+ Proof. apply compile_flat_correct. Qed.
+
+ Lemma Compile_correct {src dst} (e : @Expr (Arrow src (Tflat dst)))
+ : forall x, Syntax.Interp interp_op (Compile make_const e) x = Interp interp_op e x.
Proof.
unfold Interp, Compile, Syntax.Interp; simpl.
pose (e interp_flat_type) as E.
repeat match goal with |- context[e ?f] => change (e f) with E end.
clearbody E; clear e.
- induction E.
- { apply compilef_correct. }
- { simpl; auto. }
+ let G := match goal with |- ?G => G end in
+ let G := match (eval pattern src, dst, E in G) with ?G _ _ _ => G end in
+ refine match E in expr t return match t return expr t -> _ with
+ | Arrow src (Tflat dst) => G src dst
+ | _ => fun _ => True
+ end E
+ with
+ | Abs src dst e
+ => match dst
+ return (forall e : _ -> expr dst,
+ match dst return expr (Arrow src dst) -> _ with
+ | Tflat dst => G src dst
+ | _ => fun _ => True
+ end (Abs e))
+ with
+ | Tflat _
+ => fun e0 x
+ => _
+ | Arrow _ _ => fun _ => I
+ end e
+ | Return _ _ => I
+ end; simpl.
+ refine match e0 x as e0x in expr t
+ return match t return expr t -> _ with
+ | Tflat _
+ => fun e0x
+ => Syntax.interpf _ (invert_Abs (compile _ e0x) _)
+ = interp _ e0x
+ | _ => fun _ => True
+ end e0x
+ with
+ | Abs _ _ _ => I
+ | Return _ _ => _
+ end; simpl.
+ apply compilef_correct.
Qed.
-
- Lemma Compile_flat_correct {t : flat_type} (e : @Expr t)
- : Syntax.Interp interp_op (Compile make_const e) = Interp interp_op e.
- Proof. exact (@Compile_correct t e). Qed.
End compile_correct.
End expr_param.
End language.
+Global Arguments Arrow {_} _ _.
+Global Arguments Tflat {_} _.
Global Arguments Const {_ _ _ _ _} _.
Global Arguments Var {_ _ _ _ _} _.
Global Arguments Op {_ _ _ _ _ _} _ _.
Global Arguments LetIn {_ _ _ _ _} _ {_} _.
Global Arguments MatchPair {_ _ _ _ _ _} _ {_} _.
+Global Arguments Fst {_ _ _ _ _ _} _.
+Global Arguments Snd {_ _ _ _ _ _} _.
Global Arguments Pair {_ _ _ _ _} _ {_} _.
Global Arguments Return {_ _ _ _ _} _.
Global Arguments Abs {_ _ _ _ _ _} _.