diff options
Diffstat (limited to 'tensorflow/python/framework/meta_graph_test.py')
-rw-r--r-- | tensorflow/python/framework/meta_graph_test.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 6e5f7aafac..fc98b91a01 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -117,7 +117,7 @@ class SimpleMetaGraphTest(test.TestCase): self.assertEqual(new_output_value, output_value) def testStrippedOpListNestedFunctions(self): - with self.test_session(): + with self.cached_session(): # Square two levels deep @function.Defun(dtypes.int32) def f0(x): @@ -169,7 +169,7 @@ class SimpleMetaGraphTest(test.TestCase): # and "Tout" maps to complex64. Since these attr values map to their # defaults, they must be stripped unless stripping of default attrs is # disabled. - with self.test_session(): + with self.cached_session(): real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real") imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") @@ -212,7 +212,8 @@ class SimpleMetaGraphTest(test.TestCase): def testDefaultAttrStrippingNestedFunctions(self): """Verifies that default attributes are stripped from function node defs.""" - with self.test_session(): + with self.cached_session(): + @function.Defun(dtypes.float32, dtypes.float32) def f0(i, j): return math_ops.complex(i, j, name="double_nested_complex") @@ -251,7 +252,7 @@ class SimpleMetaGraphTest(test.TestCase): meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() meta_info_def.stripped_op_list.op.add() - with self.test_session(): + with self.cached_session(): meta_graph_def = meta_graph.create_meta_graph_def( meta_info_def=meta_info_def, graph_def=graph_def, strip_default_attrs=True) |