aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py
blob: f137ada35524bf2467314f4a284ea35a82f06825 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks for use with GTFlow Estimator."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from tensorflow.contrib.learn.python.learn import session_run_hook
from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArgs
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training_util
from tensorflow.python.training.summary_io import SummaryWriterCache


class FeatureImportanceSummarySaver(session_run_hook.SessionRunHook):
  """Hook to save feature importance summaries."""

  def __init__(self, model_dir, every_n_steps=1):
    """Create a FeatureImportanceSummarySaver Hook.

    This hook creates scalar summaries representing feature importance
    for each feature column during training.

    Args:
      model_dir: model base output directory.
      every_n_steps: frequency, in number of steps, for logging summaries.

    Raises:
      ValueError: If one of the arguments is invalid.
    """
    if model_dir is None:
      raise ValueError("model dir must be specified.")
    self._model_dir = model_dir
    self._every_n_steps = every_n_steps
    self._last_triggered_step = None

  def begin(self):
    self._global_step_tensor = training_util.get_global_step()
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use FeatureImportanceSummarySaver.")
    graph = ops.get_default_graph()
    self._feature_names_tensor = graph.get_tensor_by_name(
        "gbdt/feature_names:0")
    self._feature_usage_counts_tensor = graph.get_tensor_by_name(
        "gbdt/feature_usage_counts:0")
    self._feature_gains_tensor = graph.get_tensor_by_name(
        "gbdt/feature_gains:0")

  def before_run(self, run_context):
    del run_context  # Unused by feature importance summary saver hook.
    requests = {
        "global_step": self._global_step_tensor,
        "feature_names": self._feature_names_tensor,
        "feature_usage_counts": self._feature_usage_counts_tensor,
        "feature_gains": self._feature_gains_tensor
    }
    return SessionRunArgs(requests)

  def after_run(self, run_context, run_values):
    del run_context  # Unused by feature importance summary saver hook.

    # Read result tensors.
    global_step = run_values.results["global_step"]
    feature_names = run_values.results["feature_names"]
    feature_usage_counts = run_values.results["feature_usage_counts"]
    feature_gains = run_values.results["feature_gains"]

    # Ensure summaries are logged at desired frequency
    if (self._last_triggered_step is not None and
        global_step < self._last_triggered_step + self._every_n_steps):
      return

    # Validate tensors.
    if (len(feature_names) != len(feature_usage_counts) or
        len(feature_names) != len(feature_gains)):
      raise RuntimeError(
          "Feature names and importance measures have inconsistent lengths.")

    # Compute total usage.
    total_usage_count = 0.0
    for usage_count in feature_usage_counts:
      total_usage_count += usage_count
    usage_count_norm = 1.0 / total_usage_count if total_usage_count else 1.0

    # Compute total gain.
    total_gain = 0.0
    for gain in feature_gains:
      total_gain += gain
    gain_norm = 1.0 / total_gain if total_gain else 1.0

    # Output summary for each feature.
    self._last_triggered_step = global_step
    for (name, usage_count, gain) in zip(feature_names, feature_usage_counts,
                                         feature_gains):
      output_dir = os.path.join(self._model_dir, name.decode("utf-8"))
      summary_writer = SummaryWriterCache.get(output_dir)
      usage_count_summary = Summary(value=[
          Summary.Value(
              tag="feature_importance/usage_counts", simple_value=usage_count)
      ])
      usage_fraction_summary = Summary(value=[
          Summary.Value(
              tag="feature_importance/usage_fraction",
              simple_value=usage_count * usage_count_norm)
      ])
      summary_writer.add_summary(usage_count_summary, global_step)
      summary_writer.add_summary(usage_fraction_summary, global_step)
      gains_summary = Summary(value=[
          Summary.Value(tag="feature_importance/gains", simple_value=gain)
      ])
      gains_fraction_summary = Summary(value=[
          Summary.Value(
              tag="feature_importance/gains_fraction",
              simple_value=gain * gain_norm)
      ])
      summary_writer.add_summary(gains_summary, global_step)
      summary_writer.add_summary(gains_fraction_summary, global_step)


class FeedFnHook(session_run_hook.SessionRunHook):
  """Runs feed_fn and sets the feed_dict accordingly."""

  def __init__(self, feed_fn):
    self.feed_fn = feed_fn

  def before_run(self, run_context):
    del run_context  # unused by FeedFnHook.
    return session_run_hook.SessionRunArgs(fetches=None, feed_dict=self.feed_fn)


class StopAfterNTrees(session_run_hook.SessionRunHook):
  """Stop training after building N full trees."""

  def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor,
               override_global_step_value=None):
    self._num_trees = n
    # num_attempted_trees_tensor and num_finalized_trees_tensor are both
    # tensors.
    self._num_attempted_trees_tensor = num_attempted_trees_tensor
    self._num_finalized_trees_tensor = num_finalized_trees_tensor
    self._override_global_step_value = override_global_step_value

  def begin(self):
    self._global_step_tensor = training_util.get_global_step()
    if self._global_step_tensor is None:
      raise RuntimeError("Global step should be created.")

    if self._override_global_step_value is not None:
      self._override_global_step_op = state_ops.assign(
          self._global_step_tensor, self._override_global_step_value)

  def before_run(self, run_context):
    del run_context  # unused by StopTrainingAfterNTrees.
    return session_run_hook.SessionRunArgs({
        "num_attempted_trees": self._num_attempted_trees_tensor,
        "num_finalized_trees": self._num_finalized_trees_tensor,
    })

  def after_run(self, run_context, run_values):
    num_attempted_trees = run_values.results["num_attempted_trees"]
    num_finalized_trees = run_values.results["num_finalized_trees"]
    assert num_attempted_trees is not None
    assert num_finalized_trees is not None
    # Stop when the required number of finalized trees is reached, or when we
    # try enough times to build a tree but keep failing.
    if (num_finalized_trees >= self._num_trees or
        num_attempted_trees > 2 * self._num_trees):
      logging.info("Requesting stop since we have reached %d trees.",
                   num_finalized_trees)
      if self._override_global_step_value is not None:
        logging.info("Overriding global steps value.")
        run_context.session.run(self._override_global_step_op)
      run_context.request_stop()


class SwitchTrainOp(session_run_hook.SessionRunHook):
  """Hook that switches the train op after specified number of steps.

  Hook that replaces the train op depending on the number of steps of training
  that have taken place. The first_train_op is used till train_steps steps
  are reached. Thereafter the second_train_op is used.
  """

  def __init__(self, first_train_op, train_steps, second_train_op):
    """Initializes a `SwitchTrainOp`."""
    self._first_train_op = first_train_op
    self._second_train_op = second_train_op
    self._train_steps = train_steps

  def _get_train_op_for_global_step(self, current_step):
    """Gets train_op for current global step."""
    if current_step < self._train_steps:
      return self._first_train_op
    return self._second_train_op

  def begin(self):
    self._global_step_tensor = training_util.get_global_step()
    self._current_train_op = control_flow_ops.no_op()
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use SwitchTrainOp.")

  def before_run(self, run_context):  # pylint: disable=unused-argument
    return session_run_hook.SessionRunArgs(
        {"global_step": self._global_step_tensor,
         "train_op": self._current_train_op})

  def after_run(self, run_context, run_values):
    self._current_train_op = self._get_train_op_for_global_step(
        run_values.results["global_step"])