diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-21 12:24:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-21 12:27:46 -0800 |
commit | 6c7bd707ce26cc89d542bbb326882026a613748c (patch) | |
tree | 71672568b60e68f5aa5c1fa5ec5f1014fdd2d074 /tensorflow/python/saved_model | |
parent | b5dcb0161942c467be6cba19aa0ee05aef742d2e (diff) |
Add tpu saved model tags. No cpu tag is added because cpu is assumed to be the implicit device.
PiperOrigin-RevId: 176544698
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/saved_model_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/saved_model/tag_constants.py | 6 |
2 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index c6d2c32293..92ca7dec6f 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -214,6 +214,13 @@ class SavedModelTest(test.TestCase): self._init_and_validate_variable(sess, "v", 45) builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU]) + # Graph that updates the single variable. SavedModel invoked to: + # - simply add the model (weights are not updated). + # - multiple tags (from predefined constants for serving on TPU). + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 45) + builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU]) + # Graph that updates the single variable. SavedModel is invoked: # - to add the model (weights are not updated). # - multiple custom tags. @@ -244,6 +251,13 @@ class SavedModelTest(test.TestCase): self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) + # Restore the graph with multiple predefined tags (for serving on TPU) + # whose variables were not saved. + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, [tag_constants.SERVING, tag_constants.TPU], export_dir) + self.assertEqual( + 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) + # Restore the graph with multiple tags. Provide duplicate tags to test set # semantics. with self.test_session(graph=ops.Graph()) as sess: diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py index 52868bdf99..e2facafda5 100644 --- a/tensorflow/python/saved_model/tag_constants.py +++ b/tensorflow/python/saved_model/tag_constants.py @@ -31,9 +31,13 @@ TRAINING = "train" # Tag for the `gpu` graph. GPU = "gpu" +# Tag for the `tpu` graph. +TPU = "tpu" + _allowed_symbols = [ "SERVING", "TRAINING", - "GPU" + "GPU", + "TPU" ] remove_undocumented(__name__, _allowed_symbols) |