aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/extenders.py
blob: c99bf8badb35e6fffb7cae8761db9d402b8b3a8f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Extenders of tf.estimator.Estimator."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six

from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import clip_ops
from tensorflow.python.training import optimizer as optimizer_lib


_VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config'])


def add_metrics(estimator, metric_fn):
  """Creates a new ${tf.estimator.Estimator} which has given metrics.

  Example:

  ```python
    def my_auc(labels, predictions):
      return {'auc': tf.metrics.auc(labels, predictions['logistic'])}

    estimator = tf.estimator.DNNClassifier(...)
    estimator = tf.contrib.estimator.add_metrics(estimator, my_auc)
    estimator.train(...)
    estimator.evaluate(...)
  ```
  Example usage of custom metric which uses features:

  ```python
    def my_auc(features, labels, predictions):
      return {'auc': tf.metrics.auc(
        labels, predictions['logistic'], weights=features['weight'])}

    estimator = tf.estimator.DNNClassifier(...)
    estimator = tf.contrib.estimator.add_metrics(estimator, my_auc)
    estimator.train(...)
    estimator.evaluate(...)
  ```

  Args:
    estimator: A ${tf.estimator.Estimator} object.
    metric_fn: A function which should obey the following signature:
      - Args: can only have following four arguments in any order:
        * predictions: Predictions `Tensor` or dict of `Tensor` created by given
          `estimator`.
        * features: Input `dict` of `Tensor` objects created by `input_fn` which
          is given to `estimator.evaluate` as an argument.
        * labels:  Labels `Tensor` or dict of `Tensor` created by `input_fn`
          which is given to `estimator.evaluate` as an argument.
        * config: config attribute of the `estimator`.
       - Returns:
         Dict of metric results keyed by name. Final metrics are a union of this
         and `estimator's` existing metrics. If there is a name conflict between
         this and `estimator`s existing metrics, this will override the existing
         one. The values of the dict are the results of calling a metric
         function, namely a `(metric_tensor, update_op)` tuple.

  Returns:
      A new ${tf.estimator.Estimator} which has a union of original metrics with
        given ones.
  """
  _verify_metric_fn_args(metric_fn)

  def new_model_fn(features, labels, mode, config):
    spec = estimator.model_fn(features, labels, mode, config)
    if mode != model_fn_lib.ModeKeys.EVAL:
      return spec
    new_metrics = _call_metric_fn(metric_fn, features, labels, spec.predictions,
                                  config)
    all_metrics = spec.eval_metric_ops or {}
    all_metrics.update(new_metrics)
    return spec._replace(eval_metric_ops=all_metrics)

  return estimator_lib.Estimator(
      model_fn=new_model_fn,
      model_dir=estimator.model_dir,
      config=estimator.config)


def clip_gradients_by_norm(optimizer, clip_norm):
  """Returns an optimizer which clips gradients before applying them.

  Example:

  ```python
  optimizer = tf.train.ProximalAdagradOptimizer(
      learning_rate=0.1,
      l1_regularization_strength=0.001)
  optimizer = tf.contrib.estimator.clip_gradients_by_norm(
      optimizer, clip_norm)
  estimator = tf.estimator.DNNClassifier(
      feature_columns=[...],
      hidden_units=[1024, 512, 256],
      optimizer=optimizer)
  ```

  Args:
    optimizer: An `tf.Optimizer` object to apply gradients.
    clip_norm: A 0-D (scalar) `Tensor` > 0. The clipping ratio.

  Returns:
    A `tf.Optimizer`.
  """

  def clip_grads(grads_and_vars):
    gradients, variables = zip(*grads_and_vars)
    gradients = clip_ops.clip_by_global_norm(gradients, clip_norm)[0]
    grads_and_vars = list(zip(gradients, variables))
    return grads_and_vars

  return _TransformGradients(
      optimizer=optimizer,
      transform_grads_fn=clip_grads,
      name='ClipByNorm' + optimizer.get_name())


