aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/predictor/saved_model_predictor.py
blob: 03399396df55029bf2761364c5b6670f926b5f4b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A `Predictor` constructed from a `SavedModel`."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging

from tensorflow.contrib.predictor import predictor
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants


DEFAULT_TAGS = 'serve'

_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}'


def get_meta_graph_def(saved_model_dir, tags):
  """Gets `MetaGraphDef` from a directory containing a `SavedModel`.

  Returns the `MetaGraphDef` for the given tag-set and SavedModel directory.

  Args:
    saved_model_dir: Directory containing the SavedModel.
    tags: Comma separated list of tags used to identify the correct
      `MetaGraphDef`.

  Raises:
    ValueError: An error when the given tags cannot be found.

  Returns:
    A `MetaGraphDef` corresponding to the given tags.
  """
  saved_model = reader.read_saved_model(saved_model_dir)
  set_of_tags = set([tag.strip() for tag in tags.split(',')])
  for meta_graph_def in saved_model.meta_graphs:
    if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
      return meta_graph_def
  raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags))


def _get_signature_def(signature_def_key, export_dir, tags):
  """Construct a `SignatureDef` proto."""
  signature_def_key = (
      signature_def_key or
      signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)

  metagraph_def = get_meta_graph_def(export_dir, tags)

  try:
    signature_def = metagraph_def.signature_def[signature_def_key]
  except KeyError as e:
    formatted_key = _DEFAULT_INPUT_ALTERNATIVE_FORMAT.format(
        signature_def_key)
    try:
      signature_def = metagraph_def.signature_def[formatted_key]
    except KeyError:
      raise ValueError(
          'Got signature_def_key "{}". Available signatures are {}. '
          'Original error:\n{}'.format(
              signature_def_key, list(metagraph_def.signature_def), e))
    logging.warning('Could not find signature def "%s". '
                    'Using "%s" instead', signature_def_key, formatted_key)
  return signature_def


def _check_signature_arguments(signature_def_key,
                               signature_def,
                               input_names,
                               output_names):
  """Validates signature arguments for `SavedModelPredictor`."""
  signature_def_key_specified = signature_def_key is not None
  signature_def_specified = signature_def is not None
  input_names_specified = input_names is not None
  output_names_specified = output_names is not None
  if input_names_specified != output_names_specified:
    raise ValueError(
        'input_names and output_names must both be specified or both be '
        'unspecified.'
    )

  if (signature_def_key_specified + signature_def_specified +
      input_names_specified > 1):
    raise ValueError(
        'You must specify at most one of signature_def_key OR signature_def OR'
        '(input_names AND output_names).'
    )


class SavedModelPredictor(predictor.Predictor):
  """A `Predictor` constructed from a `SavedModel`."""

  def __init__(self,
               export_dir,
               signature_def_key=None,
               signature_def=None,
               input_names=None,
               output_names=None,
               tags=None,
               graph=None,
               config=None):
    """Initialize a `CoreEstimatorPredictor`.

    Args:
      export_dir: a path to a directory containing a `SavedModel`.
      signature_def_key: Optional string specifying the signature to use. If
        `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of
        `signature_def_key` and `signature_def` should be specified.
      signature_def: A `SignatureDef` proto specifying the inputs and outputs
        for prediction. Only one of `signature_def_key` and `signature_def`
        should be specified.
      input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel`
        that represent the input. The keys can be any string of the user's
        choosing.
      output_names: A dictionary mapping strings to `Tensor`s in the
        `SavedModel` that represent the output. The keys can be any string of
        the user's choosing.
      tags: Optional. Comma separated list of tags that will be used to retrieve
        the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`.
      graph: Optional. The Tensorflow `graph` in which prediction should be
        done.
      config: `ConfigProto` proto used to configure the session.
    Raises:
      ValueError: If more than one of signature_def_key OR signature_def OR
        (input_names AND output_names) is specified.
    """
    _check_signature_arguments(
        signature_def_key, signature_def, input_names, output_names)
    tags = tags or DEFAULT_TAGS
    self._graph = graph or ops.Graph()

    with self._graph.as_default():
      self._session = session.Session(config=config)
      loader.load(self._session, tags.split(','), export_dir)

    if input_names is None:
      if signature_def is None:
        signature_def = _get_signature_def(signature_def_key, export_dir, tags)
      input_names = {k: v.name for k, v in signature_def.inputs.items()}
      output_names = {k: v.name for k, v in signature_def.outputs.items()}

    self._feed_tensors = {k: self._graph.get_tensor_by_name(v)
                          for k, v in input_names.items()}
    self._fetch_tensors = {k: self._graph.get_tensor_by_name(v)
                           for k, v in output_names.items()}