aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/evaluator.py
blob: 3faaeef5903615ea122800a6690117dde682e830 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class Evaluator holds Metrics for the duration of an evaluation run."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six

from tensorflow.contrib.eager.python import datasets
from tensorflow.contrib.eager.python import metrics
from tensorflow.contrib.summary import summary_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops


class Evaluator(object):
  """This holds and updates Metrics for the duration of a single eval run.

  Usage:
    evaluator = my_model.evaluator() # or MyEvaluator(my_model)
    for example_batch in ...:
      evaluator(example_batch)
    results = evaluator.all_metric_results(optional_summary_logdir)

  Or, if you are getting your examples from a tf.data.Dataset, you can use
  the evaluate_on_dataset() method.

  Implementers of Evaluators should
  (a) Call `track_metric()` and/or `track_evaluator()` in __init__().
  (b) Override the `call()` method. It will be passed the output of the
      model's `eval_data()` method, and should call its contained metrics
      (treating them as callables) and any child Evaluators (using their
      call() method to avoid calling eval_data() again).

  Args:
    model: A `Model` object with an `eval_data()` method.
  """

  def __init__(self, model):
    self._model = model
    self._metrics = {}
    self._evaluators = {}
    if context.in_graph_mode():
      self.call = function.defun(self.call)

  # ---- API for users ----
  def __call__(self, *args, **kwargs):
    """Update metrics with a minibatch of input examples.

    Args:
      *args:
      **kwargs: Arguments representing an input mini-batch of examples to
        pass to self.model.eval_data().

    Returns:
      The op to execute or None if executing eagerly.
    """
    return self.call(self._model.eval_data(*args, **kwargs))

  def init_variables(self):
    """Return an op for initializing all contained uninitialized variables.

    Only for graph execution. Should be called after variables are created
    in the first execution of __call__().

    Returns:
      An op.

    Raises:
      RuntimeError: if eager execution is enabled.

    @compatibility(eager)
    Only for graph execution.
    @end_compatibility
    """
    if context.in_eager_mode():
      raise RuntimeError("Evaluator.init_variables() not needed when "
                         "eager execution is enabled.")
    return control_flow_ops.group([m.init_variables() for _, m in self.metrics])

  def all_metric_results(self, summary_logdir=None):
    """Computes results for all contained metrics.

    Args:
      summary_logdir: An optional string. If specified, metric results
        will be written as summaries to this directory.

    Returns:
      A `dict` mapping string names to tensors.
    """
    if summary_logdir is None:
      with summary_ops.never_record_summaries():
        return self._all_metric_results()
    else:
      def f():
        with summary_ops.create_file_writer(
            summary_logdir).as_default(), summary_ops.always_record_summaries():
          return self._all_metric_results()
      if context.in_eager_mode():
        return f()
      else:
        return function.defun(f)()

  def _all_metric_results(self):
    """Implementation of `all_metric_results` in the summary context."""
    results = {}
    for name, metric in six.iteritems(self._metrics):
      results[name] = metric.result()
    for prefix, evaluator in six.iteritems(self._evaluators):
      for name, metric in six.iteritems(evaluator._metrics):  # pylint: disable=protected-access
        results[prefix + "/" + name] = metric.result()
    return results

  def evaluate_on_dataset(self, dataset, *args, **kwargs):
    """Convenience method for performing an eval on a Dataset.

    Args:
      dataset: Dataset object with the input data to evaluate on.
      *args:
      **kwargs: Optional additional arguments to __call__(), except
        `summary_logdir`: if specified, metrics will be written as summaries
        to this directory.

    Returns:
      @compatibility(eager)
      When eager execution is enabled, this returns the result of performing
      an evaluation as a dictionary. With graph execution, this returns a tuple
      (init_op, call_op, results_op) which may be executed using this code:
      ```python
        sess.run(init_op)
        try:
          while True:
            sess.run(call_op)
        except tf.errors.OutOfRangeError:
          pass
        return sess.run(results_op)  # A dictionary

        # equivalently:
        return evaluator.run_evaluation(init_op, call_op, results_op, sess=sess)
      ```
      @end_compatibility
    """
    summary_logdir = kwargs.pop("summary_logdir", None)
    if context.in_graph_mode():
      call_op = self.__call__(dataset.make_one_shot_iterator().get_next(),
                              *args, **kwargs)
      init_op = self.init_variables()
      results_op = self.all_metric_results(summary_logdir)
      return (init_op, call_op, results_op)
    # Eager case
    for example in datasets.Iterator(dataset):
      self.__call__(example, *args, **kwargs)
    return self.all_metric_results(summary_logdir)

  @staticmethod
  def run_evaluation(init_op, call_op, results_op, sess=None):
    """Convenience method for running the ops returned by evaluate_on_dataset.

    Args:
      init_op: An op that initializes/resets evaluation state.
      call_op: An op that updates evaluation state on a mini-batch of examples.
        Must generate an tf.errors.OutOfRangeError when done.
      results_op: A dictionary of tensors that compute the final evaluation
        results from the evaulation state.
      sess: The Session to run the evaluation in. Defaults to the default
        Session.

    Returns:
      A dictionary of values, parallel to results_op.

    Raises:
      RuntimeError: if eager execution is enabled.

    @compatibility(eager)
    Only for graph execution.
    @end_compatibility
    """
    if context.in_eager_mode():
      raise RuntimeError("Evaluator.run_evaluation() not supported when "
                         "eager execution is enabled.")
    sess = sess or ops.get_default_session()
    sess.run(init_op)
    try:
      while True:
        sess.run(call_op)
    except errors_impl.OutOfRangeError:
      pass
    return sess.run(results_op)

  # ---- To be implemented by descendants ---
  def call(self, eval_data):
    """Update metrics using the output of self.model.

    Note: This function is executed as a graph function in graph mode.
    This means:
    a) Operations on the same resource are executed in textual order.
       This should make it easier to do things like add the updated
       value of a variable to another, for example.
    b) You don't need to worry about collecting the update ops to execute.
       All update ops added to the graph by this function will be executed.
    As a result, code should generally work the same way with graph or
    eager execution.

    Args:
      eval_data: The output of self.model.eval_data() on a mini-batch of
        examples.
    """
    raise NotImplementedError("Evaluators must define a call member function.")

  # ---- For use by descendants ---
  @property
  def model(self):
    return self._model

  def track_metric(self, metric):
    """Add a Metric to be tracked.

    Metrics can only be tracked by one `Evaluator`. Metrics must be
    tracked or they will not appear in `all_metric_results()`.

    Args:
      metric: A `Metric` object.

    Returns:
      The `metric` passed into this function.

    Raises:
      RuntimeError: If called before __init__.
      TypeError: If `metric` is not of the correct type.
      ValueError: If there is a name collision between Metrics or `metric`
        has already been added to another `Evaluator`.
    """
    if not hasattr(self, "_metrics"):
      raise RuntimeError(
          "Need to call Evaluator.__init__ before adding metrics")
    if not isinstance(metric, metrics.Metric):
      raise TypeError(
          "Evaluator.track_metric() passed type %s, not a tfe.metrics.Metric" %
          (type(metric),))
    if metric.name in self._metrics:
      if metric is self._metrics[metric.name]:
        return metric
      raise ValueError(
          "Attempt to add two Metrics with the name '%s' to the same Evaluator "
          "'%s'" % (metric.name, self.name))
    # pylint: disable=protected-access
    if hasattr(metric, "_added_to_an_evaluator"):
      raise ValueError("Metric %s already added to Evaluator %s" %
                       (metric.name, metric._added_to_an_evaluator))
    metric._added_to_an_evaluator = self.__class__.__name__
    # pylint: enable=protected-access
    self._metrics[metric.name] = metric
    return metric

  def track_evaluator(self, prefix, evaluator):
    """Add a contained `Evaluator`.

    This is for delegating to another `Evaluator`, e.g. for when you have a
    model with multiple heads. Users should manually invoke the child
    `Evaluator`'s `call` method from their `call` method.

    Args:
      prefix: A string. Metrics from `evaluator` are exported with this
        prefix and a '/'.
      evaluator: An `Evaluator` object.

    Returns:
      The value of `evaluator` passed into this function.

    Raises:
      RuntimeError: If called before __init__.
      TypeError: If `evaluator` is not of the correct type.
      ValueError: If an `Evaluator` has already been added with that `prefix`.
    """
    if not hasattr(self, "_evaluators"):
      raise RuntimeError(
          "Need to call Evaluator.__init__ before adding evaluators")
    if not isinstance(evaluator, Evaluator):
      raise TypeError(
          "Evaluator.track_evaluator() passed type %s, not a tfe.Evaluator." %
          (type(evaluator),))
    if prefix in self._evaluators:
      if evaluator is self._evaluators[prefix]:
        return evaluator
      raise RuntimeError(
          "Attempt to add two Evaluators with the same prefix '%s'." % prefix)
    self._evaluators[prefix] = evaluator
    return evaluator

  @property
  def metric_variables(self):
    v = []
    for metric in six.itervalues(self._metrics):
      v += metric.variables
    for evaluator in six.itervalues(self._evaluators):
      v += evaluator.metric_variables
    return v

  @property
  def metrics(self):
    """Returns a list of (prefix, metric) pairs."""
    m = []
    for metric in six.itervalues(self._metrics):
      m.append(("", metric))
    for prefix, evaluator in six.iteritems(self._evaluators):
      m += [(prefix + "/" + p, m) for p, m in evaluator.metrics]
    return m


class SparseSoftmaxEvaluator(Evaluator):
  """Evaluator for a sparse softmax model.

  Computes a standard set of metrics for single-label, multi-class
  models.

  Args:
    model: A `SparseSoftmaxModel` object or a `Model` whose `eval_data()`
      method produces a `dict` containing values for the loss, true
      label, predicted class, and optional weights.
    loss_key: Optional key for looking up the value of the loss in the
      `eval_data()` dict. Defaults to "loss".
    label_key: Optional key for looking up the value of the label in the
      `eval_data()` dict. Defaults to "label".
    predicted_class_key: Optional key for looking up the value of the
      predicted class in the `eval_data()` dict. Defaults to "predicted_class".
    weights_key: Optional key for looking up the value of the weights
      in the `eval_data()` dict. Defaults to "weights". Note that weights
      are optional, and default to 1 if not present in `eval_data`.
  """

  def __init__(self, model, loss_key="loss", label_key="label",
               predicted_class_key="predicted_class", weights_key="weights"):
    super(SparseSoftmaxEvaluator, self).__init__(model)
    # TODO(josh11b): Expand this to include everything from the standard
    # SparseSoftmax Head.
    self.avg_loss = self.track_metric(metrics.Mean("Avg Loss"))
    self.accuracy = self.track_metric(metrics.Accuracy())
    self.loss_key = loss_key
    self.label_key = label_key
    self.predicted_class_key = predicted_class_key
    self.weights_key = weights_key

  def call(self, eval_data):
    """Update metrics for `eval_data` dict (described above)."""
    weights = eval_data.get(self.weights_key, None)
    if weights is None:
      self.avg_loss(eval_data[self.loss_key])
      self.accuracy(eval_data[self.label_key],
                    eval_data[self.predicted_class_key])
    else:
      self.avg_loss(eval_data[self.loss_key], weights=weights)
      self.accuracy(eval_data[self.label_key],
                    eval_data[self.predicted_class_key],
                    weights=weights)