From 0bbbdfede48aed7a74ac2fb95440256ed60fb6e8 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 1 Feb 2019 18:17:11 -0500 Subject: Add support for reifying `zrange` and `option` This is needed to reify statements for the rewriter. --- src/Language.v | 199 +++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 152 insertions(+), 47 deletions(-) (limited to 'src/Language.v') diff --git a/src/Language.v b/src/Language.v index efcc16b63..044d19a3e 100644 --- a/src/Language.v +++ b/src/Language.v @@ -15,7 +15,7 @@ Require Import Crypto.Util.CPSNotations. Require Import Crypto.Util.Notations. Require Import Crypto.Util.Tactics.RunTacticAsConstr. Require Import Crypto.Util.Tactics.DebugPrint. -Import ListNotations. Local Open Scope bool_scope. Local Open Scope Z_scope. +Import Coq.Lists.List ListNotations. Local Open Scope bool_scope. Local Open Scope Z_scope. Module Compilers. Local Set Boolean Equality Schemes. @@ -293,8 +293,8 @@ Module Compilers. Module base. Local Notation einterp := type.interp. Module type. - Inductive base := unit | Z | bool | nat. (* Not Variant because COQBUG(https://github.com/coq/coq/issues/7738) *) - Inductive type := type_base (t : base) | prod (A B : type) | list (A : type). + Inductive base := unit | Z | bool | nat | zrange. (* Not Variant because COQBUG(https://github.com/coq/coq/issues/7738) *) + Inductive type := type_base (t : base) | prod (A B : type) | list (A : type) | option (A : type). Global Coercion type_base : base >-> type. End type. Global Coercion type.type_base : type.base >-> type.type. @@ -305,12 +305,14 @@ Module Compilers. | type.Z => BinInt.Z | type.bool => Datatypes.bool | type.nat => Datatypes.nat + | type.zrange => zrange end. Fixpoint interp (ty : type) := match ty with | type.type_base t => base_interp t | type.prod A B => interp A * interp B | type.list A => Datatypes.list (interp A) + | type.option A => Datatypes.option (interp A) end%type. Definition try_make_base_transport_cps @@ -321,11 +323,13 @@ Module Compilers. | type.Z, type.Z | type.bool, type.bool | type.nat, type.nat + | type.zrange, type.zrange => (return (Some id)) | type.unit, _ | type.Z, _ | type.bool, _ | type.nat, _ + | type.zrange, _ => (return None) end%cps. Fixpoint try_make_transport_cps @@ -339,9 +343,11 @@ Module Compilers. trB <-- try_make_transport_cps (fun B => P (type.prod _ B)) _ _; return (Some (fun v => trB (trA v)))) | type.list A, type.list A' => try_make_transport_cps (fun A => P (type.list A)) A A' + | type.option A, type.option A' => try_make_transport_cps (fun A => P (type.option A)) A A' | type.type_base _, _ | type.prod _ _, _ | type.list _, _ + | type.option _, _ => (return None) end%cps. @@ -351,31 +357,6 @@ Module Compilers. Definition try_transport (P : type -> Type) (t1 t2 : type) (v : P t1) : option (P t2) := try_transport_cps P t1 t2 v _ id. - (* - Fixpoint try_transport - (P : type -> Type) (t1 t2 : type) : P t1 -> option (P t2) - := match t1, t2 return P t1 -> option (P t2) with - | type.unit, type.unit - | type.Z, type.Z - | type.bool, type.bool - | type.nat, type.nat - => @Some _ - | type.list A, type.list A' - => @try_transport (fun A => P (type.list A)) A A' - | type.prod s d, type.prod s' d' - => fun v - => (v <- (try_transport (fun s => P (type.prod s d)) s s' v); - (try_transport (fun d => P (type.prod s' d)) d d' v))%option - - | type.unit, _ - | type.Z, _ - | type.bool, _ - | type.nat, _ - | type.prod _ _, _ - | type.list _, _ - => fun _ => None - end. - *) Ltac reify_base ty := let __ := Reify.debug_enter_reify_base_type ty in @@ -384,6 +365,7 @@ Module Compilers. | Datatypes.nat => type.nat | Datatypes.bool => type.bool | BinInt.Z => type.Z + | zrange => type.zrange | interp (type.type_base ?T) => T | @einterp type interp (@Compilers.type.base type (type.type_base ?T)) => T | _ => let __ := match goal with @@ -401,6 +383,9 @@ Module Compilers. | Datatypes.list ?T => let rT := reify T in constr:(type.list rT) + | Datatypes.option ?T + => let rT := reify T in + constr:(type.option rT) | interp ?T => T | @einterp type interp (@Compilers.type.base type ?T) => T | ?ty => let rT := reify_base ty in @@ -616,6 +601,9 @@ Module Compilers. | match ?x with Datatypes.pair a b => @?f a b end => let T := type of term in reify_rec (@prod_rect _ _ (fun _ => T) f x) + | match ?x with ZRange.Build_zrange a b => @?f a b end + => let T := type of term in + reify_rec (@ZRange.zrange_rect (fun _ => T) f x) | match ?x with nil => ?N | cons a b => @?C a b end => let T := type of term in reify_rec (@list_case _ (fun _ => T) N C x) @@ -892,13 +880,17 @@ Module Compilers. | Z_log2_up : ident (Z -> Z) | Z_eqb : ident (Z -> Z -> bool) | Z_leb : ident (Z -> Z -> bool) + | Z_ltb : ident (Z -> Z -> bool) | Z_geb : ident (Z -> Z -> bool) + | Z_gtb : ident (Z -> Z -> bool) | Z_of_nat : ident (nat -> Z) | Z_to_nat : ident (Z -> nat) | Z_shiftr : ident (Z -> Z -> Z) | Z_shiftl : ident (Z -> Z -> Z) | Z_land : ident (Z -> Z -> Z) | Z_lor : ident (Z -> Z -> Z) + | Z_min : ident (Z -> Z -> Z) + | Z_max : ident (Z -> Z -> Z) | Z_bneg : ident (Z -> Z) | Z_lnot_modulo : ident (Z -> Z -> Z) | Z_mul_split : ident (Z -> Z -> Z -> Z * Z) @@ -911,8 +903,13 @@ Module Compilers. | Z_add_modulo : ident (Z -> Z -> Z -> Z) | Z_rshi : ident (Z -> Z -> Z -> Z -> Z) | Z_cc_m : ident (Z -> Z -> Z) - | Z_cast (range : zrange) : ident (Z -> Z) - | Z_cast2 (range : zrange * zrange) : ident ((Z * Z) -> (Z * Z)) + | Z_cast (range : ZRange.zrange) : ident (Z -> Z) + | Z_cast2 (range : ZRange.zrange * ZRange.zrange) : ident ((Z * Z) -> (Z * Z)) + | option_Some {A:base.type} : ident (A -> option A) + | option_None {A:base.type} : ident (option A) + | option_rect {A P : base.type} : ident ((A -> P) -> (unit -> P) -> option A -> P) + | Build_zrange : ident (Z -> Z -> zrange) + | zrange_rect {P:base.type} : ident ((Z -> Z -> P) -> zrange -> P) | fancy_add (log2wordmax : BinInt.Z) (imm : BinInt.Z) : ident (Z * Z -> Z * Z) | fancy_addc (log2wordmax : BinInt.Z) (imm : BinInt.Z) : ident (Z * Z * Z -> Z * Z) | fancy_sub (log2wordmax : BinInt.Z) (imm : BinInt.Z) : ident (Z * Z -> Z * Z) @@ -927,30 +924,34 @@ Module Compilers. | fancy_sell : ident (Z * Z * Z -> Z) | fancy_addm : ident (Z * Z * Z -> Z) . + Notation Some := option_Some. + Notation None := option_None. Global Arguments Z_cast2 _%zrange_scope. - Definition to_fancy {s d : base.type} (idc : ident (s -> d)) : option (fancy.ident s d) - := match idc in ident t return option match t with + Definition to_fancy {s d : base.type} (idc : ident (s -> d)) : Datatypes.option (fancy.ident s d) + := match idc in ident t return Datatypes.option match t with | type.base s -> type.base d => fancy.ident s d | _ => Datatypes.unit end%etype with - | fancy_add log2wordmax imm => Some (fancy.with_wordmax log2wordmax (fancy.add imm)) - | fancy_addc log2wordmax imm => Some (fancy.with_wordmax log2wordmax (fancy.addc imm)) - | fancy_sub log2wordmax imm => Some (fancy.with_wordmax log2wordmax (fancy.sub imm)) - | fancy_subb log2wordmax imm => Some (fancy.with_wordmax log2wordmax (fancy.subb imm)) - | fancy_mulll log2wordmax => Some (fancy.with_wordmax log2wordmax fancy.mulll) - | fancy_mullh log2wordmax => Some (fancy.with_wordmax log2wordmax fancy.mullh) - | fancy_mulhl log2wordmax => Some (fancy.with_wordmax log2wordmax fancy.mulhl) - | fancy_mulhh log2wordmax => Some (fancy.with_wordmax log2wordmax fancy.mulhh) - | fancy_rshi log2wordmax x => Some (fancy.with_wordmax log2wordmax (fancy.rshi x)) - | fancy_selc => Some fancy.selc - | fancy_selm log2wordmax => Some (fancy.with_wordmax log2wordmax fancy.selm) - | fancy_sell => Some fancy.sell - | fancy_addm => Some fancy.addm - | _ => None + | fancy_add log2wordmax imm => Datatypes.Some (fancy.with_wordmax log2wordmax (fancy.add imm)) + | fancy_addc log2wordmax imm => Datatypes.Some (fancy.with_wordmax log2wordmax (fancy.addc imm)) + | fancy_sub log2wordmax imm => Datatypes.Some (fancy.with_wordmax log2wordmax (fancy.sub imm)) + | fancy_subb log2wordmax imm => Datatypes.Some (fancy.with_wordmax log2wordmax (fancy.subb imm)) + | fancy_mulll log2wordmax => Datatypes.Some (fancy.with_wordmax log2wordmax fancy.mulll) + | fancy_mullh log2wordmax => Datatypes.Some (fancy.with_wordmax log2wordmax fancy.mullh) + | fancy_mulhl log2wordmax => Datatypes.Some (fancy.with_wordmax log2wordmax fancy.mulhl) + | fancy_mulhh log2wordmax => Datatypes.Some (fancy.with_wordmax log2wordmax fancy.mulhh) + | fancy_rshi log2wordmax x => Datatypes.Some (fancy.with_wordmax log2wordmax (fancy.rshi x)) + | fancy_selc => Datatypes.Some fancy.selc + | fancy_selm log2wordmax => Datatypes.Some (fancy.with_wordmax log2wordmax fancy.selm) + | fancy_sell => Datatypes.Some fancy.sell + | fancy_addm => Datatypes.Some fancy.addm + | _ => Datatypes.None end. End with_scope. + Notation Some := option_Some. + Notation None := option_None. Section gen. Context (cast_outside_of_range : zrange -> BinInt.Z -> BinInt.Z). @@ -1036,7 +1037,9 @@ Module Compilers. | Z_modulo => Z.modulo | Z_eqb => Z.eqb | Z_leb => Z.leb + | Z_ltb => Z.ltb | Z_geb => Z.geb + | Z_gtb => Z.gtb | Z_log2 => Z.log2 | Z_log2_up => Z.log2_up | Z_of_nat => Z.of_nat @@ -1045,6 +1048,8 @@ Module Compilers. | Z_shiftl => Z.shiftl | Z_land => Z.land | Z_lor => Z.lor + | Z_min => Z.min + | Z_max => Z.max | Z_mul_split => Z.mul_split | Z_add_get_carry => Z.add_get_carry_full | Z_add_with_carry => Z.add_with_carry @@ -1059,6 +1064,12 @@ Module Compilers. | Z_cc_m => Z.cc_m | Z_cast r => cast r | Z_cast2 (r1, r2) => fun '(x1, x2) => (cast r1 x1, cast r2 x2) + | Some A => @Datatypes.Some _ + | None A => @Datatypes.None _ + | option_rect A P + => fun S_case N_case o => @Datatypes.option_rect _ _ S_case (N_case tt) o + | Build_zrange => ZRange.Build_zrange + | zrange_rect A => @ZRange.zrange_rect _ | fancy_add _ _ as idc | fancy_addc _ _ as idc | fancy_sub _ _ as idc @@ -1079,6 +1090,8 @@ Module Compilers. Definition cast_outside_of_range (r : zrange) (v : BinInt.Z) : BinInt.Z. Proof. exact v. Qed. End with_base. + Notation Some := option_Some. + Notation None := option_None. (** Interpret identifiers where [Z_cast] is an opaque identity function when the value is not inside the range *) @@ -1087,6 +1100,7 @@ Module Compilers. Notation LiteralZ := (@Literal base.type.Z). Notation LiteralBool := (@Literal base.type.bool). Notation LiteralNat := (@Literal base.type.nat). + Notation LiteralZRange := (@Literal base.type.zrange). (** TODO: MOVE ME? *) Module Thunked. @@ -1098,6 +1112,8 @@ Module Compilers. := ListUtil.list_case (fun _ => P) (N tt) C ls. Definition nat_rect P (O_case : unit -> P) (S_case : nat -> P -> P) (n : nat) : P := Datatypes.nat_rect (fun _ => P) (O_case tt) S_case n. + Definition option_rect {A} P (S_case : A -> P) (N_case : unit -> P) (o : option A) : P + := Datatypes.option_rect (fun _ => P) S_case (N_case tt) o. End Thunked. Ltac require_primitive_const term := @@ -1113,6 +1129,10 @@ Module Compilers. | xI ?p => require_primitive_const p | xO ?p => require_primitive_const p | xH => idtac + | Datatypes.Some ?x => require_primitive_const x + | Datatypes.None => idtac + | ZRange.Build_zrange ?x ?y + => require_primitive_const x; require_primitive_const y | ?term => fail 0 "Not a known const:" term end. Ltac is_primitive_const term := @@ -1175,6 +1195,16 @@ Module Compilers. | @Thunked.bool_rect ?T => let rT := base.reify T in then_tac (@ident.bool_rect rT) + | @Datatypes.option_rect ?A ?T0 ?PSome ?PNone + => lazymatch (eval cbv beta in T0) with + | fun _ => ?T => reify_rec (@Thunked.option_rect A T PSome (fun _ : Datatypes.unit => PNone)) + | T0 => else_tac () + | ?T' => reify_rec (@Datatypes.option_rect A T' PSome PNone) + end + | @Thunked.option_rect ?A ?T + => let rA := base.reify A in + let rT := base.reify T in + then_tac (@ident.option_rect rA rT) | @Datatypes.prod_rect ?A ?B ?T0 => lazymatch (eval cbv beta in T0) with | fun _ => ?T @@ -1185,6 +1215,14 @@ Module Compilers. | T0 => else_tac () | ?T' => reify_rec (@Datatypes.prod_rect A B T') end + | @ZRange.zrange_rect ?T0 + => lazymatch (eval cbv beta in T0) with + | fun _ => ?T + => let rT := base.reify T in + then_tac (@ident.zrange_rect rT) + | T0 => else_tac () + | ?T' => reify_rec (@ZRange.zrange_rect T') + end | @Datatypes.nat_rect ?T0 ?P0 => lazymatch (eval cbv beta in T0) with | fun _ => _ -> _ => else_tac () @@ -1277,13 +1315,17 @@ Module Compilers. | Z.modulo => then_tac ident.Z_modulo | Z.eqb => then_tac ident.Z_eqb | Z.leb => then_tac ident.Z_leb + | Z.ltb => then_tac ident.Z_ltb | Z.geb => then_tac ident.Z_geb + | Z.gtb => then_tac ident.Z_gtb | Z.log2 => then_tac ident.Z_log2 | Z.log2_up => then_tac ident.Z_log2_up | Z.shiftl => then_tac ident.Z_shiftl | Z.shiftr => then_tac ident.Z_shiftr | Z.land => then_tac ident.Z_land | Z.lor => then_tac ident.Z_lor + | Z.min => then_tac ident.Z_min + | Z.max => then_tac ident.Z_max | Z.bneg => then_tac ident.Z_bneg | Z.lnot_modulo => then_tac ident.Z_lnot_modulo | Z.of_nat => then_tac ident.Z_of_nat @@ -1298,6 +1340,14 @@ Module Compilers. | Z.add_modulo => then_tac ident.Z_add_modulo | Z.rshi => then_tac ident.Z_rshi | Z.cc_m => then_tac ident.Z_cc_m + | ident.cast _ => then_tac ident.Z_cast + | @Some ?A + => let rA := base.reify A in + then_tac (@ident.Some rA) + | @None ?A + => let rA := base.reify A in + then_tac (@ident.None rA) + | ZRange.Build_zrange => then_tac ident.Build_zrange | _ => else_tac () end end. @@ -1309,6 +1359,13 @@ Module Compilers. (fun x _ xs => expr.Ident ident.cons @ x @ xs)%expr ls. + Definition reify_option {var} {t} (v : option (@expr.expr base.type ident var (type.base t))) : @expr.expr base.type ident var (type.base (base.type.option t)) + := Datatypes.option_rect + (fun _ => _) + (fun x => expr.Ident ident.Some @ x)%expr + (expr.Ident ident.None) + v. + Fixpoint smart_Literal {var} {t:base.type} : base.interp t -> @expr.expr base.type ident var (type.base t) := match t with | base.type.type_base t => fun v => expr.Ident (ident.Literal v) @@ -1318,6 +1375,9 @@ Module Compilers. | base.type.list A => fun v : list (base.interp A) => reify_list (List.map (@smart_Literal var A) v) + | base.type.option A + => fun v : option (base.interp A) + => reify_option (option_map (@smart_Literal var A) v) end%expr. Module Export Notations. @@ -1354,7 +1414,7 @@ Module Compilers. Notation "x 'mod' y" := (#Z_modulo @ x @ y)%expr : expr_scope. Notation "- x" := (#Z_opp @ x)%expr : expr_scope. Global Arguments gen_interp _ _ !_. - + Global Arguments ident.Z_cast _%zrange_scope. Global Arguments ident.Z_cast2 _%zrange_scope. End Notations. End ident. @@ -1519,6 +1579,42 @@ Module Compilers. | Some (ident.nil _) => true | _ => false end. + Definition invert_None {t} (e : expr (base.type.option t)) : bool + := match invert_Ident e with + | Some (ident.None _) => true + | _ => false + end. + Local Notation if_arrow f t + := (match t return Type with + | (a -> b)%etype => f a b + | _ => unit + end) (only parsing). + Definition invert_Some {t} (e : expr (base.type.option t)) + : option (expr t) + := match invert_AppIdent e with + | Some (existT s (idc, e)) + => match idc in ident.ident t + return if_arrow (fun a b => expr a) t + -> option match t return Type with + | (a -> type.base (base.type.option t)) + => expr t + | _ => unit + end%etype + with + | ident.Some _ => fun x => Some x + | _ => fun _ => None + end e + | None => None + end. + + Definition reflect_option {t} (e : expr (base.type.option t)) + : option (option (expr t)) + := match invert_None e, invert_Some e with + | true, _ => Some None + | _, Some x => Some (Some x) + | false, None => None + end. + Local Notation if_arrow2 f t := (match t return Type with | (a -> b -> c)%etype => f a b c @@ -1590,9 +1686,11 @@ Module Compilers. | base.type.Z => (-1)%Z | base.type.nat => 0%nat | base.type.bool => true + | base.type.zrange => r[0~>0]%zrange | base.type.list _ => nil | base.type.prod A B => (@default A, @default B) + | base.type.option A => None end. End base. Fixpoint default {t} : type.interp base.interp t @@ -1611,10 +1709,12 @@ Module Compilers. | base.type.prod A B => (@default A, @default B) | base.type.list A => #ident.nil + | base.type.option A => #ident.None | base.type.unit as t | base.type.Z as t | base.type.nat as t | base.type.bool as t + | base.type.zrange as t => ##(@type.base.default t) end%expr. End with_var. @@ -1649,6 +1749,7 @@ Module Compilers. End defaults. Notation reify_list := ident.reify_list. + Notation reify_option := ident.reify_option. Module GallinaReify. Module base. @@ -1663,10 +1764,14 @@ Module Compilers. | base.type.list A as t => fun x : list (base.interp A) => reify_list (List.map (@reify A) x) + | base.type.option A as t + => fun x : option (base.interp A) + => reify_option (option_map (@reify A) x) | base.type.unit as t | base.type.Z as t | base.type.bool as t | base.type.nat as t + | base.type.zrange as t => fun x : base.interp t => (##x)%expr end. -- cgit v1.2.3