aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-07 14:03:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-07 14:11:29 -0700
commitb1b7d5930ecdc9412e7a3035bdd2be49e9cfc230 (patch)
tree44de31557bac8640d0b6086d5e3d2935bb97982b /tensorflow
parent996605b0e4ef96e6732f7496abf44b6e5e1eb504 (diff)
Add a tag constant, gpu, to present graph with GPU support.
PiperOrigin-RevId: 161242660
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/reader_test.py11
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py14
-rw-r--r--tensorflow/python/saved_model/tag_constants.py5
-rw-r--r--tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt4
4 files changed, 31 insertions, 3 deletions
diff --git a/tensorflow/contrib/saved_model/python/saved_model/reader_test.py b/tensorflow/contrib/saved_model/python/saved_model/reader_test.py
index 76d5a3e96d..a8331cbc8f 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/reader_test.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/reader_test.py
@@ -81,16 +81,23 @@ class ReaderTest(test.TestCase):
# Graph that updates the single variable. SavedModel is invoked:
# - to add the model (weights are not updated).
- # - multiple custom tags.
+ # - multiple predefined tags.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 44)
+ builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU])
+
+ # Graph that updates the single variable. SavedModel is invoked:
+ # - to add the model (weights are not updated).
+ # - multiple custom tags.
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 45)
builder.add_meta_graph(["foo", "bar"])
# Save the SavedModel to disk.
builder.save()
actual_tags = reader.get_saved_model_tag_sets(saved_model_dir)
- expected_tags = [["train"], ["serve"], ["foo", "bar"]]
+ expected_tags = [["train"], ["serve"], ["serve", "gpu"], ["foo", "bar"]]
self.assertEqual(expected_tags, actual_tags)
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index fcd6bc3954..5639e6855d 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -207,6 +207,13 @@ class SavedModelTest(test.TestCase):
self._init_and_validate_variable(sess, "v", 43)
builder.add_meta_graph([tag_constants.SERVING])
+ # Graph that updates the single variable. SavedModel invoked to:
+ # - simply add the model (weights are not updated).
+ # - multiple tags (from predefined constants).
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 45)
+ builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU])
+
# Graph that updates the single variable. SavedModel is invoked:
# - to add the model (weights are not updated).
# - multiple custom tags.
@@ -230,6 +237,13 @@ class SavedModelTest(test.TestCase):
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
+ # Restore the graph with multiple predefined tags whose variables were not
+ # saved.
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, [tag_constants.SERVING, tag_constants.GPU], export_dir)
+ self.assertEqual(
+ 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
+
# Restore the graph with multiple tags. Provide duplicate tags to test set
# semantics.
with self.test_session(graph=ops.Graph()) as sess:
diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py
index 4fb9645dea..52868bdf99 100644
--- a/tensorflow/python/saved_model/tag_constants.py
+++ b/tensorflow/python/saved_model/tag_constants.py
@@ -28,9 +28,12 @@ SERVING = "serve"
# Tag for the `training` graph.
TRAINING = "train"
+# Tag for the `gpu` graph.
+GPU = "gpu"
_allowed_symbols = [
"SERVING",
- "TRAINING"
+ "TRAINING",
+ "GPU"
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt
index 7c24b7ad3c..35e49ee9f4 100644
--- a/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.saved_model.tag_constants"
tf_module {
member {
+ name: "GPU"
+ mtype: "<type \'str\'>"
+ }
+ member {
name: "SERVING"
mtype: "<type \'str\'>"
}