aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/extenders.py
blob: 2e944cbdd96f5dd53cf623120bb31f8d2556a89a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Extenders of tf.estimator.Estimator."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.util import tf_inspect

_VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config'])


def add_metrics(estimator, metric_fn):
  """Creates a new ${tf.estimator.Estimator} which has given metrics.

  Example:

  ```python
    def my_auc(labels, predictions):
      return {'auc': tf.metrics.auc(labels, predictions['logistic'])}

    estimator = tf.estimator.DNNClassifier(...)
    estimator = tf.contrib.estimator.add_metrics(estimator, my_auc)
    estimator.train(...)
    estimator.evaluate(...)
  ```
  Example usage of custom metric which uses features:

  ```python
    def my_auc(features, labels, predictions):
      return {'auc': tf.metrics.auc(
        labels, predictions['logistic'], weights=features['weight'])}

    estimator = tf.estimator.DNNClassifier(...)
    estimator = tf.contrib.estimator.add_metrics(estimator, my_auc)
    estimator.train(...)
    estimator.evaluate(...)
  ```

  Args:
    estimator: A ${tf.estimator.Estimator} object.
    metric_fn: A function which should obey the following signature:
      - Args: can only have following four arguments in any order:
        * predictions: Predictions `Tensor` or dict of `Tensor` created by given
          `estimator`.
        * features: Input `dict` of `Tensor` objects created by `input_fn` which
          is given to `estimator.evaluate` as an argument.
        * labels:  Labels `Tensor` or dict of `Tensor` created by `input_fn`
          which is given to `estimator.evaluate` as an argument.
        * config: config attribute of the `estimator`.
       - Returns:
         Dict of metric results keyed by name. Final metrics are a union of this
         and `estimator's` existing metrics. If there is a name conflict between
         this and `estimator`s existing metrics, this will override the existing
         one. The values of the dict are the results of calling a metric
         function, namely a `(metric_tensor, update_op)` tuple.

  Returns:
      A new ${tf.estimator.Estimator} which has a union of original metrics with
        given ones.
  """
  _verify_metric_fn_args(metric_fn)

  def new_model_fn(features, labels, mode, config):
    spec = estimator.model_fn(features, labels, mode, config)
    if mode != model_fn_lib.ModeKeys.EVAL:
      return spec
    new_metrics = _call_metric_fn(metric_fn, features, labels, spec.predictions,
                                  config)
    all_metrics = spec.eval_metric_ops or {}
    all_metrics.update(new_metrics)
    return spec._replace(eval_metric_ops=all_metrics)

  return estimator_lib.Estimator(
      model_fn=new_model_fn,
      model_dir=estimator.model_dir,
      config=estimator.config)


def _verify_metric_fn_args(metric_fn):
  args = set(estimator_util.fn_args(metric_fn))
  if tf_inspect.ismethod(metric_fn):
    if 'self' in args:
      args.remove('self')
  invalid_args = list(args - _VALID_METRIC_FN_ARGS)
  if invalid_args:
    raise ValueError('metric_fn (%s) has following not expected args: %s' %
                     (metric_fn, invalid_args))


def _call_metric_fn(metric_fn, features, labels, predictions, config):
  """Calls metric fn with proper arguments."""
  metric_fn_args = estimator_util.fn_args(metric_fn)
  kwargs = {}
  if 'features' in metric_fn_args:
    kwargs['features'] = features
  if 'labels' in metric_fn_args:
    kwargs['labels'] = labels
  if 'predictions' in metric_fn_args:
    kwargs['predictions'] = predictions
  if 'config' in metric_fn_args:
    kwargs['config'] = config
  return metric_fn(**kwargs)