aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-01-29 15:34:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 15:38:16 -0800
commite08f0080f822543d0a306075878c2e35dabf8cc0 (patch)
tree9b5092cd9d8901c867513de03911e17c8ce88203 /tensorflow/python/saved_model
parente4021e7060166ead2fc14a94c048b5fc5336e495 (diff)
Make with_c_api more robust and enable C API in most of saved_model_test.py.
This change makes the test_util.with_c_api decorator call reset_default_graph() after enabling or disabling the C API instead of creating a new Graph. This makes it more robust to tests that call reset_default_graph(), which requires that the current default graph isn't nested (which the C API-enabled Graph previously was). In addition, enables the C API with saved_model_test.py (which required the above change). A few tests still need further changes, which I'll post in subsequent patches. PiperOrigin-RevId: 183739148
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py61
1 files changed, 34 insertions, 27 deletions
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 1ea619ff55..f92247d52e 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -54,8 +54,14 @@ def tearDownModule():
file_io.delete_recursively(test.get_temp_dir())
+@test_util.with_c_api
class SavedModelTest(test.TestCase):
+ def _get_export_dir(self, label):
+ if ops._USE_C_API:
+ label += "_c_api"
+ return os.path.join(test.get_temp_dir(), label)
+
def _init_and_validate_variable(self, sess, variable_name, variable_value):
v = variables.Variable(variable_value, name=variable_name)
sess.run(variables.global_variables_initializer())
@@ -123,8 +129,7 @@ class SavedModelTest(test.TestCase):
self.assertFalse(loader.maybe_saved_model_directory(base_path))
def testBadSavedModelFileFormat(self):
- export_dir = os.path.join(test.get_temp_dir(),
- "test_bad_saved_model_file_format")
+ export_dir = self._get_export_dir("test_bad_saved_model_file_format")
# Attempt to load a SavedModel from an export directory that does not exist.
with self.test_session(graph=ops.Graph()) as sess:
with self.assertRaisesRegexp(IOError,
@@ -157,8 +162,7 @@ class SavedModelTest(test.TestCase):
loader.load(sess, ["foo"], export_dir)
def testVerifySessionGraphUsage(self):
- export_dir = os.path.join(test.get_temp_dir(),
- "test_verify_session_graph_usage")
+ export_dir = self._get_export_dir("test_verify_session_graph_usage")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
@@ -178,7 +182,7 @@ class SavedModelTest(test.TestCase):
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
def testSequence(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_sequence")
+ export_dir = self._get_export_dir("test_sequence")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Expect an assertion error since add_meta_graph_and_variables() should be
@@ -195,7 +199,7 @@ class SavedModelTest(test.TestCase):
sess, ["baz"])
def testTags(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_tags")
+ export_dir = self._get_export_dir("test_tags")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable. SavedModel invoked to:
@@ -284,7 +288,7 @@ class SavedModelTest(test.TestCase):
export_dir)
def testVariables(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_variables")
+ export_dir = self._get_export_dir("test_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with two variables. SavedModel invoked to:
@@ -336,7 +340,7 @@ class SavedModelTest(test.TestCase):
export_dir)
def testGraphWithoutVariables(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_graph_has_variables")
+ export_dir = self._get_export_dir("test_graph_has_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with no variables.
@@ -371,7 +375,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual(30.0, sess.run(c))
def testNoOverwrite(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_no_overwrite")
+ export_dir = self._get_export_dir("test_no_overwrite")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable. SavedModel invoked to:
@@ -395,7 +399,7 @@ class SavedModelTest(test.TestCase):
export_dir)
def testSaveAsText(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_astext")
+ export_dir = self._get_export_dir("test_astext")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable. SavedModel invoked to:
@@ -426,7 +430,7 @@ class SavedModelTest(test.TestCase):
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
def testCollections(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_collections")
+ export_dir = self._get_export_dir("test_collections")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable added to a collection. SavedModel invoked to:
@@ -476,7 +480,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual(len(ops.get_collection("foo_vars")), 0)
def testSignatureDefs(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_signature_defs")
+ export_dir = self._get_export_dir("test_signature_defs")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable and a single entry in the signature def map.
@@ -536,8 +540,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
def testSignatureDefValidation(self):
- export_dir = os.path.join(test.get_temp_dir(),
- "test_signature_def_validation")
+ export_dir = self._get_export_dir("test_signature_def_validation")
builder = saved_model_builder.SavedModelBuilder(export_dir)
tensor_without_name = meta_graph_pb2.TensorInfo()
@@ -555,7 +558,7 @@ class SavedModelTest(test.TestCase):
self._validate_outputs_tensor_info(builder, tensor_empty)
def testAssets(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_assets")
+ export_dir = self._get_export_dir("test_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
@@ -588,7 +591,7 @@ class SavedModelTest(test.TestCase):
self.assertFalse(file_io.file_exists(ignored_asset_path))
def testCustomMainOp(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_main_op")
+ export_dir = self._get_export_dir("test_main_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
@@ -623,7 +626,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual(3, ops.get_collection("v")[2].eval())
def testLegacyInitOp(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_legacy_init_op")
+ export_dir = self._get_export_dir("test_legacy_init_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
@@ -657,8 +660,8 @@ class SavedModelTest(test.TestCase):
self.assertEqual(3, ops.get_collection("v")[2].eval())
def testLegacyInitOpWithNonEmptyCollection(self):
- export_dir = os.path.join(test.get_temp_dir(),
- "test_legacy_init_op_with_non_empty_collection")
+ export_dir = self._get_export_dir(
+ "test_legacy_init_op_with_non_empty_collection")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
@@ -685,7 +688,7 @@ class SavedModelTest(test.TestCase):
sess, ["foo"], legacy_init_op=legacy_init_op)
def testMultipleAssets(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_multiple_assets")
+ export_dir = self._get_export_dir("test_multiple_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
@@ -727,7 +730,7 @@ class SavedModelTest(test.TestCase):
"asset_file_tensor:0")
def testDuplicateAssets(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_duplicate_assets")
+ export_dir = self._get_export_dir("test_duplicate_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
@@ -775,7 +778,7 @@ class SavedModelTest(test.TestCase):
"asset_file_tensor:0")
def testOp(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_op")
+ export_dir = self._get_export_dir("test_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with session.Session(
@@ -818,7 +821,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual(3, ops.get_collection("v")[2].eval())
def testCustomSaveable(self):
- export_dir = os.path.join(test.get_temp_dir(), "custom_saveable")
+ export_dir = self._get_export_dir("custom_saveable")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with session.Session(
@@ -847,7 +850,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual(3.0, v1.values().eval())
def testClearDevices(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_clear_devices")
+ export_dir = self._get_export_dir("test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Specify a device and save a variable.
@@ -871,7 +874,9 @@ class SavedModelTest(test.TestCase):
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
def testStripDefaultAttrs(self):
- export_dir = os.path.join(test.get_temp_dir(), "test_strip_default_attrs")
+ if ops._USE_C_API: return # TODO(skyewm): get this working
+
+ export_dir = self._get_export_dir("test_strip_default_attrs")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Add a graph with two float32 variables and a Complex Op composing them
@@ -941,8 +946,10 @@ class SavedModelTest(test.TestCase):
self.assertIn("Tout", node_def.attr)
def testStripDefaultAttrsInconsistentConsumerDefaults(self):
- export_dir = os.path.join(test.get_temp_dir(),
- "test_strip_default_attrs_no_consumer_defaults")
+ if ops._USE_C_API: return # TODO(skyewm): get this working
+
+ export_dir = self._get_export_dir(
+ "test_strip_default_attrs_no_consumer_defaults")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Add a graph with two float32 variables and a Complex Op composing them