def forward_features(estimator, keys=None):
  """Forward features to predictions dictionary.

  In some cases, user wants to see some of the features in estimators prediction
  output. As an example, consider a batch prediction service: The service simply
  runs inference on the users graph and returns the results. Keys are essential
  because there is no order guarantee on the outputs so they need to be rejoined
  to the inputs via keys or transclusion of the inputs in the outputs.

  Example:

  ```python
    def input_fn():
      features, labels = ...
      features['unique_example_id'] = ...
      features, labels

    estimator = tf.estimator.LinearClassifier(...)
    estimator = tf.contrib.estimator.forward_features(
        estimator, 'unique_example_id')
    estimator.train(...)
    assert 'unique_example_id' in estimator.predict(...)
  ```

  Args:
    estimator: A ${tf.estimator.Estimator} object.
    keys: a `string` or a `list` of `string`. If it is `None`, all of the
      `features` in `dict` is forwarded to the `predictions`. If it is a
      `string`, only given key is forwarded. If it is a `list` of strings, all
      the given `keys` are forwarded.

  Returns:
      A new ${tf.estimator.Estimator} which forwards features to predictions.

  Raises:
    ValueError:
      * if `keys` is already part of `predictions`. We don't allow
        override.
      * if 'keys' does not exist in `features`.
      * if feature key refers to a `SparseTensor`, since we don't support
        `SparseTensor` in `predictions`. `SparseTensor` is common in `features`.
    TypeError: if `keys` type is not one of `string` or list/tuple of `string`.
  """

  def verify_key_types(keys):  # pylint: disable=missing-docstring
    if keys is None:
      return keys
    if isinstance(keys, six.string_types):
      return [keys]
    if not isinstance(keys, (list, tuple)):
      raise TypeError('keys should be either a string or a list of strings. '
                      'Given: {}'.format(type(keys)))
    for key in keys:
      if not isinstance(key, six.string_types):
        raise TypeError('All items in the given keys list should be a string. '
                        'There exist an item with type: {}'.format(type(key)))
    return keys

  def get_keys(features):
    if keys is None:
      return features.keys()
    return keys

  def verify_keys_and_predictions(features, predictions):
    if not isinstance(predictions, dict):
      raise ValueError(
          'Predictions should be a dict to be able to forward features. '
          'Given: {}'.format(type(predictions)))
    for key in get_keys(features):
      if key not in features:
        raise ValueError(
            'keys should be exist in features. Key "{}" is not in features '
            'dict. features dict has following keys: {}. Please check '
            'arguments of forward_features.'.format(key, features.keys()))
      if key in predictions:
        raise ValueError(
            'Cannot forward feature key ({}). Since it does exist in '
            'predictions. Existing prediction keys: {}. Please check arguments '
            'of forward_features.'.format(key, predictions.keys()))

  keys = verify_key_types(keys)

  def new_model_fn(features, labels, mode, config):  # pylint: disable=missing-docstring
    spec = estimator.model_fn(features, labels, mode, config)
    predictions = spec.predictions
    if predictions is None:
      return spec
    verify_keys_and_predictions(features, predictions)
    for key in get_keys(features):
      feature = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
          features[key])
      if not isinstance(feature, ops.Tensor):
        raise ValueError(
            'Forwarded feature ({}) should be a Tensor. Please use keys '
            'argument of forward_features to filter unwanted features. Type of '
            'features[{}] is {}.'.format(key, key, type(feature)))
      predictions[key] = feature
    return spec._replace(predictions=predictions)

  return estimator_lib.Estimator(
      model_fn=new_model_fn,
      model_dir=estimator.model_dir,
      config=estimator.config)


class _TransformGradients(optimizer_lib.Optimizer):
  """Add given gradient transformation to the optimizer."""

  def __init__(self, optimizer, transform_grads_fn, name=None):
    """Construct an `tf.Optimizer` wrapper to apply given transformations.

    Example:

    ```python
    optimizer = tf.train.ProximalAdagradOptimizer(
        learning_rate=0.1,
        l1_regularization_strength=0.001)
    def clip_grads(grads_and_vars):
      gradients, variables = zip(*grads_and_vars)
      gradients = tf.clip_by_global_norm(grads, my_norm)[0]
      grads_and_vars = list(zip(gradients, variables))
      return grads_and_vars
    optimizer = _TransformGradients(
        opt=optimizer, transform_grads_fn=clip_grads)
    estimator = tf.estimator.DNNClassifier(
        feature_columns=[...],
        hidden_units=[1024, 512, 256],
        optimizer=optimizer)
    ```

    Args:
      optimizer: An `tf.Optimizer` object to apply gradients.
      transform_grads_fn: A function which takes a single argument, a list of
        gradient to variable pairs (tuples), performs any requested gradient
        updates, such as gradient clipping or multipliers, and returns the
        updated list.
      name: A string which will be used for debugging purposes.
    """
    super(_TransformGradients, self).__init__(
        use_locking=False, name=name or optimizer.get_name())
    self._optimizer = optimizer
    self._transform_grads_fn = transform_grads_fn

  def compute_gradients(self, *args, **kwargs):
    """See `tf.Optimizer`."""
    return self._optimizer.compute_gradients(*args, **kwargs)

  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    Calls `transform_grads_fn`, and then applies the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      An `Operation` that applies the gradients. If `global_step` was not None,
      that operation also increments `global_step`.

    Raises:
      ValueError: If the grads_and_vars is malformed.
    """
    grads_and_vars = self._transform_grads_fn(grads_and_vars)
    return self._optimizer.apply_gradients(grads_and_vars, global_step, name)

  def get_slot(self, *args, **kwargs):
    """See `tf.Optimizer`."""
    return self._optimizer.get_slot(*args, **kwargs)

  def get_slot_names(self, *args, **kwargs):
    """See `tf.Optimizer`."""
    return self._optimizer.get_slot_names(*args, **kwargs)


def _verify_metric_fn_args(metric_fn):
  args = set(estimator_util.fn_args(metric_fn))
  invalid_args = list(args - _VALID_METRIC_FN_ARGS)
  if invalid_args:
    raise ValueError('metric_fn (%s) has following not expected args: %s' %
                     (metric_fn, invalid_args))


def _call_metric_fn(metric_fn, features, labels, predictions, config):
  """Calls metric fn with proper arguments."""
  metric_fn_args = estimator_util.fn_args(metric_fn)
  kwargs = {}
  if 'features' in metric_fn_args:
    kwargs['features'] = features
  if 'labels' in metric_fn_args:
    kwargs['labels'] = labels
  if 'predictions' in metric_fn_args:
    kwargs['predictions'] = predictions
  if 'config' in metric_fn_args:
    kwargs['config'] = config
  return metric_fn(**kwargs)