aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/basic_loops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-06-09 13:56:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-09 15:03:40 -0700
commit48f2522176f0e2b3d30304b247658c240d81fc88 (patch)
tree11b39962af5f062ed712244556608f0212c506a8 /tensorflow/python/training/basic_loops_test.py
parent1fde2817529524f8b0d59ab24400ca8c2675fcf7 (diff)
Add basic_train_loop() as an example for higher level frameworks to copy or
reuse. It can also be used directly for simple training. Fix Coordinator.clear_stop() to also clear the exception to raise. Add test. Add SummaryWriter.reopen(), with tests. This is needed to properly handle summaries when create a session more than once in a Supervior. In Supervisor.prepare_or_wait_for_session() reopen the summary writer. At then end of Supervisor.managed_session() correctly close the summary write and clear the running threads even if an exception was reported. Change: 124500982
Diffstat (limited to 'tensorflow/python/training/basic_loops_test.py')
-rw-r--r--tensorflow/python/training/basic_loops_test.py95
1 files changed, 95 insertions, 0 deletions
diff --git a/tensorflow/python/training/basic_loops_test.py b/tensorflow/python/training/basic_loops_test.py
new file mode 100644
index 0000000000..fc442c414c
--- /dev/null
+++ b/tensorflow/python/training/basic_loops_test.py
@@ -0,0 +1,95 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for basic_loops.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+
+import tensorflow as tf
+
+
+def _test_dir(test_name):
+ test_dir = os.path.join(tf.test.get_temp_dir(), test_name)
+ if os.path.exists(test_dir):
+ shutil.rmtree(test_dir)
+ return test_dir
+
+
+class BasicTrainLoopTest(tf.test.TestCase):
+
+ def testBasicTrainLoop(self):
+ logdir = _test_dir("basic_train_loop")
+ sv = tf.train.Supervisor(logdir=logdir)
+ # Counts the number of calls.
+ num_calls = [0]
+
+ def train_fn(unused_sess, sv, y, a):
+ num_calls[0] += 1
+ self.assertEqual("y", y)
+ self.assertEqual("A", a)
+ if num_calls[0] == 3:
+ sv.request_stop()
+
+ with tf.Graph().as_default():
+ tf.train.basic_train_loop(sv, train_fn, args=(sv, "y"), kwargs={"a": "A"})
+ self.assertEqual(3, num_calls[0])
+
+ def testBasicTrainLoopExceptionAborts(self):
+ logdir = _test_dir("basic_train_loop_exception_aborts")
+ sv = tf.train.Supervisor(logdir=logdir)
+
+ def train_fn(unused_sess):
+ train_fn.counter += 1
+ if train_fn.counter == 3:
+ raise RuntimeError("Failed")
+
+ # Function attribute use to count the number of calls.
+ train_fn.counter = 0
+
+ with tf.Graph().as_default():
+ with self.assertRaisesRegexp(RuntimeError, "Failed"):
+ tf.train.basic_train_loop(sv, train_fn)
+
+ def testBasicTrainLoopRetryOnAborted(self):
+ logdir = _test_dir("basic_train_loop_exception_aborts")
+ sv = tf.train.Supervisor(logdir=logdir)
+
+ class AbortAndRetry(object):
+
+ def __init__(self):
+ self.num_calls = 0
+ self.retries_left = 2
+
+ def train_fn(self, unused_sess):
+ self.num_calls += 1
+ if self.num_calls % 3 == 2:
+ self.retries_left -= 1
+ if self.retries_left > 0:
+ raise tf.errors.AbortedError(None, None, "Aborted here")
+ else:
+ raise RuntimeError("Failed Again")
+
+ with tf.Graph().as_default():
+ aar = AbortAndRetry()
+ with self.assertRaisesRegexp(RuntimeError, "Failed Again"):
+ tf.train.basic_train_loop(sv, aar.train_fn)
+ self.assertEquals(0, aar.retries_left)
+
+
+if __name__ == "__main__":
+ tf.test.main()