aboutsummaryrefslogtreecommitdiff
path: root/src/Compilers/Z/Bounds/Relax.v
blob: 8178592eda1f9d06c207c323eefe1b62a08dcc9f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
Require Import Coq.ZArith.ZArith.
Require Import Coq.Arith.Arith.
Require Import Coq.Classes.Morphisms.
Require Import Crypto.Compilers.Syntax.
Require Import Crypto.Compilers.TypeInversion.
Require Import Crypto.Compilers.Relations.
Require Import Crypto.Compilers.SmartMap.
Require Import Crypto.Compilers.Z.Syntax.
Require Import Crypto.Compilers.Z.Syntax.Equality.
Require Import Crypto.Compilers.Z.Syntax.Util.
Require Import Crypto.Compilers.Z.Bounds.Interpretation.
Require Import Crypto.Compilers.Z.Bounds.RoundUpLemmas.
Require Import Crypto.Util.Tactics.DestructHead.
Require Import Crypto.Util.Tactics.SpecializeBy.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Tactics.SplitInContext.
Require Import Crypto.Util.Option.
Require Import Crypto.Util.ZUtil.Log2.
Require Import Crypto.Util.ZUtil.Tactics.LtbToLt.
Require Import Crypto.Util.Bool.

Local Lemma helper logsz v
  : (v < 2 ^ 2 ^ Z.of_nat logsz)%Z <-> (Z.to_nat (Z.log2_up (Z.log2_up (1 + v))) <= logsz)%nat.
Proof.
  rewrite Nat2Z.inj_le, Z2Nat.id by auto with zarith.
  transitivity (1 + v <= 2^2^Z.of_nat logsz)%Z; [ omega | ].
  rewrite !Z.log2_up_le_pow2_full by auto with zarith.
  reflexivity.
Qed.

Local Arguments Z.pow : simpl never.
Local Arguments Z.sub !_ !_.
Local Arguments Z.add !_ !_.
Local Arguments Z.mul !_ !_.
Lemma relax_output_bounds'
      round_up
      t (tight_output_bounds relaxed_output_bounds : interp_flat_type Bounds.interp_base_type t)
      (Hv : SmartFlatTypeMap (@Bounds.bounds_to_base_type round_up) relaxed_output_bounds
            = SmartFlatTypeMap (@Bounds.bounds_to_base_type round_up) tight_output_bounds)
      v k
      (v' := eq_rect _ (interp_flat_type _) v _ Hv)
      (Htighter : @Bounds.is_bounded_by
                    t tight_output_bounds
                    (@cast_back_flat_const
                       (@Bounds.interp_base_type) t (@Bounds.bounds_to_base_type round_up) tight_output_bounds
                       v')
                  /\ @cast_back_flat_const
                       (@Bounds.interp_base_type) t (@Bounds.bounds_to_base_type round_up) tight_output_bounds
                       v'
                     = k)
      (Hrelax : Bounds.is_tighter_thanb tight_output_bounds relaxed_output_bounds = true)
  : @Bounds.is_bounded_by
      t relaxed_output_bounds
      (@cast_back_flat_const
         (@Bounds.interp_base_type) t (@Bounds.bounds_to_base_type round_up) relaxed_output_bounds
         v)
    /\ @cast_back_flat_const
         (@Bounds.interp_base_type) t (@Bounds.bounds_to_base_type round_up) relaxed_output_bounds
         v
       = k.
Proof.
  destruct Htighter as [H0 H1]; subst v' k.
  cbv [Bounds.is_bounded_by cast_back_flat_const Bounds.is_tighter_thanb] in *.
  apply interp_flat_type_rel_pointwise_iff_relb in Hrelax.
  induction t; unfold SmartFlatTypeMap in *; simpl @smart_interp_flat_map in *; inversion_flat_type.
  { cbv [Bounds.is_tighter_thanb' Bounds.bounds_to_base_type ZRange.is_tighter_than_bool is_true SmartFlatTypeMap ZRange.is_bounded_by' Bounds.smallest_logsz ZRange.is_bounded_by Bounds.is_bounded_by' Bounds.bit_width_of_base_type] in *; simpl in *;
      repeat first [ progress inversion_flat_type
                   | progress inversion_base_type_constr
                   | progress subst
                   | progress destruct_head bounds
                   | progress destruct_head base_type
                   | progress split_andb
                   | progress Z.ltb_to_lt
                   | progress break_match_hyps
                   | progress destruct_head'_and
                   | progress simpl in *
                   | rewrite helper in *
                   | omega
                   | tauto
                   | congruence
                   | progress destruct_head @eq; (reflexivity || omega)
                   | progress break_innermost_match_step
                   | apply conj ]. }
  { compute in *; tauto. }
  { simpl in *.
    specialize (fun Hv => IHt1 (fst tight_output_bounds) (fst relaxed_output_bounds) Hv (fst v)).
    specialize (fun Hv => IHt2 (snd tight_output_bounds) (snd relaxed_output_bounds) Hv (snd v)).
    do 2 match goal with
         | [ H : _ = _, H' : forall x, _ |- _ ] => specialize (H' H)
         end.
    simpl in *.
    split_and.
    repeat apply conj;
      [ match goal with H : _ |- _ => apply H end..
      | apply (f_equal2 (@pair _ _)); (etransitivity; [ match goal with H : _ |- _ => apply H end | ]) ];
      repeat first [ progress destruct_head prod
                   | progress simpl in *
                   | reflexivity
                   | assumption
                   | match goal with
                     | [ |- ?P (eq_rect _ _ _ _ _) = ?P _ ]
                       => apply f_equal; clear
                     | [ H : interp_flat_type_rel_pointwise (@Bounds.is_bounded_by') ?x ?y |- interp_flat_type_rel_pointwise (@Bounds.is_bounded_by') ?x ?y' ]
                       => clear -H;
                          match goal with |- ?R _ _ => generalize dependent R; intros end
                     | [ H : ?x = ?y |- _ ]
                       => first [ generalize dependent x | generalize dependent y ];
                          let k := fresh in intro k; intros; subst k
                     end ]. }
Qed.

Lemma relax_output_bounds
      round_up
      t (tight_output_bounds relaxed_output_bounds : interp_flat_type Bounds.interp_base_type t)
      (Hv : SmartFlatTypeMap (@Bounds.bounds_to_base_type round_up) relaxed_output_bounds
            = SmartFlatTypeMap (@Bounds.bounds_to_base_type round_up) tight_output_bounds)
      v k
      (v' := eq_rect _ (interp_flat_type _) v _ Hv)
      (Htighter : @Bounds.is_bounded_by t tight_output_bounds k
                  /\ @cast_back_flat_const
                       (@Bounds.interp_base_type) t (@Bounds.bounds_to_base_type round_up) tight_output_bounds
                       v'
                     = k)
      (Hrelax : Bounds.is_tighter_thanb tight_output_bounds relaxed_output_bounds = true)
  : @Bounds.is_bounded_by t relaxed_output_bounds k
    /\ @cast_back_flat_const
         (@Bounds.interp_base_type) t (@Bounds.bounds_to_base_type round_up) relaxed_output_bounds
         v
       = k.
Proof.
  pose proof (fun pf => @relax_output_bounds' round_up t tight_output_bounds relaxed_output_bounds Hv v k (conj pf (proj2 Htighter)) Hrelax) as H.
  destruct H as [H1 H2]; [ | rewrite <- H2; tauto ].
  subst v'.
  destruct Htighter; subst k; assumption.
Qed.