From 1701df710e52d4d4e6e97e608c09cfd80d7b7d8c Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Wed, 8 Feb 2017 16:04:33 -0500 Subject: Simpler version of MapCast Unfortunately, more of the casting logic is in MultiSizeTest2, now. I plan to make it more generic soon. --- src/Reflection/MultiSizeTest2.v | 206 +++++++++++++++++++++++++++++++++++----- 1 file changed, 183 insertions(+), 23 deletions(-) (limited to 'src/Reflection/MultiSizeTest2.v') diff --git a/src/Reflection/MultiSizeTest2.v b/src/Reflection/MultiSizeTest2.v index 8189b1325..0b8c7f958 100644 --- a/src/Reflection/MultiSizeTest2.v +++ b/src/Reflection/MultiSizeTest2.v @@ -1,9 +1,11 @@ Require Import Coq.omega.Omega. Require Import Crypto.Reflection.Syntax. Require Import Crypto.Reflection.SmartMap. +Require Import Crypto.Reflection.Linearize. +Require Import Crypto.Reflection.Inline. Require Import Crypto.Reflection.Equality. Require Import Crypto.Reflection.Application. -Require Import Crypto.Reflection.MapCastWithCastOp. +Require Import Crypto.Reflection.MapCast. (** * Preliminaries: bounded and unbounded number types *) @@ -26,6 +28,22 @@ Definition interp_base_type (t : base_type) | Word8 => word8 | Word9 => word9 end. +Definition base_type_max (x y : base_type) := + match x, y with + | Nat, _ => Nat + | _, Nat => Nat + | Word9, _ => Word9 + | _, Word9 => Word9 + | Word8, Word8 => Word8 + end. +Definition base_type_min (x y : base_type) := + match x, y with + | Word8, _ => Word8 + | _, Word8 => Word8 + | Word9, _ => Word9 + | _, Word9 => Word9 + | Nat, Nat => Nat + end. Definition interp_base_type_bounds (t : base_type) := nat. Local Notation TNat := (Tbase Nat). @@ -119,37 +137,179 @@ Definition bound_base_const t1 t2 (x1 : interp_base_type t1) (x2 : interp_base_t := bound (unbound x1). Local Notation new_flat_type (*: forall t, interp_flat_type interp_base_type2 t -> flat_type base_type_code1*) := (@SmartFlatTypeMap2 _ _ interp_base_type_bounds (fun t v => Tbase (bound_type t v))). +Fixpoint new_type {t} : forall (e_bounds : interp_type interp_base_type_bounds t) + (input_bounds : interp_all_binders_for' t interp_base_type_bounds), + type base_type + := match t return interp_type _ t -> interp_all_binders_for' t _ -> type _ with + | Tflat T => fun e_bounds _ => @new_flat_type T e_bounds + | Arrow A B + => fun e_bounds input_bounds + => Arrow (@bound_type A (fst input_bounds)) + (@new_type B (e_bounds (fst input_bounds)) (snd input_bounds)) + end. Definition bound_op ovar src1 dst1 src2 dst2 (opc1 : op src1 dst1) (opc2 : op src2 dst2) - : forall args2, - option { new_src : _ & (@exprf _ op ovar new_src - -> @exprf _ op ovar (new_flat_type (interpf (@interp_op_bounds) (Op opc2 args2))))%type } + : exprf base_type op (var:=ovar) src1 + -> interp_flat_type interp_base_type_bounds src2 + -> exprf base_type op (var:=ovar) dst1 := match opc1 in op src1 dst1, opc2 in op src2 dst2 - return (forall args2, - option { new_src : _ & (exprf _ _ new_src -> exprf _ _ (new_flat_type (interpf (@interp_op_bounds) (Op opc2 args2))))%type }) + return (exprf base_type op (var:=ovar) src1 + -> interp_flat_type interp_base_type_bounds src2 + -> exprf base_type op (var:=ovar) dst1) with - | Const t1 v1, Const t2 v2 - => fun args2 => Some (existT _ Unit (fun x => Op (Const (@bound_base_const t1 t2 v1 _)) TT)) - | Plus T1, Plus T2 => fun args2 => Some (existT _ _ (Op (Plus (bound_type T2 _)))) - | Cast _ _, Plus _ - | Cast _ _, Const _ _ - | Cast _ _, Cast _ _ - => fun args2 => Some (existT _ _ (fun args' => args')) + | Plus T1, Plus T2 + => fun args args2 + => LetIn args + (fun args + => Op (Cast _ _) (Op (Plus (base_type_max + (bound_type T2 (interp_op_bounds (Plus _) args2)) + (base_type_max + (bound_type T2 (fst args2)) + (bound_type T2 (snd args2))))) + (Pair (Op (Cast _ _) (Var (fst args))) + (Op (Cast _ _) (Var (snd args)))))) + | Const _ _ as e, _ + => fun args args2 => Op e TT + | Cast _ _ as e, Plus _ + | Cast _ _ as e, Const _ _ + | Cast _ _ as e, Cast _ _ + => fun args args2 => Op e args | Plus _, _ - | Const _ _, _ - => fun _ => None + => fun args args2 => @failf _ _ + end. + +Definition interpf_smart_bound {var t} + (e : interp_flat_type var t) (bounds : interp_flat_type interp_base_type_bounds t) + : interp_flat_type (fun t => exprf _ op (var:=var) (Tbase t)) (new_flat_type bounds) + := SmartFlatTypeMap2Interp2 + (f:=fun t v => Tbase _) + (fun t bs v => Op (Cast t (bound_type t bs)) (Var v)) bounds e. +Definition interpf_smart_unbound {var t} + (bounds : interp_flat_type interp_base_type_bounds t) + (e : interp_flat_type (fun t => exprf _ op (var:=var) (Tbase t)) (new_flat_type bounds)) + : interp_flat_type (fun t => @exprf base_type op var (Tbase t)) t + := SmartFlatTypeMapUnInterp2 + (f:=fun t v => Tbase (bound_type t _)) + (fun t bs v => Op (Cast (bound_type t bs) t) v) + e. + +Definition smart_boundf {var t1} (e1 : exprf base_type op (var:=var) t1) (bounds : interp_flat_type interp_base_type_bounds t1) + : exprf base_type op (var:=var) (new_flat_type bounds) + := LetIn e1 (fun e1' => SmartPairf (var:=var) (interpf_smart_bound e1' bounds)). +Fixpoint UnSmartArrow {P t} + : forall (e_bounds : interp_type interp_base_type_bounds t) + (input_bounds : interp_all_binders_for' t interp_base_type_bounds) + (e : P (SmartArrow (new_flat_type input_bounds) + (new_flat_type (ApplyInterpedAll' e_bounds input_bounds)))), + P (new_type e_bounds input_bounds) + := match t + return (forall (e_bounds : interp_type interp_base_type_bounds t) + (input_bounds : interp_all_binders_for' t interp_base_type_bounds) + (e : P (SmartArrow (new_flat_type input_bounds) + (new_flat_type (ApplyInterpedAll' e_bounds input_bounds)))), + P (new_type e_bounds input_bounds)) + with + | Tflat T => fun _ _ x => x + | Arrow A B => fun e_bounds input_bounds + => @UnSmartArrow + (fun t => P (Arrow (bound_type A (fst input_bounds)) t)) + B + (e_bounds (fst input_bounds)) + (snd input_bounds) + end. +Definition smart_bound {var t1} (e1 : expr base_type op (var:=var) t1) + (e_bounds : interp_type interp_base_type_bounds t1) + (input_bounds : interp_all_binders_for' t1 interp_base_type_bounds) + : expr base_type op (var:=var) (new_type e_bounds input_bounds) + := UnSmartArrow + e_bounds + input_bounds + (SmartAbs + (fun args + => LetIn + args + (fun args + => LetIn + (SmartPairf (interpf_smart_unbound input_bounds (SmartVarfMap (fun _ => Var) args))) + (fun v => smart_boundf + (ApplyAll e1 (interp_all_binders_for_of' v)) + (ApplyInterpedAll' e_bounds input_bounds))))). +Definition SmartBound {t1} (e : Expr base_type op t1) + (input_bounds : interp_all_binders_for' t1 interp_base_type_bounds) + : Expr base_type op (new_type _ input_bounds) + := fun var => smart_bound (e var) (interp (@interp_op_bounds) (e _)) input_bounds. + + +Definition SmartCast_base {var A A'} (x : exprf base_type op (var:=var) (Tbase A)) + : exprf base_type op (var:=var) (Tbase A') + := match base_type_eq_dec A A' with + | left pf => match pf with + | eq_refl => x + end + | right _ => Op (Cast _ A') x + end. +(** We can squash [a -> b -> c] into [a -> c] if [min a b c = min a + c], i.e., if the narrowest type we pass through in the original + case is the same as the narrowest type we pass through in the new + case. *) +Definition squash_cast {var} (a b c : base_type) : @exprf base_type op var (Tbase a) -> @exprf base_type op var (Tbase c) + := if base_type_beq (base_type_min b (base_type_min a c)) (base_type_min a c) + then SmartCast_base + else fun x => Op (Cast b c) (Op (Cast a b) x). +Fixpoint maybe_push_cast {var t} (v : @exprf base_type op var t) : option (@exprf base_type op var t) + := match v in exprf _ _ t return option (exprf _ _ t) with + | Var _ _ as v' + => Some v' + | Op t1 tR opc args + => match opc in op src dst return exprf _ _ src -> option (exprf _ _ dst) with + | Cast b c + => fun args + => match @maybe_push_cast _ _ args with + | Some (Op _ _ opc' args') + => match opc' in op src' dst' return exprf _ _ src' -> option (exprf _ _ (Tbase c)) with + | Cast a b + => fun args'' + => Some (squash_cast a b c args'') + | Const _ v + => fun args'' + => Some (SmartCast_base (Op (Const v) TT)) + | _ => fun _ => None + end args' + | Some (Var _ v as v') => Some (SmartCast_base v') + | Some _ => None + | None => None + end + | Const _ v + => fun _ => Some (Op (Const v) TT) + | _ => fun _ => None + end args + | Pair _ _ _ _ + | LetIn _ _ _ _ + | TT + => None + end. +Definition push_cast {var t} : @exprf base_type op var t -> inline_directive (op:=op) (var:=var) t + := match t with + | Tbase _ => fun v => match maybe_push_cast v with + | Some e => inline e + | None => default_inline v + end + | _ => default_inline end. Definition Boundify {t1} (e1 : Expr base_type op t1) args2 : Expr _ _ _ - := @MapInterpCastWithCastOp - base_type interp_base_type_bounds - op (@interp_op_bounds) base_type_beq internal_base_type_dec_bl - (@failf) (@bound_type) - (fun var t1 t2 => Op (Cast t1 t2)) - (fun _ _ opc => match opc with Cast _ _ => true | _ => false end) - bound_op - t1 e1 (interp_all_binders_for_to' args2). + := InlineConstGen + (@push_cast) + (Linearize + (SmartBound + (@MapInterpCast + base_type interp_base_type_bounds + op (@interp_op_bounds) + (@failf) + (@bound_op) + t1 e1 (interp_all_binders_for_to' args2)) + (interp_all_binders_for_to' args2))). (** * Examples *) -- cgit v1.2.3