aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/predictor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-12 09:01:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-12 09:06:15 -0700
commit576c7b1ec86408e4aff0b32f63d1d3b306b32e41 (patch)
treedc7918c24a2e15a16b670db12e2257f9aa35fa72 /tensorflow/contrib/predictor
parent786bf6cd656d0d67e56bf50047ff116bae884b9e (diff)
Automated g4 rollback of changelist 161218103
PiperOrigin-RevId: 161671226
Diffstat (limited to 'tensorflow/contrib/predictor')
-rw-r--r--tensorflow/contrib/predictor/BUILD7
-rw-r--r--tensorflow/contrib/predictor/__init__.py16
-rw-r--r--tensorflow/contrib/predictor/saved_model_predictor.py32
3 files changed, 45 insertions, 10 deletions
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD
index c4b46551c1..8bd8c5f618 100644
--- a/tensorflow/contrib/predictor/BUILD
+++ b/tensorflow/contrib/predictor/BUILD
@@ -1,6 +1,6 @@
# `Predictor` classes provide an interface for efficient, repeated inference.
-package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"])
+package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"])
licenses(["notice"]) # Apache 2.0
@@ -62,7 +62,10 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":base_predictor",
- "//tensorflow/python/tools:saved_model_cli",
+ "//tensorflow/contrib/saved_model:saved_model_py",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow/python/saved_model:signature_def_utils",
],
)
diff --git a/tensorflow/contrib/predictor/__init__.py b/tensorflow/contrib/predictor/__init__.py
index d270c3f798..68146aea17 100644
--- a/tensorflow/contrib/predictor/__init__.py
+++ b/tensorflow/contrib/predictor/__init__.py
@@ -13,12 +13,20 @@
# limitations under the License.
# ==============================================================================
-"""Modules for `Predictor`s."""
+"""Modules for `Predictor`s.
+
+@@from_contrib_estimator
+@@from_estimator
+@@from_saved_model
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.predictor import from_contrib_estimator
-from tensorflow.contrib.predictor import from_estimator
-from tensorflow.contrib.predictor import from_saved_model
+from tensorflow.contrib.predictor.predictor_factories import from_contrib_estimator
+from tensorflow.contrib.predictor.predictor_factories import from_estimator
+from tensorflow.contrib.predictor.predictor_factories import from_saved_model
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py
index ab2bafa0c8..0dbca0f813 100644
--- a/tensorflow/contrib/predictor/saved_model_predictor.py
+++ b/tensorflow/contrib/predictor/saved_model_predictor.py
@@ -22,12 +22,12 @@ from __future__ import print_function
import logging
from tensorflow.contrib.predictor import predictor
+from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.tools import saved_model_cli
DEFAULT_TAGS = 'serve'
@@ -35,13 +35,37 @@ DEFAULT_TAGS = 'serve'
_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}'
+def get_meta_graph_def(saved_model_dir, tags):
+ """Gets `MetaGraphDef` from a directory containing a `SavedModel`.
+
+ Returns the `MetaGraphDef` for the given tag-set and SavedModel directory.
+
+ Args:
+ saved_model_dir: Directory containing the SavedModel.
+ tags: Comma separated list of tags used to identify the correct
+ `MetaGraphDef`.
+
+ Raises:
+ ValueError: An error when the given tags cannot be found.
+
+ Returns:
+ A `MetaGraphDef` corresponding to the given tags.
+ """
+ saved_model = reader.read_saved_model(saved_model_dir)
+ set_of_tags = set([tag.strip() for tag in tags.split(',')])
+ for meta_graph_def in saved_model.meta_graphs:
+ if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
+ return meta_graph_def
+ raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags))
+
+
def _get_signature_def(signature_def_key, export_dir, tags):
"""Construct a `SignatureDef` proto."""
signature_def_key = (
signature_def_key or
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
- metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags)
+ metagraph_def = get_meta_graph_def(export_dir, tags)
try:
signature_def = signature_def_utils.get_signature_def_by_key(
@@ -114,8 +138,8 @@ class SavedModelPredictor(predictor.Predictor):
output_names: A dictionary mapping strings to `Tensor`s in the
`SavedModel` that represent the output. The keys can be any string of
the user's choosing.
- tags: Optional. Tags that will be used to retrieve the correct
- `SignatureDef`. Defaults to `DEFAULT_TAGS`.
+ tags: Optional. Comma separated list of tags that will be used to retrieve
+ the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
Raises: