diff options
Diffstat (limited to 'src/AbstractInterpretation.v')
-rw-r--r-- | src/AbstractInterpretation.v | 1089 |
1 files changed, 1089 insertions, 0 deletions
diff --git a/src/AbstractInterpretation.v b/src/AbstractInterpretation.v new file mode 100644 index 000000000..463cc72cd --- /dev/null +++ b/src/AbstractInterpretation.v @@ -0,0 +1,1089 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.ListUtil.FoldBool. +Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ZRange.Operations. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.OptionList. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Language. +Require Import Crypto.UnderLets. +Import ListNotations. Local Open Scope bool_scope. Local Open Scope Z_scope. + +Module Compilers. + Export Language.Compilers. + Export UnderLets.Compilers. + Import invert_expr. + + Module ZRange. + Module type. + Local Notation binterp := base.interp. + Local Notation tinterp_gen := type.interp. + Local Notation einterp := (type.interp base.interp). + Module base. + (** turn a [base.type] into a [Set] describing the type of + bounds on that primitive; Z is a range, nat and bool are exact values *) + Fixpoint interp (t : base.type) : Set + := match t with + | base.type.Z => zrange + | base.type.unit as t + | base.type.nat as t + | base.type.bool as t + => base.interp t + | base.type.prod A B => interp A * interp B + | base.type.list A => list (interp A) + end%type. + Definition is_neg {t} : interp t -> bool + := match t with + | base.type.Z => fun r => (lower r <? 0) && (upper r <=? 0) + | _ => fun _ => false + end. + Fixpoint is_tighter_than {t} : interp t -> interp t -> bool + := match t with + | base.type.Z => is_tighter_than_bool + | base.type.nat => Nat.eqb + | base.type.unit => fun _ _ => true + | base.type.bool => bool_eq + | base.type.prod A B + => fun '(a, b) '(a', b') + => @is_tighter_than A a a' && @is_tighter_than B b b' + | base.type.list A + => fold_andb_map (@is_tighter_than A) + end%bool. + Fixpoint is_bounded_by {t} : interp t -> binterp t -> bool + := match t with + | base.type.Z => fun r z => ZRange.is_bounded_by_bool z r + | base.type.nat => Nat.eqb + | base.type.unit => fun _ _ => true + | base.type.bool => bool_eq + | base.type.prod A B + => fun '(a, b) '(a', b') + => @is_bounded_by A a a' && @is_bounded_by B b b' + | base.type.list A + => fold_andb_map (@is_bounded_by A) + end. + Module option. + (** turn a [base.type] into a [Set] describing the type + of optional bounds on that primitive; bounds on a [Z] + may be either a range, or [None], generally indicating + that the [Z] is unbounded. *) + Fixpoint interp (t : base.type) : Set + := match t with + | base.type.Z => option zrange + | base.type.unit => unit + | base.type.nat as t + | base.type.bool as t + => option (base.interp t) + | base.type.prod A B => interp A * interp B + | base.type.list A => option (list (interp A)) + end%type. + Fixpoint None {t} : interp t + := match t with + | base.type.unit => tt + | base.type.list _ + | base.type.Z + | base.type.nat + | base.type.bool + => Datatypes.None + | base.type.prod A B + => (@None A, @None B) + end. + Fixpoint Some {t} : base.interp t -> interp t + := match t with + | base.type.Z + | base.type.nat + | base.type.bool + => Datatypes.Some + | base.type.list A + => fun ls => Datatypes.Some (List.map (@Some A) ls) + | base.type.prod A B + => fun '(a, b) + => (@Some A a, @Some B b) + | _ => fun _ => tt + end. + Fixpoint lift_Some {t} : interp t -> option (base.interp t) + := match t with + | base.type.Z + | base.type.nat + | base.type.bool + => fun x => x + | base.type.unit + => fun x => Datatypes.Some tt + | base.type.list A + => fun ls => ls <- ls; ls <-- List.map (@lift_Some A) ls; Datatypes.Some ls + | base.type.prod A B + => fun '(a, b) => a <- @lift_Some A a; b <- @lift_Some B b; Datatypes.Some (a, b) + end%option. + (** Keep data about list length and nat value, but not zrange *) + Fixpoint strip_ranges {t} : interp t -> interp t + := match t with + | base.type.Z => fun _ => Datatypes.None + | base.type.nat + | base.type.bool + | base.type.unit + => fun x => x + | base.type.list A + => fun ls => ls <- ls; Datatypes.Some (List.map (@strip_ranges A) ls) + | base.type.prod A B + => fun '(a, b) + => (@strip_ranges A a, @strip_ranges B b) + end%option. + Definition is_neg {t} : interp t -> bool + := match t with + | base.type.Z + => fun v => match v with + | Datatypes.Some v => @is_neg base.type.Z v + | Datatypes.None => false + end + | t => fun _ => false + end. + Fixpoint is_tighter_than {t} : interp t -> interp t -> bool + := match t with + | base.type.Z as t + | base.type.nat as t + | base.type.bool as t + => fun r1 r2 + => match r1, r2 with + | _, Datatypes.None => true + | Datatypes.None, Datatypes.Some _ => false + | Datatypes.Some r1, Datatypes.Some r2 => base.is_tighter_than (t:=t) r1 r2 + end + | base.type.prod A B + => fun '(a, b) '(a', b') + => @is_tighter_than A a a' && @is_tighter_than B b b' + | base.type.list A + => fun ls1 ls2 + => match ls1, ls2 with + | Datatypes.None, Datatypes.None => true + | Datatypes.Some _, Datatypes.None => true + | Datatypes.None, Datatypes.Some _ => false + | Datatypes.Some ls1, Datatypes.Some ls2 => fold_andb_map (@is_tighter_than A) ls1 ls2 + end + | _ => fun 'tt 'tt => true + end. + Fixpoint is_bounded_by {t} : interp t -> binterp t -> bool + := match t with + | base.type.Z as t + | base.type.nat as t + | base.type.bool as t + => fun r + => match r with + | Datatypes.Some r => @base.is_bounded_by t r + | Datatypes.None => fun _ => true + end + | base.type.prod A B + => fun '(a, b) '(a', b') + => @is_bounded_by A a a' && @is_bounded_by B b b' + | base.type.list A + => fun ls1 ls2 + => match ls1 with + | Datatypes.None => true + | Datatypes.Some ls1 => fold_andb_map (@is_bounded_by A) ls1 ls2 + end + | _ => fun 'tt _ => true + end. + + Lemma is_bounded_by_Some {t} r val + : is_bounded_by (@Some t r) val = base.is_bounded_by r val. + Proof. + induction t; + repeat first [ reflexivity + | progress cbn in * + | progress destruct_head'_prod + | progress destruct_head' base.type.base + | rewrite fold_andb_map_map1 + | match goal with H : _ |- _ => rewrite H end + | match goal with H : _ |- _ => setoid_rewrite H end ]. + Qed. + + Lemma is_tighter_than_is_bounded_by {t} r1 r2 val + (Htight : @is_tighter_than t r1 r2 = true) + (Hbounds : is_bounded_by r1 val = true) + : is_bounded_by r2 val = true. + Proof. + induction t; + repeat first [ progress destruct_head'_prod + | progress destruct_head'_and + | progress destruct_head'_unit + | progress cbn in * + | progress destruct_head' option + | solve [ eauto with nocore ] + | progress cbv [ZRange.is_bounded_by_bool is_tighter_than_bool] in * + | progress rewrite ?Bool.andb_true_iff in * + | discriminate + | apply conj + | Z.ltb_to_lt; omega + | progress break_innermost_match_hyps + | progress subst + | rewrite NPeano.Nat.eqb_refl + | reflexivity + | match goal with + | [ H : Nat.eqb _ _ = true |- _ ] => apply beq_nat_true in H + | [ H : bool_eq _ _ = true |- _ ] => apply bool_eq_ok in H + | [ |- bool_eq ?x ?x = true ] => destruct x; reflexivity + end ]. + { lazymatch goal with + | [ r1 : list (interp t), r2 : list (interp t), val : list (binterp t) |- _ ] + => revert r1 r2 val Htight Hbounds IHt + end; intros r1 r2 val; revert r1 r2 val. + induction r1, r2, val; cbn; auto with nocore; try congruence; []. + rewrite !Bool.andb_true_iff; intros; destruct_head'_and; split; eauto with nocore. } + Qed. + + Lemma is_tighter_than_Some_is_bounded_by {t} r1 r2 val + (Htight : @is_tighter_than t r1 (Some r2) = true) + (Hbounds : is_bounded_by r1 val = true) + : base.is_bounded_by r2 val = true. + Proof. + rewrite <- is_bounded_by_Some. + eapply is_tighter_than_is_bounded_by; eassumption. + Qed. + End option. + End base. + + (** turn a [type] into a [Set] describing the type of bounds on + that type; this lifts [base.interp] from + [type.base] to [type] *) + Definition interp (t : type base.type) + := type.interp base.interp t. + Fixpoint is_tighter_than {t} : interp t -> interp t -> bool + := match t with + | type.base x => @base.is_tighter_than x + | type.arrow s d => fun _ _ => false + end. + Fixpoint is_bounded_by {t} : interp t -> einterp t -> bool + := match t return interp t -> einterp t -> bool with + | type.base x => @base.is_bounded_by x + | type.arrow s d => fun _ _ => false + end. + Module option. + (** turn a [type] into a [Set] describing the type of optional + bounds on that base type; bounds on a [Z] may be either a + range, or [None], generally indicating that the [Z] is + unbounded. This lifts [base.option.interp] from + [base.type] to [type] *) + Definition interp (t : type base.type) + := tinterp_gen base.option.interp t. + Fixpoint None {t : type base.type} : interp t + := match t with + | type.base x => @base.option.None x + | type.arrow s d => fun _ => @None d + end. + Fixpoint Some {t : type base.type} : type.interp t -> interp t + := match t with + | type.base x => @base.option.Some x + | type.arrow s d => fun _ _ => @None d + end. + Fixpoint strip_ranges {t : type base.type} : interp t -> interp t + := match t with + | type.base x => @base.option.strip_ranges x + | type.arrow s d => fun f x => @strip_ranges d (f x) + end. + Fixpoint is_tighter_than {t} : interp t -> interp t -> bool + := match t with + | type.base x => @base.option.is_tighter_than x + | type.arrow s d => fun _ _ => false + end. + Fixpoint is_bounded_by {t} : interp t -> einterp t -> bool + := match t with + | type.base x => @base.option.is_bounded_by x + | type.arrow s d => fun _ _ => false + end. + + Lemma is_bounded_by_Some {t} r val + : is_bounded_by (@Some t r) val = type.is_bounded_by r val. + Proof. + induction t; [ apply base.option.is_bounded_by_Some | reflexivity ]. + Qed. + + Lemma is_tighter_than_is_bounded_by {t} r1 r2 val + (Htight : @is_tighter_than t r1 r2 = true) + (Hbounds : is_bounded_by r1 val = true) + : is_bounded_by r2 val = true. + Proof. + induction t; cbn in *; + eauto using base.option.is_tighter_than_is_bounded_by. + Qed. + + Lemma is_tighter_than_Some_is_bounded_by {t} r1 r2 val + (Htight : @is_tighter_than t r1 (Some r2) = true) + (Hbounds : is_bounded_by r1 val = true) + : type.is_bounded_by r2 val = true. + Proof. + rewrite <- is_bounded_by_Some. + eapply is_tighter_than_is_bounded_by; eassumption. + Qed. + End option. + End type. + + Module ident. + Module option. + Local Open Scope zrange_scope. + + Fixpoint of_literal {t} : base.interp t -> type.base.option.interp t + := match t with + | base.type.Z => fun z => Some r[z~>z]%zrange + | base.type.nat + | base.type.bool + => fun n => Some n + | base.type.unit + => fun _ => tt + | base.type.prod A B + => fun '(a, b) => (@of_literal A a, @of_literal B b) + | base.type.list A + => fun ls => Some (List.map (@of_literal A) ls) + end. + Fixpoint to_literal {t} : type.base.option.interp t -> option (base.interp t) + := match t with + | base.type.Z => fun r => r <- r; if r.(lower) =? r.(upper) then Some r.(lower) else None + | base.type.nat + | base.type.bool + => fun v => v + | base.type.unit + => fun _ => Some tt + | base.type.prod A B + => fun '(a, b) => a <- @to_literal A a; b <- @to_literal B b; Some (a, b) + | base.type.list A + => fun ls => ls <- ls; fold_right (fun x xs => x <- x; xs <- xs; Some (x :: xs)) + (Some nil) + (List.map (@to_literal A) ls) + end%option%Z. + Local Notation rSome v + := (ZRange.type.base.option.Some (t:=base.reify_norm_type_of v) v) + (only parsing). + (** do bounds analysis on identifiers; take in optional bounds + on arguments, return optional bounds on outputs. *) + (** Casts are like assertions; we only guarantee anything when they're true *) + Definition interp_Z_cast (r : zrange) (v : option zrange) : option zrange + := match v with + | Some v => if is_tighter_than_bool v r (* the value is definitely inside the range *) + then Some v + else None + | None => None + end. + Definition interp {t} (idc : ident t) : type.option.interp t + := match idc in ident.ident t return type.option.interp t with + | ident.Literal _ v => of_literal v + | ident.Nat_succ as idc + | ident.Nat_pred as idc + => option_map (ident.interp idc) + | ident.Z_of_nat as idc + => option_map (fun n => r[Z.of_nat n~>Z.of_nat n]%zrange) + | ident.Z_to_nat as idc + => fun v => v <- to_literal v; Some (ident.interp idc v) + | ident.List_length _ + => option_map (@List.length _) + | ident.Nat_max as idc + | ident.Nat_mul as idc + | ident.Nat_add as idc + | ident.Nat_sub as idc + | ident.List_seq as idc + => fun x y => x <- x; y <- y; rSome (ident.interp idc x y) + | ident.List_repeat _ + => fun x y => y <- y; Some (repeat x y) + | ident.List_firstn _ + => fun n ls => n <- n; ls <- ls; Some (firstn n ls) + | ident.List_skipn _ + => fun n ls => n <- n; ls <- ls; Some (skipn n ls) + | ident.List_combine _ _ + => fun x y => x <- x; y <- y; Some (List.combine x y) + | ident.List_flat_map _ _ + => fun f ls + => (ls <- ls; + let fls := List.map f ls in + List.fold_right + (fun ls1 ls2 => ls1 <- ls1; ls2 <- ls2; Some (ls1 ++ ls2)) + (Some nil) + fls) + | ident.List_partition _ + => fun f ls + => match ls with + | Some ls + => list_rect + _ + (Some nil, Some nil) + (fun x tl partition_tl + => let '(g, d) := partition_tl in + ((fx <- f x; + if fx then (g <- g; Some (x::g)) else g), + (fx <- f x; + if fx then d else (d <- d; Some (x::d))))) + ls + | None => (None, None) + end + | ident.Z_eqb as idc + | ident.Z_leb as idc + | ident.Z_geb as idc + | ident.Z_pow as idc + | ident.Z_modulo as idc + => fun x y => match to_literal x, to_literal y with + | Some x, Some y => of_literal (ident.interp idc x y) + | _, _ => ZRange.type.base.option.None + end + | ident.Z_bneg as idc + => fun x => match to_literal x with + | Some x => of_literal (ident.interp idc x) + | None => Datatypes.Some r[0~>1] + end + | ident.Z_lnot_modulo as idc + => fun v m + => match to_literal m, to_literal v with + | Some m, Some v => of_literal (ident.interp idc v m) + | Some m, None => Some (if (0 <? m)%Z + then r[0 ~> m-1] + else if (m =? 0)%Z + then r[0 ~> 0] + else r[m+1 ~> 0]) + | _, _ => None + end + | ident.bool_rect _ + => fun t f b + => match b with + | Some b => if b then t tt else f tt + | None => ZRange.type.base.option.None + end + | ident.nat_rect _ + => fun O_case S_case n + => match n with + | Some n + => nat_rect + _ + (O_case tt) + (fun n' rec => S_case (Some n') rec) + n + | None => ZRange.type.base.option.None + end + | ident.nat_rect_arrow _ _ + => fun O_case S_case n v + => match n with + | Some n + => nat_rect + _ + O_case + (fun n' rec => S_case (Some n') rec) + n + v + | None => ZRange.type.base.option.None + end + | ident.list_rect _ _ + => fun N C ls + => match ls with + | Some ls + => list_rect + _ + (N tt) + (fun x xs rec => C x (Some xs) rec) + ls + | None => ZRange.type.base.option.None + end + | ident.list_case _ _ + => fun N C ls + => match ls with + | Some ls + => list_case + _ + (N tt) + (fun x xs => C x (Some xs)) + ls + | None => ZRange.type.base.option.None + end + | ident.List_fold_right _ _ + => fun f v ls + => match ls with + | Some ls + => fold_right f v ls + | None => ZRange.type.base.option.None + end + | ident.List_nth_default _ + => fun d ls n + => match ls, n with + | Some ls, Some n + => nth_default d ls n + | _, _ => ZRange.type.base.option.None + end + | ident.List_update_nth _ + => fun n f ls => ls <- ls; n <- n; Some (update_nth n f ls) + | ident.nil t => Some nil + | ident.cons t => fun x => option_map (cons x) + | ident.pair A B => pair + | ident.fst A B => fst + | ident.snd A B => snd + | ident.prod_rect A B P => fun f '(a, b) => f a b + | ident.List_map _ _ + => fun f ls => ls <- ls; Some (List.map f ls) + | ident.List_app _ + => fun ls1 ls2 => ls1 <- ls1; ls2 <- ls2; Some (List.app ls1 ls2) + | ident.List_rev _ => option_map (@List.rev _) + | ident.Z_opp as idc + | ident.Z_log2 as idc + | ident.Z_log2_up as idc + => fun x => x <- x; Some (ZRange.two_corners (ident.interp idc) x) + | ident.Z_add as idc + | ident.Z_mul as idc + | ident.Z_sub as idc + => fun x y => x <- x; y <- y; Some (ZRange.four_corners (ident.interp idc) x y) + | ident.Z_div as idc + | ident.Z_shiftr as idc + | ident.Z_shiftl as idc + => fun x y => x <- x; y <- y; Some (ZRange.four_corners_and_zero (ident.interp idc) x y) + | ident.Z_add_with_carry as idc + => fun x y z => x <- x; y <- y; z <- z; Some (ZRange.eight_corners (ident.interp idc) x y z) + | ident.Z_cc_m as idc + => fun s x => s <- to_literal s; x <- x; Some (ZRange.two_corners (ident.interp idc s) x) + | ident.Z_rshi as idc + => fun s x y offset + => s <- to_literal s; x <- x; y <- y; offset <- to_literal offset; + if (0 <? s) then Some r[0~>s-1] else None + | ident.Z_land + => fun x y => x <- x; y <- y; Some (ZRange.land_bounds x y) + | ident.Z_lor + => fun x y => x <- x; y <- y; Some (ZRange.lor_bounds x y) + | ident.Z_mul_split + => fun split_at x y + => match to_literal split_at, x, y with + | Some split_at, Some x, Some y + => ZRange.type.base.option.Some + (t:=base.type.Z*base.type.Z) + (ZRange.split_bounds (ZRange.four_corners Z.mul x y) split_at) + | _, _, _ => ZRange.type.base.option.None + end + | ident.Z_add_get_carry + => fun split_at x y + => match to_literal split_at, x, y with + | Some split_at, Some x, Some y + => ZRange.type.base.option.Some + (t:=base.type.Z*base.type.Z) + (ZRange.split_bounds (ZRange.four_corners Z.add x y) split_at) + | _, _, _ => ZRange.type.base.option.None + end + | ident.Z_add_with_get_carry + => fun split_at x y z + => match to_literal split_at, x, y, z with + | Some split_at, Some x, Some y, Some z + => ZRange.type.base.option.Some + (t:=base.type.Z*base.type.Z) + (ZRange.split_bounds + (ZRange.eight_corners (fun x y z => (x + y + z)%Z) x y z) + split_at) + | _, _, _, _ => ZRange.type.base.option.None + end + | ident.Z_sub_get_borrow + => fun split_at x y + => match to_literal split_at, x, y with + | Some split_at, Some x, Some y + => ZRange.type.base.option.Some + (t:=base.type.Z*base.type.Z) + (let b := ZRange.split_bounds (ZRange.four_corners BinInt.Z.sub x y) split_at in + (* N.B. sub_get_borrow returns - ((x - y) / split_at) as the borrow, so we need to negate *) + (fst b, ZRange.opp (snd b))) + | _, _, _ => ZRange.type.base.option.None + end + | ident.Z_sub_with_get_borrow + => fun split_at x y z + => match to_literal split_at, x, y, z with + | Some split_at, Some x, Some y, Some z + => ZRange.type.base.option.Some + (t:=base.type.Z*base.type.Z) + (let b := ZRange.split_bounds (ZRange.eight_corners (fun x y z => (y - z - x)%Z) x y z) split_at in + (* N.B. sub_get_borrow returns - ((x - y) / split_at) as the borrow, so we need to negate *) + (fst b, ZRange.opp (snd b))) + | _, _, _, _ => ZRange.type.base.option.None + end + | ident.Z_zselect + => fun _ y z => y <- y; z <- z; Some (ZRange.union y z) + | ident.Z_add_modulo + => fun x y m + => (x <- x; + y <- y; + m <- m; + Some (ZRange.union + (ZRange.four_corners Z.add x y) + (ZRange.eight_corners (fun x y m => Z.max 0 (x + y - m)) + x y m))) + | ident.Z_cast range + => fun r : option zrange + => interp_Z_cast range r + | ident.Z_cast2 (r1, r2) + => fun '((r1', r2') : option zrange * option zrange) + => (interp_Z_cast r1 r1', interp_Z_cast r2 r2') + (** TODO(jadep): fill in fancy bounds analysis rules *) + | ident.fancy_add log2wordmax _ + | ident.fancy_sub log2wordmax _ + => let wordmax := 2^log2wordmax in + let r := r[0~>wordmax-1] in + fun args + => if ZRange.type.base.option.is_tighter_than args (Some r, Some r) + then (Some r, Some r[0~>1]) + else ZRange.type.base.option.None + | ident.fancy_addc log2wordmax _ + | ident.fancy_subb log2wordmax _ + => let wordmax := 2^log2wordmax in + let r := r[0~>wordmax-1] in + fun args + => if ZRange.type.base.option.is_tighter_than args (Some r[0~>1], Some r, Some r) + then (Some r, Some r[0~>1]) + else ZRange.type.base.option.None + | ident.fancy_mulll log2wordmax + | ident.fancy_mullh log2wordmax + | ident.fancy_mulhl log2wordmax + | ident.fancy_mulhh log2wordmax + => let wordmax := 2^log2wordmax in + let r := r[0~>wordmax-1] in + fun args + => if ZRange.type.base.option.is_tighter_than args (Some r, Some r) + then if (Z.eqb (log2wordmax mod 2) 0) + then Some r + else ZRange.type.base.option.None + else ZRange.type.base.option.None + | ident.fancy_rshi log2wordmax n as idc + => let wordmax := 2^log2wordmax in + let r := r[0~>wordmax-1] in + let r_nbits := r[0~>2^n-1] in + fun args + => + if (0 <=? log2wordmax)%Z + then if (ZRange.type.base.option.is_tighter_than args (Some r_nbits, Some r) && (0 <=? n)%Z) + then + hi_range <- fst args; + lo_range <- snd args; + Some (ZRange.four_corners (fun x y => ident.interp idc (x, y)) hi_range lo_range) + else if ZRange.type.base.option.is_tighter_than args (Some r, Some r) + then Some r + else ZRange.type.base.option.None + else ZRange.type.base.option.None + | ident.fancy_selm _ + | ident.fancy_selc + | ident.fancy_sell + => fun '(_, y, z) => y <- y; z <- z; Some (ZRange.union y z) + | ident.fancy_addm + => fun '(x, y, m) + => (x <- x; + y <- y; + m <- m; + Some (ZRange.union + (ZRange.four_corners Z.add x y) + (ZRange.eight_corners (fun x y m => Z.max 0 (x + y - m)) + x y m))) + end%option. + End option. + End ident. + End ZRange. + + (** XXX TODO: Do we still need to do UnderLets here? *) + Module partial. + Import UnderLets. + Section with_var. + Context {base_type : Type}. + Local Notation type := (type base_type). + Let type_base (x : base_type) : type := type.base x. + Local Coercion type_base : base_type >-> type. + Context {ident : type -> Type} + {var : type -> Type}. + Local Notation expr := (@expr base_type ident). + Local Notation UnderLets := (@UnderLets base_type ident var). + Context (abstract_domain' : base_type -> Type) + (annotate : forall (is_let_bound : bool) t, abstract_domain' t -> @expr var t -> UnderLets (@expr var t)) + (bottom' : forall A, abstract_domain' A) + (abstract_interp_ident : forall t, ident t -> type.interp abstract_domain' t). + + Definition abstract_domain (t : type) + := type.interp abstract_domain' t. + + Fixpoint value (t : type) + := match t return Type (* COQBUG(https://github.com/coq/coq/issues/7727) *) with + | type.base t + => abstract_domain t * @expr var t + | type.arrow s d + => value s -> UnderLets (value d) + end%type. + + Definition value_with_lets (t : type) + := UnderLets (value t). + + Context (interp_ident : forall t, ident t -> value_with_lets t). + + Fixpoint bottom {t} : abstract_domain t + := match t with + | type.base t => bottom' t + | type.arrow s d => fun _ => @bottom d + end. + + Fixpoint bottom_for_each_lhs_of_arrow {t} : type.for_each_lhs_of_arrow abstract_domain t + := match t return type.for_each_lhs_of_arrow abstract_domain t with + | type.base t => tt + | type.arrow s d => (bottom, @bottom_for_each_lhs_of_arrow d) + end. + + Definition state_of_value {t} : value t -> abstract_domain t + := match t return value t -> abstract_domain t with + | type.base t => fun '(st, v) => st + | type.arrow s d => fun _ => bottom + end. + + (** We need to make sure that we ignore the state of + higher-order arrows *everywhere*, or else the proofs don't go + through. So we sometimes need to replace the state of + arrow-typed values with [⊥]. *) + Fixpoint bottomify {t} : value t -> value_with_lets t + := match t return value t -> value_with_lets t with + | type.base t => fun '(st, v) => Base (bottom' t, v) + | type.arrow s d => fun f => Base (fun x => fx <-- f x; @bottomify d fx) + end%under_lets. + + (** We drop the state of higher-order arrows *) + Fixpoint reify (is_let_bound : bool) {t} : value t -> type.for_each_lhs_of_arrow abstract_domain t -> UnderLets (@expr var t) + := match t return value t -> type.for_each_lhs_of_arrow abstract_domain t -> UnderLets (@expr var t) with + | type.base t + => fun '(st, v) 'tt + => annotate is_let_bound t st v + | type.arrow s d + => fun f_e '(sv, dv) + => let sv := match s with + | type.base _ => sv + | type.arrow _ _ => bottom + end in + Base + (λ x , (UnderLets.to_expr + (fx <-- f_e (@reflect _ (expr.Var x) sv); + @reify false _ fx dv))) + end%core%expr + with reflect {t} : @expr var t -> abstract_domain t -> value t + := match t return @expr var t -> abstract_domain t -> value t with + | type.base t + => fun e st => (st, e) + | type.arrow s d + => fun e absf + => (fun v + => let stv := state_of_value v in + (rv <-- (@reify false s v bottom_for_each_lhs_of_arrow); + Base (@reflect d (e @ rv) (absf stv))%expr)) + end%under_lets. + + Fixpoint interp {t} (e : @expr value_with_lets t) : value_with_lets t + := match e in expr.expr t return value_with_lets t with + | expr.Ident t idc => interp_ident _ idc (* Base (reflect (###idc) (abstract_interp_ident _ idc))*) + | expr.Var t v => v + | expr.Abs s d f => Base (fun x => @interp d (f (Base x))) + | expr.App (type.base s) d f x + => (x' <-- @interp _ x; + f' <-- @interp (_ -> d)%etype f; + f' x') + | expr.App (type.arrow s' d') d f x + => (x' <-- @interp (s' -> d')%etype x; + x'' <-- bottomify x'; + f' <-- @interp (_ -> d)%etype f; + f' x'') + | expr.LetIn (type.arrow _ _) B x f + => (x' <-- @interp _ x; + @interp _ (f (Base x'))) + | expr.LetIn (type.base A) B x f + => (x' <-- @interp _ x; + x'' <-- reify true (* this forces a let-binder here *) x' tt; + @interp _ (f (Base (reflect x'' (state_of_value x'))))) + end%under_lets. + + Definition eval_with_bound' {t} (e : @expr value_with_lets t) + (st : type.for_each_lhs_of_arrow abstract_domain t) + : expr t + := UnderLets.to_expr (e' <-- interp e; reify false e' st). + + Definition eval' {t} (e : @expr value_with_lets t) : expr t + := eval_with_bound' e bottom_for_each_lhs_of_arrow. + + Definition eta_expand_with_bound' {t} (e : @expr var t) + (st : type.for_each_lhs_of_arrow abstract_domain t) + : expr t + := UnderLets.to_expr (reify false (reflect e bottom) st). + + Section extract. + Context (ident_extract : forall t, ident t -> abstract_domain t). + + (** like [expr.interp (@ident_extract) e], except we replace + all higher-order state with bottom *) + Fixpoint extract' {t} (e : @expr abstract_domain t) : abstract_domain t + := match e in expr.expr t return abstract_domain t with + | expr.Ident t idc => ident_extract t idc + | expr.Var t v => v + | expr.Abs s d f => fun v : abstract_domain s => @extract' _ (f v) + | expr.App (type.base s) d f x + => @extract' _ f (@extract' _ x) + | expr.App (type.arrow s' d') d f x + => @extract' _ f (@bottom (type.arrow s' d')) + | expr.LetIn A B x f => dlet y := @extract' _ x in @extract' _ (f y) + end. + + Definition extract_gen {t} (e : @expr abstract_domain t) (bound : type.for_each_lhs_of_arrow abstract_domain t) + : abstract_domain' (type.final_codomain t) + := type.app_curried (extract' e) bound. + End extract. + End with_var. + + Module ident. + Section with_var. + Local Notation type := (type base.type). + Let type_base (x : base.type) : type := type.base x. + Local Coercion type_base : base.type >-> type. + Context {var : type -> Type}. + Local Notation expr := (@expr base.type ident). + Local Notation UnderLets := (@UnderLets base.type ident var). + Context (abstract_domain' : base.type -> Type). + Local Notation abstract_domain := (@abstract_domain base.type abstract_domain'). + Context (annotate_ident : forall t, abstract_domain' t -> option (ident (t -> t))) + (bottom' : forall A, abstract_domain' A) + (abstract_interp_ident : forall t, ident t -> type.interp abstract_domain' t) + (update_literal_with_state : forall A : base.type.base, abstract_domain' A -> base.interp A -> base.interp A) + (extract_list_state : forall A, abstract_domain' (base.type.list A) -> option (list (abstract_domain' A))) + (is_annotated_for : forall t t', ident t -> abstract_domain' t' -> bool). + + (** TODO: Is it okay to commute annotations? *) + Definition update_annotation {t} (st : abstract_domain' t) (e : @expr var t) : @expr var t + := match e, annotate_ident _ st with + | (#cst' @ e'), Some cst + => if is_annotated_for _ _ cst' st + then e + else ###cst @ e + | _, Some cst => ###cst @ e + | _, None => e + end%expr_pat%expr. + + Definition annotate_with_ident (is_let_bound : bool) {t} + (st : abstract_domain' t) (e : @expr var t) + : UnderLets (@expr var t) + := let cst_e := update_annotation st e (*match annotate_ident _ st with + | Some cst => ###cst @ e + | None => e + end%expr*) in + if is_let_bound + then UnderLet cst_e (fun v => Base ($v)%expr) + else Base cst_e. + + Definition annotate_base (is_let_bound : bool) {t : base.type.base} + (st : abstract_domain' t) (e : @expr var t) + : UnderLets (@expr var t) + := match invert_Literal e with + | Some v => Base ##(update_literal_with_state _ st v) + | None => annotate_with_ident is_let_bound st e + end%expr. + + Fixpoint annotate (is_let_bound : bool) {t : base.type} : abstract_domain' t -> @expr var t -> UnderLets (@expr var t) + := match t return abstract_domain' t -> @expr var t -> UnderLets (@expr var t) with + | base.type.type_base t => annotate_base is_let_bound + | base.type.prod A B + => fun st e + => match invert_pair e with + | Some (x, y) + => let stx := abstract_interp_ident _ ident.fst st in + let sty := abstract_interp_ident _ ident.snd st in + (x' <-- @annotate is_let_bound A stx x; + y' <-- @annotate is_let_bound B sty y; + Base (x', y')%expr) + | None => annotate_with_ident is_let_bound st e + end + | base.type.list A + => fun st e + => match extract_list_state _ st, reflect_list e with + | Some ls_st, Some ls_e + => (retv <---- (List.map + (fun '(st', e') => @annotate is_let_bound A st' e') + (List.combine ls_st ls_e)); + Base (reify_list retv)) + | Some ls_st, None + => (retv <---- (List.map + (fun '(n, st') + => let e' := (#ident.List_nth_default @ DefaultValue.expr.base.default @ e @ ##(n:nat))%expr in + @annotate is_let_bound A st' e') + (List.combine (List.seq 0 (List.length ls_st)) ls_st)); + Base (reify_list retv)) + | None, _ => annotate_with_ident is_let_bound st e + end + end%under_lets. + + Local Notation value_with_lets := (@value_with_lets base.type ident var abstract_domain'). + Local Notation reify := (@reify base.type ident var abstract_domain' annotate bottom'). + Local Notation reflect := (@reflect base.type ident var abstract_domain' annotate bottom'). + + (** We manually rewrite with the rule for [nth_default], as the eliminator for eta-expanding lists in the input *) + Definition interp_ident {t} (idc : ident t) : value_with_lets t + := match idc in ident t return value_with_lets t with + | ident.List_nth_default T as idc + => let default := reflect (###idc) (abstract_interp_ident _ idc) in + Base + (fun default_arg + => default <-- default default_arg; + Base + (fun ls_arg + => default <-- default ls_arg; + Base + (fun n_arg + => default <-- default n_arg; + ls' <-- @reify false (base.type.list T) ls_arg tt; + Base + (fst default, + match reflect_list ls', invert_Literal (snd n_arg) with + | Some ls, Some n + => nth_default (snd default_arg) ls n + | _, _ => snd default + end)))) + | idc => Base (reflect (###idc) (abstract_interp_ident _ idc)) + end%core%under_lets%expr. + + Definition eval_with_bound {t} (e : @expr value_with_lets t) + (st : type.for_each_lhs_of_arrow abstract_domain t) + : @expr var t + := @eval_with_bound' base.type ident var abstract_domain' annotate bottom' (@interp_ident) t e st. + + Definition eval {t} (e : @expr value_with_lets t) : @expr var t + := @eval' base.type ident var abstract_domain' annotate bottom' (@interp_ident) t e. + + Definition eta_expand_with_bound {t} (e : @expr var t) + (st : type.for_each_lhs_of_arrow abstract_domain t) + : @expr var t + := @eta_expand_with_bound' base.type ident var abstract_domain' annotate bottom' t e st. + + Definition extract {t} (e : @expr _ t) (bound : type.for_each_lhs_of_arrow abstract_domain t) : abstract_domain' (type.final_codomain t) + := @extract_gen base.type ident abstract_domain' bottom' abstract_interp_ident t e bound. + End with_var. + End ident. + + Definition default_relax_zrange (v : zrange) : option zrange := Some v. + + Section specialized. + Local Notation abstract_domain' := ZRange.type.base.option.interp. + Local Notation abstract_domain := (@partial.abstract_domain base.type abstract_domain'). + Notation expr := (@expr base.type ident). + Notation Expr := (@expr.Expr base.type ident). + Local Notation type := (type base.type). + Let type_base (x : base.type) : type := type.base x. + Local Coercion type_base : base.type >-> type. + + Section with_relax. + Context (relax_zrange : zrange -> option zrange). + + Let always_relax_zrange : zrange -> zrange + := fun range => match relax_zrange (ZRange.normalize range) with + | Some r => r + | None => range + end. + + Definition annotation_of_state (st : abstract_domain' base.type.Z) : option zrange + := option_map always_relax_zrange st. + + Definition annotate_ident t : abstract_domain' t -> option (ident (t -> t)) + := match t return abstract_domain' t -> option (ident (t -> t)) with + | base.type.Z + => fun st => st' <- annotation_of_state st; Some (ident.Z_cast st') + | base.type.Z * base.type.Z + => fun '(sta, stb) => sta' <- annotation_of_state sta; stb' <- annotation_of_state stb; Some (ident.Z_cast2 (sta', stb')) + | _ => fun _ => None + end%option%etype. + Definition is_annotated_for t t' (idc : ident t) : abstract_domain' t' -> bool + := match idc, t' with + | ident.Z_cast r, base.type.type_base base.type.Z + => fun r' + => option_beq zrange_beq (Some r) (annotation_of_state r') + | ident.Z_cast2 (r1, r2), base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z) + => fun '(r1', r2') + => (option_beq zrange_beq (Some r1) (annotation_of_state r1')) + && (option_beq zrange_beq (Some r2) (annotation_of_state r2')) + | _, _ => fun _ => false + end. + Definition is_annotation t (idc : ident t) : bool + := match idc with + | ident.Z_cast _ + | ident.Z_cast2 _ + => true + | _ => false + end. + Definition bottom' T : abstract_domain' T + := ZRange.type.base.option.None. + Definition abstract_interp_ident t (idc : ident t) : type.interp abstract_domain' t + := ZRange.ident.option.interp idc. + Definition update_Z_literal_with_state : abstract_domain' base.type.Z -> Z -> Z + := fun r n + => match r with + | Some r => if ZRange.type.base.is_bounded_by (t:=base.type.Z) r n + then n + else ident.cast_outside_of_range r n + | None => n + end. + Definition update_literal_with_state (t : base.type.base) : abstract_domain' t -> base.interp t -> base.interp t + := match t with + | base.type.Z => update_Z_literal_with_state + | base.type.unit + | base.type.bool + | base.type.nat + => fun _ => id + end. + Definition extract_list_state A (st : abstract_domain' (base.type.list A)) : option (list (abstract_domain' A)) + := st. + + Definition eval_with_bound {var} {t} (e : @expr _ t) (bound : type.for_each_lhs_of_arrow abstract_domain t) : expr t + := (@partial.ident.eval_with_bound) + var abstract_domain' annotate_ident bottom' abstract_interp_ident update_literal_with_state extract_list_state is_annotated_for t e bound. + + Definition eta_expand_with_bound {var} {t} (e : @expr _ t) (bound : type.for_each_lhs_of_arrow abstract_domain t) : expr t + := (@partial.ident.eta_expand_with_bound) + var abstract_domain' annotate_ident bottom' abstract_interp_ident update_literal_with_state extract_list_state is_annotated_for t e bound. + + Definition EvalWithBound {t} (e : Expr t) (bound : type.for_each_lhs_of_arrow abstract_domain t) : Expr t + := fun var => eval_with_bound (e _) bound. + Definition EtaExpandWithBound {t} (e : Expr t) (bound : type.for_each_lhs_of_arrow abstract_domain t) : Expr t + := fun var => eta_expand_with_bound (e _) bound. + End with_relax. + + Definition eval {var} {t} (e : @expr _ t) : expr t + := (@partial.ident.eval) + var abstract_domain' (annotate_ident default_relax_zrange) bottom' abstract_interp_ident update_literal_with_state extract_list_state (is_annotated_for default_relax_zrange) t e. + Definition Eval {t} (e : Expr t) : Expr t + := fun var => eval (e _). + Definition EtaExpandWithListInfoFromBound {t} (e : Expr t) (bound : type.for_each_lhs_of_arrow abstract_domain t) : Expr t + := EtaExpandWithBound default_relax_zrange e (type.map_for_each_lhs_of_arrow (@ZRange.type.option.strip_ranges) bound). + Definition extract {t} (e : expr t) (bound : type.for_each_lhs_of_arrow abstract_domain t) : abstract_domain' (type.final_codomain t) + := @partial.ident.extract abstract_domain' bottom' abstract_interp_ident t e bound. + Definition Extract {t} (e : Expr t) (bound : type.for_each_lhs_of_arrow abstract_domain t) : abstract_domain' (type.final_codomain t) + := @partial.ident.extract abstract_domain' bottom' abstract_interp_ident t (e _) bound. + End specialized. + End partial. + Import defaults. + + Module Import CheckCasts. + Fixpoint get_casts {t} (e : expr t) : list { t : _ & ident t } + := match e with + | expr.Ident t idc => if partial.is_annotation _ idc then [existT _ t idc] else nil + | expr.Var t v => v + | expr.Abs s d f => @get_casts _ (f nil) + | expr.App s d f x => @get_casts _ f ++ @get_casts _ x + | expr.LetIn A B x f => @get_casts _ x ++ @get_casts _ (f nil) + end%list. + + Definition GetUnsupportedCasts {t} (e : Expr t) : list { t : _ & ident t } + := get_casts (e _). + End CheckCasts. + + Definition PartialEvaluateWithBounds + (relax_zrange : zrange -> option zrange) {t} (e : Expr t) + (bound : type.for_each_lhs_of_arrow ZRange.type.option.interp t) + : Expr t + := partial.EvalWithBound relax_zrange (GeneralizeVar.GeneralizeVar (e _)) bound. + Definition PartialEvaluateWithListInfoFromBounds {t} (e : Expr t) + (bound : type.for_each_lhs_of_arrow ZRange.type.option.interp t) + : Expr t + := partial.EtaExpandWithListInfoFromBound (GeneralizeVar.GeneralizeVar (e _)) bound. + + Definition CheckedPartialEvaluateWithBounds + (relax_zrange : zrange -> option zrange) + {t} (E : Expr t) + (b_in : type.for_each_lhs_of_arrow ZRange.type.option.interp t) + (b_out : ZRange.type.base.option.interp (type.final_codomain t)) + : Expr t + (ZRange.type.base.option.interp (type.final_codomain t) * Expr t + list { t : _ & ident t }) + := dlet_nd e := GeneralizeVar.ToFlat E in + let E := GeneralizeVar.FromFlat e in + let b_computed := partial.Extract E b_in in + match CheckCasts.GetUnsupportedCasts E with + | nil => (let E := PartialEvaluateWithBounds relax_zrange E b_in in + if ZRange.type.base.option.is_tighter_than b_computed b_out + then @inl (Expr t) _ E + else inr (@inl (ZRange.type.base.option.interp (type.final_codomain t) * Expr t) _ (b_computed, E))) + | unsupported_casts => inr (inr unsupported_casts) + end. +End Compilers. |