diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2017-07-07 10:42:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-07 10:46:47 -0700 |
commit | ac1c5224a1dd0c3142ef06d76f6b62097210959e (patch) | |
tree | 877961b683d82e5c0f36694a1647e8f567be8357 /tensorflow/contrib/predictor | |
parent | cc342cfb7a5a8cc174f1f84df0319161cddf48ea (diff) |
Automated g4 rollback of changelist 161203536
PiperOrigin-RevId: 161218103
Diffstat (limited to 'tensorflow/contrib/predictor')
-rw-r--r-- | tensorflow/contrib/predictor/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/contrib/predictor/__init__.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/predictor/saved_model_predictor.py | 32 |
3 files changed, 9 insertions, 32 deletions
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index e298fd3cb2..c4b46551c1 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -1,6 +1,6 @@ # `Predictor` classes provide an interface for efficient, repeated inference. -package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"]) +package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"]) licenses(["notice"]) # Apache 2.0 @@ -62,6 +62,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":base_predictor", + "//tensorflow/python/tools:saved_model_cli", ], ) diff --git a/tensorflow/contrib/predictor/__init__.py b/tensorflow/contrib/predictor/__init__.py index e0a2152b37..d270c3f798 100644 --- a/tensorflow/contrib/predictor/__init__.py +++ b/tensorflow/contrib/predictor/__init__.py @@ -19,6 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.predictor.predictor_factories import from_contrib_estimator -from tensorflow.contrib.predictor.predictor_factories import from_estimator -from tensorflow.contrib.predictor.predictor_factories import from_saved_model +from tensorflow.contrib.predictor import from_contrib_estimator +from tensorflow.contrib.predictor import from_estimator +from tensorflow.contrib.predictor import from_saved_model diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py index 0dbca0f813..ab2bafa0c8 100644 --- a/tensorflow/contrib/predictor/saved_model_predictor.py +++ b/tensorflow/contrib/predictor/saved_model_predictor.py @@ -22,12 +22,12 @@ from __future__ import print_function import logging from tensorflow.contrib.predictor import predictor -from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import signature_constants +from tensorflow.python.tools import saved_model_cli DEFAULT_TAGS = 'serve' @@ -35,37 +35,13 @@ DEFAULT_TAGS = 'serve' _DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}' -def get_meta_graph_def(saved_model_dir, tags): - """Gets `MetaGraphDef` from a directory containing a `SavedModel`. - - Returns the `MetaGraphDef` for the given tag-set and SavedModel directory. - - Args: - saved_model_dir: Directory containing the SavedModel. - tags: Comma separated list of tags used to identify the correct - `MetaGraphDef`. - - Raises: - ValueError: An error when the given tags cannot be found. - - Returns: - A `MetaGraphDef` corresponding to the given tags. - """ - saved_model = reader.read_saved_model(saved_model_dir) - set_of_tags = set([tag.strip() for tag in tags.split(',')]) - for meta_graph_def in saved_model.meta_graphs: - if set(meta_graph_def.meta_info_def.tags) == set_of_tags: - return meta_graph_def - raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags)) - - def _get_signature_def(signature_def_key, export_dir, tags): """Construct a `SignatureDef` proto.""" signature_def_key = ( signature_def_key or signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) - metagraph_def = get_meta_graph_def(export_dir, tags) + metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags) try: signature_def = signature_def_utils.get_signature_def_by_key( @@ -138,8 +114,8 @@ class SavedModelPredictor(predictor.Predictor): output_names: A dictionary mapping strings to `Tensor`s in the `SavedModel` that represent the output. The keys can be any string of the user's choosing. - tags: Optional. Comma separated list of tags that will be used to retrieve - the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`. + tags: Optional. Tags that will be used to retrieve the correct + `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be done. Raises: |