aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/export/export_output.py
blob: 20382a58d8d6fa5be938ee08fcf1487043868301 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes for different types of export output."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc

import six


from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util.tf_export import estimator_export


@estimator_export('estimator.export.ExportOutput')
class ExportOutput(object):
  """Represents an output of a model that can be served.

  These typically correspond to model heads.
  """

  __metaclass__ = abc.ABCMeta

  _SEPARATOR_CHAR = '/'

  @abc.abstractmethod
  def as_signature_def(self, receiver_tensors):
    """Generate a SignatureDef proto for inclusion in a MetaGraphDef.

    The SignatureDef will specify outputs as described in this ExportOutput,
    and will use the provided receiver_tensors as inputs.

    Args:
      receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying
        input nodes that will be fed.
    """
    pass

  def _check_output_key(self, key, error_label):
    # For multi-head models, the key can be a tuple.
    if isinstance(key, tuple):
      key = self._SEPARATOR_CHAR.join(key)

    if not isinstance(key, six.string_types):
      raise ValueError(
          '{} output key must be a string; got {}.'.format(error_label, key))
    return key

  def _wrap_and_check_outputs(
      self, outputs, single_output_default_name, error_label=None):
    """Wraps raw tensors as dicts and checks type.

    Note that we create a new dict here so that we can overwrite the keys
    if necessary.

    Args:
      outputs: A `Tensor` or a dict of string to `Tensor`.
      single_output_default_name: A string key for use in the output dict
        if the provided `outputs` is a raw tensor.
      error_label: descriptive string for use in error messages. If none,
        single_output_default_name will be used.

    Returns:
      A dict of tensors

    Raises:
      ValueError: if the outputs dict keys are not strings or tuples of strings
        or the values are not Tensors.
    """
    if not isinstance(outputs, dict):
      outputs = {single_output_default_name: outputs}

    output_dict = {}
    for key, value in outputs.items():
      error_name = error_label or single_output_default_name
      key = self._check_output_key(key, error_name)
      if not isinstance(value, ops.Tensor):
        raise ValueError(
            '{} output value must be a Tensor; got {}.'.format(
                error_name, value))

      output_dict[key] = value
    return output_dict


@estimator_export('estimator.export.ClassificationOutput')
class ClassificationOutput(ExportOutput):
  """Represents the output of a classification head.

  Either classes or scores or both must be set.

  The classes `Tensor` must provide string labels, not integer class IDs.

  If only classes is set, it is interpreted as providing top-k results in
  descending order.

  If only scores is set, it is interpreted as providing a score for every class
  in order of class ID.

  If both classes and scores are set, they are interpreted as zipped, so each
  score corresponds to the class at the same index.  Clients should not depend
  on the order of the entries.
  """

  def __init__(self, scores=None, classes=None):
    """Constructor for `ClassificationOutput`.

    Args:
      scores: A float `Tensor` giving scores (sometimes but not always
          interpretable as probabilities) for each class.  May be `None`, but
          only if `classes` is set.  Interpretation varies-- see class doc.
      classes: A string `Tensor` giving predicted class labels.  May be `None`,
          but only if `scores` is set.  Interpretation varies-- see class doc.

    Raises:
      ValueError: if neither classes nor scores is set, or one of them is not a
          `Tensor` with the correct dtype.
    """
    if (scores is not None
        and not (isinstance(scores, ops.Tensor)
                 and scores.dtype.is_floating)):
      raise ValueError('Classification scores must be a float32 Tensor; '
                       'got {}'.format(scores))
    if (classes is not None
        and not (isinstance(classes, ops.Tensor)
                 and dtypes.as_dtype(classes.dtype) == dtypes.string)):
      raise ValueError('Classification classes must be a string Tensor; '
                       'got {}'.format(classes))
    if scores is None and classes is None:
      raise ValueError('At least one of scores and classes must be set.')

    self._scores = scores
    self._classes = classes

  @property
  def scores(self):
    return self._scores

  @property
  def classes(self):
    return self._classes

  def as_signature_def(self, receiver_tensors):
    if len(receiver_tensors) != 1:
      raise ValueError('Classification input must be a single string Tensor; '
                       'got {}'.format(receiver_tensors))
    (_, examples), = receiver_tensors.items()
    if dtypes.as_dtype(examples.dtype) != dtypes.string:
      raise ValueError('Classification input must be a single string Tensor; '
                       'got {}'.format(receiver_tensors))
    return signature_def_utils.classification_signature_def(
        examples, self.classes, self.scores)


@estimator_export('estimator.export.RegressionOutput')
class RegressionOutput(ExportOutput):
  """Represents the output of a regression head."""

  def __init__(self, value):
    """Constructor for `RegressionOutput`.

    Args:
      value: a float `Tensor` giving the predicted values.  Required.

    Raises:
      ValueError: if the value is not a `Tensor` with dtype tf.float32.
    """
    if not (isinstance(value, ops.Tensor) and value.dtype.is_floating):
      raise ValueError('Regression output value must be a float32 Tensor; '
                       'got {}'.format(value))
    self._value = value

  @property
  def value(self):
    return self._value

  def as_signature_def(self, receiver_tensors):
    if len(receiver_tensors) != 1:
      raise ValueError('Regression input must be a single string Tensor; '
                       'got {}'.format(receiver_tensors))
    (_, examples), = receiver_tensors.items()
    if dtypes.as_dtype(examples.dtype) != dtypes.string:
      raise ValueError('Regression input must be a single string Tensor; '
                       'got {}'.format(receiver_tensors))
    return signature_def_utils.regression_signature_def(examples, self.value)


