aboutsummaryrefslogtreecommitdiff
path: root/src/Language.v
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2019-02-01 18:17:11 -0500
committerGravatar Jason Gross <jasongross9@gmail.com>2019-02-18 22:52:44 -0500
commit0bbbdfede48aed7a74ac2fb95440256ed60fb6e8 (patch)
tree09ae7896243a599ebd99224a00dcc1065869933b /src/Language.v
parenta7bc3fde287c451d2b0e77602cd9fab560d62a43 (diff)
Add support for reifying `zrange` and `option`
This is needed to reify statements for the rewriter.
Diffstat (limited to 'src/Language.v')
-rw-r--r--src/Language.v199
1 files changed, 152 insertions, 47 deletions
diff --git a/src/Language.v b/src/Language.v
index efcc16b63..044d19a3e 100644
--- a/src/Language.v
+++ b/src/Language.v
@@ -15,7 +15,7 @@ Require Import Crypto.Util.CPSNotations.
Require Import Crypto.Util.Notations.
Require Import Crypto.Util.Tactics.RunTacticAsConstr.
Require Import Crypto.Util.Tactics.DebugPrint.
-Import ListNotations. Local Open Scope bool_scope. Local Open Scope Z_scope.
+Import Coq.Lists.List ListNotations. Local Open Scope bool_scope. Local Open Scope Z_scope.
Module Compilers.
Local Set Boolean Equality Schemes.
@@ -293,8 +293,8 @@ Module Compilers.
Module base.
Local Notation einterp := type.interp.
Module type.
- Inductive base := unit | Z | bool | nat. (* Not Variant because COQBUG(https://github.com/coq/coq/issues/7738) *)
- Inductive type := type_base (t : base) | prod (A B : type) | list (A : type).
+ Inductive base := unit | Z | bool | nat | zrange. (* Not Variant because COQBUG(https://github.com/coq/coq/issues/7738) *)
+ Inductive type := type_base (t : base) | prod (A B : type) | list (A : type) | option (A : type).
Global Coercion type_base : base >-> type.
End type.
Global Coercion type.type_base : type.base >-> type.type.
@@ -305,12 +305,14 @@ Module Compilers.
| type.Z => BinInt.Z
| type.bool => Datatypes.bool
| type.nat => Datatypes.nat
+ | type.zrange => zrange
end.
Fixpoint interp (ty : type)
:= match ty with
| type.type_base t => base_interp t
| type.prod A B => interp A * interp B
| type.list A => Datatypes.list (interp A)
+ | type.option A => Datatypes.option (interp A)
end%type.
Definition try_make_base_transport_cps
@@ -321,11 +323,13 @@ Module Compilers.
| type.Z, type.Z
| type.bool, type.bool
| type.nat, type.nat
+ | type.zrange, type.zrange
=> (return (Some id))
| type.unit, _
| type.Z, _
| type.bool, _
| type.nat, _
+ | type.zrange, _
=> (return None)
end%cps.
Fixpoint try_make_transport_cps
@@ -339,9 +343,11 @@ Module Compilers.
trB <-- try_make_transport_cps (fun B => P (type.prod _ B)) _ _;
return (Some (fun v => trB (trA v))))
| type.list A, type.list A' => try_make_transport_cps (fun A => P (type.list A)) A A'
+ | type.option A, type.option A' => try_make_transport_cps (fun A => P (type.option A)) A A'
| type.type_base _, _
| type.prod _ _, _
| type.list _, _
+ | type.option _, _
=> (return None)
end%cps.
@@ -351,31 +357,6 @@ Module Compilers.
Definition try_transport (P : type -> Type) (t1 t2 : type) (v : P t1) : option (P t2)
:= try_transport_cps P t1 t2 v _ id.
- (*
- Fixpoint try_transport
- (P : type -> Type) (t1 t2 : type) : P t1 -> option (P t2)
- := match t1, t2 return P t1 -> option (P t2) with
- | type.unit, type.unit
- | type.Z, type.Z
- | type.bool, type.bool
- | type.nat, type.nat
- => @Some _
- | type.list A, type.list A'
- => @try_transport (fun A => P (type.list A)) A A'
- | type.prod s d, type.prod s' d'
- => fun v
- => (v <- (try_transport (fun s => P (type.prod s d)) s s' v);
- (try_transport (fun d => P (type.prod s' d)) d d' v))%option
-
- | type.unit, _
- | type.Z, _
- | type.bool, _
- | type.nat, _
- | type.prod _ _, _
- | type.list _, _
- => fun _ => None
- end.
- *)
Ltac reify_base ty :=
let __ := Reify.debug_enter_reify_base_type ty in
@@ -384,6 +365,7 @@ Module Compilers.
| Datatypes.nat => type.nat
| Datatypes.bool => type.bool
| BinInt.Z => type.Z
+ | zrange => type.zrange
| interp (type.type_base ?T) => T
| @einterp type interp (@Compilers.type.base type (type.type_base ?T)) => T
| _ => let __ := match goal with
@@ -401,6 +383,9 @@ Module Compilers.
| Datatypes.list ?T
=> let rT := reify T in
constr:(type.list rT)
+ | Datatypes.option ?T
+ => let rT := reify T in
+ constr:(type.option rT)
| interp ?T => T
| @einterp type interp (@Compilers.type.base type ?T) => T
| ?ty => let rT := reify_base ty in
@@ -616,6 +601,9 @@ Module Compilers.
| match ?x with Datatypes.pair a b => @?f a b end
=> let T := type of term in
reify_rec (@prod_rect _ _ (fun _ => T) f x)
+ | match ?x with ZRange.Build_zrange a b => @?f a b end
+ => let T := type of term in
+ reify_rec (@ZRange.zrange_rect (fun _ => T) f x)
| match ?x with nil => ?N | cons a b => @?C a b end
=> let T := type of term in
reify_rec (@list_case _ (fun _ => T) N C x)
@@ -892,13 +880,17 @@ Module Compilers.
| Z_log2_up : ident (Z -> Z)
| Z_eqb : ident (Z -> Z -> bool)
| Z_leb : ident (Z -> Z -> bool)
+ | Z_ltb : ident (Z -> Z -> bool)
| Z_geb : ident (Z -> Z -> bool)
+ | Z_gtb : ident (Z -> Z -> bool)
| Z_of_nat : ident (nat -> Z)
| Z_to_nat : ident (Z -> nat)
| Z_shiftr : ident (Z -> Z -> Z)
| Z_shiftl : ident (Z -> Z -> Z)
| Z_land : ident (Z -> Z -> Z)
| Z_lor : ident (Z -> Z -> Z)
+ | Z_min : ident (Z -> Z -> Z)
+ | Z_max : ident (Z -> Z -> Z)
| Z_bneg : ident (Z -> Z)
| Z_lnot_modulo : ident (Z -> Z -> Z)
| Z_mul_split : ident (Z -> Z -> Z -> Z * Z)
@@ -911,8 +903,13 @@ Module Compilers.
| Z_add_modulo : ident (Z -> Z -> Z -> Z)
| Z_rshi : ident (Z -> Z -> Z -> Z -> Z)
| Z_cc_m : ident (Z -> Z -> Z)
- | Z_cast (range : zrange) : ident (Z -> Z)
- | Z_cast2 (range : zrange * zrange) : ident ((Z * Z) -> (Z * Z))
+ | Z_cast (range : ZRange.zrange) : ident (Z -> Z)
+ | Z_cast2 (range : ZRange.zrange * ZRange.zrange) : ident ((Z * Z) -> (Z * Z))
+ | option_Some {A:base.type} : ident (A -> option A)
+ | option_None {A:base.type} : ident (option A)
+ | option_rect {A P : base.type} : ident ((A -> P) -> (unit -> P) -> option A -> P)
+ | Build_zrange : ident (Z -> Z -> zrange)
+ | zrange_rect {P:base.type} : ident ((Z -> Z -> P) -> zrange -> P)
| fancy_add (log2wordmax : BinInt.Z) (imm : BinInt.Z) : ident (Z * Z -> Z * Z)
| fancy_addc (log2wordmax : BinInt.Z) (imm : BinInt.Z) : ident (Z * Z * Z -> Z * Z)
| fancy_sub (log2wordmax : BinInt.Z) (imm : BinInt.Z) : ident (Z * Z -> Z * Z)
@@ -927,30 +924,34 @@ Module Compilers.
| fancy_sell : ident (Z * Z * Z -> Z)
| fancy_addm : ident (Z * Z * Z -> Z)
.
+ Notation Some := option_Some.
+ Notation None := option_None.
Global Arguments Z_cast2 _%zrange_scope.
- Definition to_fancy {s d : base.type} (idc : ident (s -> d)) : option (fancy.ident s d)
- := match idc in ident t return option match t with
+ Definition to_fancy {s d : base.type} (idc : ident (s -> d)) : Datatypes.option (fancy.ident s d)
+ := match idc in ident t return Datatypes.option match t with
| type.base s -> type.base d => fancy.ident s d
| _ => Datatypes.unit
end%etype with
- | fancy_add log2wordmax imm => Some (fancy.with_wordmax log2wordmax (fancy.add imm))
- | fancy_addc log2wordmax imm => Some (fancy.with_wordmax log2wordmax (fancy.addc imm))
- | fancy_sub log2wordmax imm => Some (fancy.with_wordmax log2wordmax (fancy.sub imm))
- | fancy_subb log2wordmax imm => Some (fancy.with_wordmax log2wordmax (fancy.subb imm))
- | fancy_mulll log2wordmax => Some (fancy.with_wordmax log2wordmax fancy.mulll)
- | fancy_mullh log2wordmax => Some (fancy.with_wordmax log2wordmax fancy.mullh)
- | fancy_mulhl log2wordmax => Some (fancy.with_wordmax log2wordmax fancy.mulhl)
- | fancy_mulhh log2wordmax => Some (fancy.with_wordmax log2wordmax fancy.mulhh)
- | fancy_rshi log2wordmax x => Some (fancy.with_wordmax log2wordmax (fancy.rshi x))
- | fancy_selc => Some fancy.selc
- | fancy_selm log2wordmax => Some (fancy.with_wordmax log2wordmax fancy.selm)
- | fancy_sell => Some fancy.sell
- | fancy_addm => Some fancy.addm
- | _ => None
+ | fancy_add log2wordmax imm => Datatypes.Some (fancy.with_wordmax log2wordmax (fancy.add imm))
+ | fancy_addc log2wordmax imm => Datatypes.Some (fancy.with_wordmax log2wordmax (fancy.addc imm))
+ | fancy_sub log2wordmax imm => Datatypes.Some (fancy.with_wordmax log2wordmax (fancy.sub imm))
+ | fancy_subb log2wordmax imm => Datatypes.Some (fancy.with_wordmax log2wordmax (fancy.subb imm))
+ | fancy_mulll log2wordmax => Datatypes.Some (fancy.with_wordmax log2wordmax fancy.mulll)
+ | fancy_mullh log2wordmax => Datatypes.Some (fancy.with_wordmax log2wordmax fancy.mullh)
+ | fancy_mulhl log2wordmax => Datatypes.Some (fancy.with_wordmax log2wordmax fancy.mulhl)
+ | fancy_mulhh log2wordmax => Datatypes.Some (fancy.with_wordmax log2wordmax fancy.mulhh)
+ | fancy_rshi log2wordmax x => Datatypes.Some (fancy.with_wordmax log2wordmax (fancy.rshi x))
+ | fancy_selc => Datatypes.Some fancy.selc
+ | fancy_selm log2wordmax => Datatypes.Some (fancy.with_wordmax log2wordmax fancy.selm)
+ | fancy_sell => Datatypes.Some fancy.sell
+ | fancy_addm => Datatypes.Some fancy.addm
+ | _ => Datatypes.None
end.
End with_scope.
+ Notation Some := option_Some.
+ Notation None := option_None.
Section gen.
Context (cast_outside_of_range : zrange -> BinInt.Z -> BinInt.Z).
@@ -1036,7 +1037,9 @@ Module Compilers.
| Z_modulo => Z.modulo
| Z_eqb => Z.eqb
| Z_leb => Z.leb
+ | Z_ltb => Z.ltb
| Z_geb => Z.geb
+ | Z_gtb => Z.gtb
| Z_log2 => Z.log2
| Z_log2_up => Z.log2_up
| Z_of_nat => Z.of_nat
@@ -1045,6 +1048,8 @@ Module Compilers.
| Z_shiftl => Z.shiftl
| Z_land => Z.land
| Z_lor => Z.lor
+ | Z_min => Z.min
+ | Z_max => Z.max
| Z_mul_split => Z.mul_split
| Z_add_get_carry => Z.add_get_carry_full
| Z_add_with_carry => Z.add_with_carry
@@ -1059,6 +1064,12 @@ Module Compilers.
| Z_cc_m => Z.cc_m
| Z_cast r => cast r
| Z_cast2 (r1, r2) => fun '(x1, x2) => (cast r1 x1, cast r2 x2)
+ | Some A => @Datatypes.Some _
+ | None A => @Datatypes.None _
+ | option_rect A P
+ => fun S_case N_case o => @Datatypes.option_rect _ _ S_case (N_case tt) o
+ | Build_zrange => ZRange.Build_zrange
+ | zrange_rect A => @ZRange.zrange_rect _
| fancy_add _ _ as idc
| fancy_addc _ _ as idc
| fancy_sub _ _ as idc
@@ -1079,6 +1090,8 @@ Module Compilers.
Definition cast_outside_of_range (r : zrange) (v : BinInt.Z) : BinInt.Z.
Proof. exact v. Qed.
End with_base.
+ Notation Some := option_Some.
+ Notation None := option_None.
(** Interpret identifiers where [Z_cast] is an opaque identity
function when the value is not inside the range *)
@@ -1087,6 +1100,7 @@ Module Compilers.
Notation LiteralZ := (@Literal base.type.Z).
Notation LiteralBool := (@Literal base.type.bool).
Notation LiteralNat := (@Literal base.type.nat).
+ Notation LiteralZRange := (@Literal base.type.zrange).
(** TODO: MOVE ME? *)
Module Thunked.
@@ -1098,6 +1112,8 @@ Module Compilers.
:= ListUtil.list_case (fun _ => P) (N tt) C ls.
Definition nat_rect P (O_case : unit -> P) (S_case : nat -> P -> P) (n : nat) : P
:= Datatypes.nat_rect (fun _ => P) (O_case tt) S_case n.
+ Definition option_rect {A} P (S_case : A -> P) (N_case : unit -> P) (o : option A) : P
+ := Datatypes.option_rect (fun _ => P) S_case (N_case tt) o.
End Thunked.
Ltac require_primitive_const term :=
@@ -1113,6 +1129,10 @@ Module Compilers.
| xI ?p => require_primitive_const p
| xO ?p => require_primitive_const p
| xH => idtac
+ | Datatypes.Some ?x => require_primitive_const x
+ | Datatypes.None => idtac
+ | ZRange.Build_zrange ?x ?y
+ => require_primitive_const x; require_primitive_const y
| ?term => fail 0 "Not a known const:" term
end.
Ltac is_primitive_const term :=
@@ -1175,6 +1195,16 @@ Module Compilers.
| @Thunked.bool_rect ?T
=> let rT := base.reify T in
then_tac (@ident.bool_rect rT)
+ | @Datatypes.option_rect ?A ?T0 ?PSome ?PNone
+ => lazymatch (eval cbv beta in T0) with
+ | fun _ => ?T => reify_rec (@Thunked.option_rect A T PSome (fun _ : Datatypes.unit => PNone))
+ | T0 => else_tac ()
+ | ?T' => reify_rec (@Datatypes.option_rect A T' PSome PNone)
+ end
+ | @Thunked.option_rect ?A ?T
+ => let rA := base.reify A in
+ let rT := base.reify T in
+ then_tac (@ident.option_rect rA rT)
| @Datatypes.prod_rect ?A ?B ?T0
=> lazymatch (eval cbv beta in T0) with
| fun _ => ?T
@@ -1185,6 +1215,14 @@ Module Compilers.
| T0 => else_tac ()
| ?T' => reify_rec (@Datatypes.prod_rect A B T')
end
+ | @ZRange.zrange_rect ?T0
+ => lazymatch (eval cbv beta in T0) with
+ | fun _ => ?T
+ => let rT := base.reify T in
+ then_tac (@ident.zrange_rect rT)
+ | T0 => else_tac ()
+ | ?T' => reify_rec (@ZRange.zrange_rect T')
+ end
| @Datatypes.nat_rect ?T0 ?P0
=> lazymatch (eval cbv beta in T0) with
| fun _ => _ -> _ => else_tac ()
@@ -1277,13 +1315,17 @@ Module Compilers.
| Z.modulo => then_tac ident.Z_modulo
| Z.eqb => then_tac ident.Z_eqb
| Z.leb => then_tac ident.Z_leb
+ | Z.ltb => then_tac ident.Z_ltb
| Z.geb => then_tac ident.Z_geb
+ | Z.gtb => then_tac ident.Z_gtb
| Z.log2 => then_tac ident.Z_log2
| Z.log2_up => then_tac ident.Z_log2_up
| Z.shiftl => then_tac ident.Z_shiftl
| Z.shiftr => then_tac ident.Z_shiftr
| Z.land => then_tac ident.Z_land
| Z.lor => then_tac ident.Z_lor
+ | Z.min => then_tac ident.Z_min
+ | Z.max => then_tac ident.Z_max
| Z.bneg => then_tac ident.Z_bneg
| Z.lnot_modulo => then_tac ident.Z_lnot_modulo
| Z.of_nat => then_tac ident.Z_of_nat
@@ -1298,6 +1340,14 @@ Module Compilers.
| Z.add_modulo => then_tac ident.Z_add_modulo
| Z.rshi => then_tac ident.Z_rshi
| Z.cc_m => then_tac ident.Z_cc_m
+ | ident.cast _ => then_tac ident.Z_cast
+ | @Some ?A
+ => let rA := base.reify A in
+ then_tac (@ident.Some rA)
+ | @None ?A
+ => let rA := base.reify A in
+ then_tac (@ident.None rA)
+ | ZRange.Build_zrange => then_tac ident.Build_zrange
| _ => else_tac ()
end
end.
@@ -1309,6 +1359,13 @@ Module Compilers.
(fun x _ xs => expr.Ident ident.cons @ x @ xs)%expr
ls.
+ Definition reify_option {var} {t} (v : option (@expr.expr base.type ident var (type.base t))) : @expr.expr base.type ident var (type.base (base.type.option t))
+ := Datatypes.option_rect
+ (fun _ => _)
+ (fun x => expr.Ident ident.Some @ x)%expr
+ (expr.Ident ident.None)
+ v.
+
Fixpoint smart_Literal {var} {t:base.type} : base.interp t -> @expr.expr base.type ident var (type.base t)
:= match t with
| base.type.type_base t => fun v => expr.Ident (ident.Literal v)
@@ -1318,6 +1375,9 @@ Module Compilers.
| base.type.list A
=> fun v : list (base.interp A)
=> reify_list (List.map (@smart_Literal var A) v)
+ | base.type.option A
+ => fun v : option (base.interp A)
+ => reify_option (option_map (@smart_Literal var A) v)
end%expr.
Module Export Notations.
@@ -1354,7 +1414,7 @@ Module Compilers.
Notation "x 'mod' y" := (#Z_modulo @ x @ y)%expr : expr_scope.
Notation "- x" := (#Z_opp @ x)%expr : expr_scope.
Global Arguments gen_interp _ _ !_.
-
+ Global Arguments ident.Z_cast _%zrange_scope.
Global Arguments ident.Z_cast2 _%zrange_scope.
End Notations.
End ident.
@@ -1519,6 +1579,42 @@ Module Compilers.
| Some (ident.nil _) => true
| _ => false
end.
+ Definition invert_None {t} (e : expr (base.type.option t)) : bool
+ := match invert_Ident e with
+ | Some (ident.None _) => true
+ | _ => false
+ end.
+ Local Notation if_arrow f t
+ := (match t return Type with
+ | (a -> b)%etype => f a b
+ | _ => unit
+ end) (only parsing).
+ Definition invert_Some {t} (e : expr (base.type.option t))
+ : option (expr t)
+ := match invert_AppIdent e with
+ | Some (existT s (idc, e))
+ => match idc in ident.ident t
+ return if_arrow (fun a b => expr a) t
+ -> option match t return Type with
+ | (a -> type.base (base.type.option t))
+ => expr t
+ | _ => unit
+ end%etype
+ with
+ | ident.Some _ => fun x => Some x
+ | _ => fun _ => None
+ end e
+ | None => None
+ end.
+
+ Definition reflect_option {t} (e : expr (base.type.option t))
+ : option (option (expr t))
+ := match invert_None e, invert_Some e with
+ | true, _ => Some None
+ | _, Some x => Some (Some x)
+ | false, None => None
+ end.
+
Local Notation if_arrow2 f t
:= (match t return Type with
| (a -> b -> c)%etype => f a b c
@@ -1590,9 +1686,11 @@ Module Compilers.
| base.type.Z => (-1)%Z
| base.type.nat => 0%nat
| base.type.bool => true
+ | base.type.zrange => r[0~>0]%zrange
| base.type.list _ => nil
| base.type.prod A B
=> (@default A, @default B)
+ | base.type.option A => None
end.
End base.
Fixpoint default {t} : type.interp base.interp t
@@ -1611,10 +1709,12 @@ Module Compilers.
| base.type.prod A B
=> (@default A, @default B)
| base.type.list A => #ident.nil
+ | base.type.option A => #ident.None
| base.type.unit as t
| base.type.Z as t
| base.type.nat as t
| base.type.bool as t
+ | base.type.zrange as t
=> ##(@type.base.default t)
end%expr.
End with_var.
@@ -1649,6 +1749,7 @@ Module Compilers.
End defaults.
Notation reify_list := ident.reify_list.
+ Notation reify_option := ident.reify_option.
Module GallinaReify.
Module base.
@@ -1663,10 +1764,14 @@ Module Compilers.
| base.type.list A as t
=> fun x : list (base.interp A)
=> reify_list (List.map (@reify A) x)
+ | base.type.option A as t
+ => fun x : option (base.interp A)
+ => reify_option (option_map (@reify A) x)
| base.type.unit as t
| base.type.Z as t
| base.type.bool as t
| base.type.nat as t
+ | base.type.zrange as t
=> fun x : base.interp t
=> (##x)%expr
end.