aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-21 12:24:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-21 12:27:46 -0800
commit6c7bd707ce26cc89d542bbb326882026a613748c (patch)
tree71672568b60e68f5aa5c1fa5ec5f1014fdd2d074 /tensorflow/python/saved_model
parentb5dcb0161942c467be6cba19aa0ee05aef742d2e (diff)
Add tpu saved model tags. No cpu tag is added because cpu is assumed to be the implicit device.
PiperOrigin-RevId: 176544698
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py14
-rw-r--r--tensorflow/python/saved_model/tag_constants.py6
2 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index c6d2c32293..92ca7dec6f 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -214,6 +214,13 @@ class SavedModelTest(test.TestCase):
self._init_and_validate_variable(sess, "v", 45)
builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU])
+ # Graph that updates the single variable. SavedModel invoked to:
+ # - simply add the model (weights are not updated).
+ # - multiple tags (from predefined constants for serving on TPU).
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 45)
+ builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU])
+
# Graph that updates the single variable. SavedModel is invoked:
# - to add the model (weights are not updated).
# - multiple custom tags.
@@ -244,6 +251,13 @@ class SavedModelTest(test.TestCase):
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
+ # Restore the graph with multiple predefined tags (for serving on TPU)
+ # whose variables were not saved.
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, [tag_constants.SERVING, tag_constants.TPU], export_dir)
+ self.assertEqual(
+ 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
+
# Restore the graph with multiple tags. Provide duplicate tags to test set
# semantics.
with self.test_session(graph=ops.Graph()) as sess:
diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py
index 52868bdf99..e2facafda5 100644
--- a/tensorflow/python/saved_model/tag_constants.py
+++ b/tensorflow/python/saved_model/tag_constants.py
@@ -31,9 +31,13 @@ TRAINING = "train"
# Tag for the `gpu` graph.
GPU = "gpu"
+# Tag for the `tpu` graph.
+TPU = "tpu"
+
_allowed_symbols = [
"SERVING",
"TRAINING",
- "GPU"
+ "GPU",
+ "TPU"
]
remove_undocumented(__name__, _allowed_symbols)