@estimator_export('estimator.export.PredictOutput')
class PredictOutput(ExportOutput):
  """Represents the output of a generic prediction head.

  A generic prediction need not be either a classification or a regression.

  Named outputs must be provided as a dict from string to `Tensor`,
  """
  _SINGLE_OUTPUT_DEFAULT_NAME = 'output'

  def __init__(self, outputs):
    """Constructor for PredictOutput.

    Args:
      outputs: A `Tensor` or a dict of string to `Tensor` representing the
        predictions.

    Raises:
      ValueError: if the outputs is not dict, or any of its keys are not
          strings, or any of its values are not `Tensor`s.
    """

    self._outputs = self._wrap_and_check_outputs(
        outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction')

  @property
  def outputs(self):
    return self._outputs

  def as_signature_def(self, receiver_tensors):
    return signature_def_utils.predict_signature_def(receiver_tensors,
                                                     self.outputs)


class _SupervisedOutput(ExportOutput):
  """Represents the output of a supervised training or eval process."""
  __metaclass__ = abc.ABCMeta

  LOSS_NAME = 'loss'
  PREDICTIONS_NAME = 'predictions'
  METRICS_NAME = 'metrics'

  METRIC_VALUE_SUFFIX = 'value'
  METRIC_UPDATE_SUFFIX = 'update_op'

  _loss = None
  _predictions = None
  _metrics = None

  def __init__(self, loss=None, predictions=None, metrics=None):
    """Constructor for SupervisedOutput (ie, Train or Eval output).

    Args:
      loss: dict of Tensors or single Tensor representing calculated loss.
      predictions: dict of Tensors or single Tensor representing model
        predictions.
      metrics: dict of (metric_value, update_op) tuples, or a single tuple.
        metric_value must be a Tensor, and update_op must be a Tensor or Op.

    Raises:
      ValueError: if any of the outputs' dict keys are not strings or tuples of
        strings or the values are not Tensors (or Operations in the case of
        update_op).
    """

    if loss is not None:
      loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME)
      self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME)
    if predictions is not None:
      pred_dict = self._wrap_and_check_outputs(
          predictions, self.PREDICTIONS_NAME)
      self._predictions = self._prefix_output_keys(
          pred_dict, self.PREDICTIONS_NAME)
    if metrics is not None:
      self._metrics = self._wrap_and_check_metrics(metrics)

  def _prefix_output_keys(self, output_dict, output_name):
    """Prepend output_name to the output_dict keys if it doesn't exist.

    This produces predictable prefixes for the pre-determined outputs
    of SupervisedOutput.

    Args:
      output_dict: dict of string to Tensor, assumed valid.
      output_name: prefix string to prepend to existing keys.

    Returns:
      dict with updated keys and existing values.
    """

    new_outputs = {}
    for key, val in output_dict.items():
      key = self._prefix_key(key, output_name)
      new_outputs[key] = val
    return new_outputs

  def _prefix_key(self, key, output_name):
    if key.find(output_name) != 0:
      key = output_name + self._SEPARATOR_CHAR + key
    return key

  def _wrap_and_check_metrics(self, metrics):
    """Handle the saving of metrics.

    Metrics is either a tuple of (value, update_op), or a dict of such tuples.
    Here, we separate out the tuples and create a dict with names to tensors.

    Args:
      metrics: dict of (metric_value, update_op) tuples, or a single tuple.

    Returns:
      dict of output_names to tensors

    Raises:
      ValueError: if the dict key is not a string, or the metric values or ops
        are not tensors.
    """
    if not isinstance(metrics, dict):
      metrics = {self.METRICS_NAME: metrics}

    outputs = {}
    for key, (metric_val, metric_op) in metrics.items():
      key = self._check_output_key(key, self.METRICS_NAME)
      key = self._prefix_key(key, self.METRICS_NAME)

      val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX
      op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX
      if not isinstance(metric_val, ops.Tensor):
        raise ValueError(
            '{} output value must be a Tensor; got {}.'.format(
                key, metric_val))
      if (not isinstance(metric_op, ops.Tensor) and
          not isinstance(metric_op, ops.Operation)):
        raise ValueError(
            '{} update_op must be a Tensor or Operation; got {}.'.format(
                key, metric_op))

      # We must wrap any ops in a Tensor before export, as the SignatureDef
      # proto expects tensors only. See b/109740581
      metric_op_tensor = metric_op
      if isinstance(metric_op, ops.Operation):
        with ops.control_dependencies([metric_op]):
          metric_op_tensor = constant_op.constant([], name='metric_op_wrapper')

      outputs[val_name] = metric_val
      outputs[op_name] = metric_op_tensor

    return outputs

  @property
  def loss(self):
    return self._loss

  @property
  def predictions(self):
    return self._predictions

  @property
  def metrics(self):
    return self._metrics

  @abc.abstractmethod
  def _get_signature_def_fn(self):
    """Returns a function that produces a SignatureDef given desired outputs."""
    pass

  def as_signature_def(self, receiver_tensors):
    signature_def_fn = self._get_signature_def_fn()
    return signature_def_fn(
        receiver_tensors, self.loss, self.predictions, self.metrics)


class TrainOutput(_SupervisedOutput):
  """Represents the output of a supervised training process.

  This class generates the appropriate signature def for exporting
  training output by type-checking and wrapping loss, predictions, and metrics
  values.
  """

  def _get_signature_def_fn(self):
    return signature_def_utils.supervised_train_signature_def


class EvalOutput(_SupervisedOutput):
  """Represents the output of a supervised eval process.

  This class generates the appropriate signature def for exporting
  eval output by type-checking and wrapping loss, predictions, and metrics
  values.
  """

  def _get_signature_def_fn(self):
    return signature_def_utils.supervised_eval_signature_def