diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-07 14:03:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-07 14:11:29 -0700 |
commit | b1b7d5930ecdc9412e7a3035bdd2be49e9cfc230 (patch) | |
tree | 44de31557bac8640d0b6086d5e3d2935bb97982b /tensorflow | |
parent | 996605b0e4ef96e6732f7496abf44b6e5e1eb504 (diff) |
Add a tag constant, gpu, to present graph with GPU support.
PiperOrigin-RevId: 161242660
Diffstat (limited to 'tensorflow')
4 files changed, 31 insertions, 3 deletions
diff --git a/tensorflow/contrib/saved_model/python/saved_model/reader_test.py b/tensorflow/contrib/saved_model/python/saved_model/reader_test.py index 76d5a3e96d..a8331cbc8f 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/reader_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/reader_test.py @@ -81,16 +81,23 @@ class ReaderTest(test.TestCase): # Graph that updates the single variable. SavedModel is invoked: # - to add the model (weights are not updated). - # - multiple custom tags. + # - multiple predefined tags. with self.test_session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 44) + builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU]) + + # Graph that updates the single variable. SavedModel is invoked: + # - to add the model (weights are not updated). + # - multiple custom tags. + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 45) builder.add_meta_graph(["foo", "bar"]) # Save the SavedModel to disk. builder.save() actual_tags = reader.get_saved_model_tag_sets(saved_model_dir) - expected_tags = [["train"], ["serve"], ["foo", "bar"]] + expected_tags = [["train"], ["serve"], ["serve", "gpu"], ["foo", "bar"]] self.assertEqual(expected_tags, actual_tags) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index fcd6bc3954..5639e6855d 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -207,6 +207,13 @@ class SavedModelTest(test.TestCase): self._init_and_validate_variable(sess, "v", 43) builder.add_meta_graph([tag_constants.SERVING]) + # Graph that updates the single variable. SavedModel invoked to: + # - simply add the model (weights are not updated). + # - multiple tags (from predefined constants). + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 45) + builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU]) + # Graph that updates the single variable. SavedModel is invoked: # - to add the model (weights are not updated). # - multiple custom tags. @@ -230,6 +237,13 @@ class SavedModelTest(test.TestCase): self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) + # Restore the graph with multiple predefined tags whose variables were not + # saved. + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, [tag_constants.SERVING, tag_constants.GPU], export_dir) + self.assertEqual( + 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) + # Restore the graph with multiple tags. Provide duplicate tags to test set # semantics. with self.test_session(graph=ops.Graph()) as sess: diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py index 4fb9645dea..52868bdf99 100644 --- a/tensorflow/python/saved_model/tag_constants.py +++ b/tensorflow/python/saved_model/tag_constants.py @@ -28,9 +28,12 @@ SERVING = "serve" # Tag for the `training` graph. TRAINING = "train" +# Tag for the `gpu` graph. +GPU = "gpu" _allowed_symbols = [ "SERVING", - "TRAINING" + "TRAINING", + "GPU" ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt index 7c24b7ad3c..35e49ee9f4 100644 --- a/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt @@ -1,6 +1,10 @@ path: "tensorflow.saved_model.tag_constants" tf_module { member { + name: "GPU" + mtype: "<type \'str\'>" + } + member { name: "SERVING" mtype: "<type \'str\'>" } |