diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/tests/monitors_test.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/tests/monitors_test.py | 29 |
1 files changed, 27 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/tests/monitors_test.py b/tensorflow/contrib/learn/python/learn/tests/monitors_test.py index 3327a67d53..09555f11cb 100644 --- a/tensorflow/contrib/learn/python/learn/tests/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/monitors_test.py @@ -24,6 +24,7 @@ import tensorflow as tf from tensorflow.contrib import testing from tensorflow.contrib.learn.python import learn +from tensorflow.python.platform import tf_logging as logging class _MyEveryN(learn.monitors.EveryN): @@ -56,6 +57,20 @@ class _MyEveryN(learn.monitors.EveryN): class MonitorsTest(tf.test.TestCase): """Monitors tests.""" + def setUp(self): + # Mock out logging calls so we can verify whether correct tensors are being + # monitored. + self._actual_log = logging.info + + def mockLog(*args, **kwargs): + self.logged_message = args + self._actual_log(*args, **kwargs) + + logging.info = mockLog + + def tearDown(self): + logging.info = self._actual_log + def _run_monitor(self, monitor, num_epochs=3, num_steps_per_epoch=10): monitor.begin(max_steps=(num_epochs * num_steps_per_epoch) - 1) for epoch in xrange(num_epochs): @@ -86,12 +101,22 @@ class MonitorsTest(tf.test.TestCase): self.assertEqual(expected_steps, monitor.steps_begun) self.assertEqual(expected_steps, monitor.steps_ended) - # TODO(b/29293803): This is just a sanity check for now, add better tests with - # a mocked logger. def test_print(self): with tf.Graph().as_default() as g, self.test_session(g): t = tf.constant(42.0, name='foo') self._run_monitor(learn.monitors.PrintTensor(tensor_names=[t.name])) + self.assertRegexpMatches(str(self.logged_message), t.name) + + def test_logging_trainable(self): + with tf.Graph().as_default() as g, self.test_session(g): + var = tf.Variable(tf.constant(42.0), name='foo') + var.initializer.run() + cof = tf.constant(1.0) + loss = tf.sub(tf.mul(var, cof), tf.constant(1.0)) + train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss) + tf.get_default_session().run(train_step) + self._run_monitor(learn.monitors.LoggingTrainable('foo')) + self.assertRegexpMatches(str(self.logged_message), var.name) def test_summary_saver(self): with tf.Graph().as_default() as g, self.test_session(g): |