aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/monitored_session_test.py
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-12-04 17:36:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-04 17:40:05 -0800
commit4ff0f280053187e6360f0812198813ed576d6b62 (patch)
tree9e7c7b94f47bfd5d3656a408c3f2cb14a3d592d3 /tensorflow/python/training/monitored_session_test.py
parent601687d9f5046f411be556f28b6c82ac035696f9 (diff)
Reproduce an issue with MonitoredSession when saving a variable on a GPU.
Also arrange for continuous testing with GPUs. PiperOrigin-RevId: 177895214
Diffstat (limited to 'tensorflow/python/training/monitored_session_test.py')
-rw-r--r--tensorflow/python/training/monitored_session_test.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 159b2d5c16..349d8537cb 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
@@ -1968,6 +1969,19 @@ class MonitoredSessionTest(test.TestCase):
self.assertEqual(2, trace_the_exception['side_effect_counter'])
self.assertNear(0.62, session.run(graph_state), 0.1)
+ def test_saver_on_a_gpu(self):
+ if not test_util.is_gpu_available():
+ return
+ with ops.Graph().as_default():
+ with self.test_session():
+ with ops.device('/gpu:0'):
+ variables.Variable(0)
+ saver_lib.Saver()
+
+ # TODO(b/36964652): Reproduces the issue that needs to be fixed.
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ monitored_session.MonitoredSession()
+
class SingularMonitoredSessionTest(test.TestCase):
"""Tests SingularMonitoredSession."""