diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-07 08:20:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-07 08:24:37 -0700 |
commit | 81e81b796dd40a8294d166eb457e609aaedb4540 (patch) | |
tree | 1d40f177a910ffcfe5c385a775fd5eaf0872174d /tensorflow/contrib/predictor | |
parent | aee58d07201e6cd247cd367ca34973c6d8611564 (diff) |
Fix dependencies and import statements for predictor module.
PiperOrigin-RevId: 161203536
Diffstat (limited to 'tensorflow/contrib/predictor')
-rw-r--r-- | tensorflow/contrib/predictor/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/contrib/predictor/__init__.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/predictor/saved_model_predictor.py | 32 |
3 files changed, 32 insertions, 9 deletions
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index c4b46551c1..e298fd3cb2 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -1,6 +1,6 @@ # `Predictor` classes provide an interface for efficient, repeated inference. -package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"]) +package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"]) licenses(["notice"]) # Apache 2.0 @@ -62,7 +62,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":base_predictor", - "//tensorflow/python/tools:saved_model_cli", ], ) diff --git a/tensorflow/contrib/predictor/__init__.py b/tensorflow/contrib/predictor/__init__.py index d270c3f798..e0a2152b37 100644 --- a/tensorflow/contrib/predictor/__init__.py +++ b/tensorflow/contrib/predictor/__init__.py @@ -19,6 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.predictor import from_contrib_estimator -from tensorflow.contrib.predictor import from_estimator -from tensorflow.contrib.predictor import from_saved_model +from tensorflow.contrib.predictor.predictor_factories import from_contrib_estimator +from tensorflow.contrib.predictor.predictor_factories import from_estimator +from tensorflow.contrib.predictor.predictor_factories import from_saved_model diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py index ab2bafa0c8..0dbca0f813 100644 --- a/tensorflow/contrib/predictor/saved_model_predictor.py +++ b/tensorflow/contrib/predictor/saved_model_predictor.py @@ -22,12 +22,12 @@ from __future__ import print_function import logging from tensorflow.contrib.predictor import predictor +from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import signature_constants -from tensorflow.python.tools import saved_model_cli DEFAULT_TAGS = 'serve' @@ -35,13 +35,37 @@ DEFAULT_TAGS = 'serve' _DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}' +def get_meta_graph_def(saved_model_dir, tags): + """Gets `MetaGraphDef` from a directory containing a `SavedModel`. + + Returns the `MetaGraphDef` for the given tag-set and SavedModel directory. + + Args: + saved_model_dir: Directory containing the SavedModel. + tags: Comma separated list of tags used to identify the correct + `MetaGraphDef`. + + Raises: + ValueError: An error when the given tags cannot be found. + + Returns: + A `MetaGraphDef` corresponding to the given tags. + """ + saved_model = reader.read_saved_model(saved_model_dir) + set_of_tags = set([tag.strip() for tag in tags.split(',')]) + for meta_graph_def in saved_model.meta_graphs: + if set(meta_graph_def.meta_info_def.tags) == set_of_tags: + return meta_graph_def + raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags)) + + def _get_signature_def(signature_def_key, export_dir, tags): """Construct a `SignatureDef` proto.""" signature_def_key = ( signature_def_key or signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) - metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags) + metagraph_def = get_meta_graph_def(export_dir, tags) try: signature_def = signature_def_utils.get_signature_def_by_key( @@ -114,8 +138,8 @@ class SavedModelPredictor(predictor.Predictor): output_names: A dictionary mapping strings to `Tensor`s in the `SavedModel` that represent the output. The keys can be any string of the user's choosing. - tags: Optional. Tags that will be used to retrieve the correct - `SignatureDef`. Defaults to `DEFAULT_TAGS`. + tags: Optional. Comma separated list of tags that will be used to retrieve + the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be done. Raises: |