aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 71328a7f88..5ba226ce07 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -266,6 +266,13 @@ class TensorFlowTestCase(googletest.TestCase):
self._ClearCachedSession()
random.seed(random_seed.DEFAULT_GRAPH_SEED)
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
+ # Note: The following line is necessary because some test methods may error
+ # out from within nested graph contexts (e.g., via assertRaises and
+ # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty
+ # under certain versions of Python. That would cause
+ # ops.reset_default_graph() to throw an exception if the stack were not
+ # cleared first.
+ ops._default_graph_stack.reset() # pylint: disable=protected-access
ops.reset_default_graph()
ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED