diff options
author | Jason Gross <jgross@mit.edu> | 2017-01-11 21:02:50 -0500 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2017-03-01 11:45:47 -0500 |
commit | 6b3048c37ad348dc88ecc03ef892ecfb121bfa7f (patch) | |
tree | 351e5438c5664ab0caf08b9d5054f296ff4aa2ee /src/Reflection/InputSyntax.v | |
parent | 80dc66a34fbf031bfac1214ccbb3bb1dcdef3d39 (diff) |
Switch to fully uncurried form for reflection
This will eventually make a number of proofs easier.
Unfortunately, the correctness lemmas for AddCoordinates and LadderStep
no longer work (because of different arities), and there's a proof in
Experiments/Ed25519 that I've admitted.
The correctness lemmas will be easy to re-add when we have a more
general version that handle arbitrary type shapes.
Diffstat (limited to 'src/Reflection/InputSyntax.v')
-rw-r--r-- | src/Reflection/InputSyntax.v | 138 |
1 files changed, 115 insertions, 23 deletions
diff --git a/src/Reflection/InputSyntax.v b/src/Reflection/InputSyntax.v index 258241391..12810d20d 100644 --- a/src/Reflection/InputSyntax.v +++ b/src/Reflection/InputSyntax.v @@ -2,7 +2,7 @@ Require Import Coq.Strings.String. Require Import Crypto.Reflection.Syntax. Require Import Crypto.Reflection.SmartMap. -Require Import Crypto.Reflection.Relations. +Require Import Crypto.Reflection.ExprInversion. Require Import Crypto.Reflection.InterpProofs. Require Import Crypto.Util.Tuple. Require Import Crypto.Util.Tactics. @@ -17,14 +17,20 @@ Section language. Context (base_type_code : Type). Local Notation flat_type := (flat_type base_type_code). - Local Notation type := (type base_type_code). + Inductive type := Tflat (A : flat_type) | Arrow (A : flat_type) (B : type). Section expr_param. Context (interp_base_type : base_type_code -> Type). Context (op : flat_type (* input tuple *) -> flat_type (* output type *) -> Type). - Local Notation interp_type := (interp_type interp_base_type). Local Notation interp_flat_type_gen := interp_flat_type. Local Notation interp_flat_type := (interp_flat_type interp_base_type). + + Fixpoint interp_type (t : type) := + match t with + | Tflat A => interp_flat_type A + | Arrow A B => (interp_flat_type A -> interp_type B)%type + end. + Section expr. Context {var : flat_type -> Type}. @@ -37,9 +43,11 @@ Section language. | Pair : forall {t1}, exprf t1 -> forall {t2}, exprf t2 -> exprf (Prod t1 t2) | MatchPair : forall {t1 t2}, exprf (Prod t1 t2) -> forall {tC}, (var t1 -> var t2 -> exprf tC) -> exprf tC. Inductive expr : type -> Type := - | Return {t} : exprf t -> expr t - | Abs {src dst} : (var (Tbase src) -> expr dst) -> expr (Arrow src dst). - Global Coercion Return : exprf >-> expr. + | Return {T} : exprf T -> expr (Tflat T) + | Abs {src dst} : (var src -> expr dst) -> expr (Arrow src dst). + + Definition Fst {t1 t2} (v : exprf (Prod t1 t2)) : exprf t1 := MatchPair v (fun x y => Var x). + Definition Snd {t1 t2} (v : exprf (Prod t1 t2)) : exprf t2 := MatchPair v (fun x y => Var y). End expr. Definition Expr (t : type) := forall var, @expr var t. @@ -69,6 +77,19 @@ Section language. Context {var : base_type_code -> Type} (make_const : forall t, interp_base_type t -> op Unit (Tbase t)). + Fixpoint compilet (t : type) : Syntax.type base_type_code + := Syntax.Arrow + match t with + | Tflat T => Unit + | Arrow A (Tflat B) => A + | Arrow A B + => A * domain (compilet B) + end%ctype + match t with + | Tflat T => T + | Arrow A B => codomain (compilet B) + end. + Fixpoint SmartConst (t : flat_type) : interp_flat_type t -> Syntax.exprf base_type_code op (var:=var) t := match t return interp_flat_type t -> Syntax.exprf _ _ t with | Unit => fun _ => TT @@ -87,16 +108,36 @@ Section language. | MatchPair _ _ ex _ eC => Syntax.LetIn (@compilef _ ex) (fun xy => @compilef _ (eC (fst xy) (snd xy))) end. - Fixpoint compile {t} (e : @expr (interp_flat_type_gen var) t) : @Syntax.expr base_type_code op var t - := match e in expr t return @Syntax.expr _ _ _ t with - | Return _ x => Syntax.Return (compilef x) - | Abs a _ f => Syntax.Abs (fun x : var a => @compile _ (f x)) - end. + (* ugh, so much manual annotation *) + Fixpoint compile {t} (e : @expr (interp_flat_type_gen var) t) : @Syntax.expr base_type_code op var (compilet t) + := match e in expr t return @Syntax.expr _ _ _ (compilet t) with + | Return _ v => Syntax.Abs (fun _ => compilef v) + | Abs src dst f + => let res := fun x => @compile _ (f x) in + match dst + return (_ -> Syntax.expr _ _ (compilet dst)) + -> Syntax.expr _ _ (compilet (Arrow src dst)) + with + | Tflat T + => fun resf => Syntax.Abs (fun x => invert_Abs (resf x) tt) + | Arrow A B as dst' + => match compilet dst' as cdst + return (_ -> Syntax.expr _ _ cdst) + -> Syntax.expr _ _ (Syntax.Arrow + (_ * domain cdst) + (codomain cdst)) + with + | Syntax.Arrow A' B' + => fun resf => Syntax.Abs (fun x : interp_flat_type_gen var (_ * _) + => invert_Abs (resf (fst x)) (snd x)) + end + end res + end. End compile. Definition Compile (make_const : forall t, interp_base_type t -> op Unit (Tbase t)) - {t} (e : Expr t) : Syntax.Expr base_type_code op t + {t} (e : Expr t) : Syntax.Expr base_type_code op (compilet t) := fun var => compile make_const (e _). Section compile_correct. @@ -127,32 +168,83 @@ Section language. end. Qed. - Lemma Compile_correct {t} (e : @Expr t) - : interp_type_gen_rel_pointwise (fun _ => @eq _) - (Syntax.Interp interp_op (Compile make_const e)) - (Interp interp_op e). + Lemma compile_flat_correct {T} (e : expr (Tflat T)) + : forall x, Syntax.interp interp_op (compile make_const e) x = interp interp_op e. + Proof. + intros []; simpl. + let G := match goal with |- ?G => G end in + let G := match (eval pattern T, e in G) with ?G _ _ => G end in + refine match e in expr t return match t return expr t -> _ with + | Tflat T => G T + | _ => fun _ => True + end e + with + | Return _ _ => _ + | Abs _ _ _ => I + end; simpl. + apply compilef_correct. + Qed. + + Lemma Compile_flat_correct_flat {T} (e : Expr (Tflat T)) + : forall x, Syntax.Interp interp_op (Compile make_const e) x = Interp interp_op e. + Proof. apply compile_flat_correct. Qed. + + Lemma Compile_correct {src dst} (e : @Expr (Arrow src (Tflat dst))) + : forall x, Syntax.Interp interp_op (Compile make_const e) x = Interp interp_op e x. Proof. unfold Interp, Compile, Syntax.Interp; simpl. pose (e interp_flat_type) as E. repeat match goal with |- context[e ?f] => change (e f) with E end. clearbody E; clear e. - induction E. - { apply compilef_correct. } - { simpl; auto. } + let G := match goal with |- ?G => G end in + let G := match (eval pattern src, dst, E in G) with ?G _ _ _ => G end in + refine match E in expr t return match t return expr t -> _ with + | Arrow src (Tflat dst) => G src dst + | _ => fun _ => True + end E + with + | Abs src dst e + => match dst + return (forall e : _ -> expr dst, + match dst return expr (Arrow src dst) -> _ with + | Tflat dst => G src dst + | _ => fun _ => True + end (Abs e)) + with + | Tflat _ + => fun e0 x + => _ + | Arrow _ _ => fun _ => I + end e + | Return _ _ => I + end; simpl. + refine match e0 x as e0x in expr t + return match t return expr t -> _ with + | Tflat _ + => fun e0x + => Syntax.interpf _ (invert_Abs (compile _ e0x) _) + = interp _ e0x + | _ => fun _ => True + end e0x + with + | Abs _ _ _ => I + | Return _ _ => _ + end; simpl. + apply compilef_correct. Qed. - - Lemma Compile_flat_correct {t : flat_type} (e : @Expr t) - : Syntax.Interp interp_op (Compile make_const e) = Interp interp_op e. - Proof. exact (@Compile_correct t e). Qed. End compile_correct. End expr_param. End language. +Global Arguments Arrow {_} _ _. +Global Arguments Tflat {_} _. Global Arguments Const {_ _ _ _ _} _. Global Arguments Var {_ _ _ _ _} _. Global Arguments Op {_ _ _ _ _ _} _ _. Global Arguments LetIn {_ _ _ _ _} _ {_} _. Global Arguments MatchPair {_ _ _ _ _ _} _ {_} _. +Global Arguments Fst {_ _ _ _ _ _} _. +Global Arguments Snd {_ _ _ _ _ _} _. Global Arguments Pair {_ _ _ _ _} _ {_} _. Global Arguments Return {_ _ _ _ _} _. Global Arguments Abs {_ _ _ _ _ _} _. |