aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/tests/monitors_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/tests/monitors_test.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/monitors_test.py29
1 files changed, 27 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/tests/monitors_test.py b/tensorflow/contrib/learn/python/learn/tests/monitors_test.py
index 3327a67d53..09555f11cb 100644
--- a/tensorflow/contrib/learn/python/learn/tests/monitors_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/monitors_test.py
@@ -24,6 +24,7 @@ import tensorflow as tf
from tensorflow.contrib import testing
from tensorflow.contrib.learn.python import learn
+from tensorflow.python.platform import tf_logging as logging
class _MyEveryN(learn.monitors.EveryN):
@@ -56,6 +57,20 @@ class _MyEveryN(learn.monitors.EveryN):
class MonitorsTest(tf.test.TestCase):
"""Monitors tests."""
+ def setUp(self):
+ # Mock out logging calls so we can verify whether correct tensors are being
+ # monitored.
+ self._actual_log = logging.info
+
+ def mockLog(*args, **kwargs):
+ self.logged_message = args
+ self._actual_log(*args, **kwargs)
+
+ logging.info = mockLog
+
+ def tearDown(self):
+ logging.info = self._actual_log
+
def _run_monitor(self, monitor, num_epochs=3, num_steps_per_epoch=10):
monitor.begin(max_steps=(num_epochs * num_steps_per_epoch) - 1)
for epoch in xrange(num_epochs):
@@ -86,12 +101,22 @@ class MonitorsTest(tf.test.TestCase):
self.assertEqual(expected_steps, monitor.steps_begun)
self.assertEqual(expected_steps, monitor.steps_ended)
- # TODO(b/29293803): This is just a sanity check for now, add better tests with
- # a mocked logger.
def test_print(self):
with tf.Graph().as_default() as g, self.test_session(g):
t = tf.constant(42.0, name='foo')
self._run_monitor(learn.monitors.PrintTensor(tensor_names=[t.name]))
+ self.assertRegexpMatches(str(self.logged_message), t.name)
+
+ def test_logging_trainable(self):
+ with tf.Graph().as_default() as g, self.test_session(g):
+ var = tf.Variable(tf.constant(42.0), name='foo')
+ var.initializer.run()
+ cof = tf.constant(1.0)
+ loss = tf.sub(tf.mul(var, cof), tf.constant(1.0))
+ train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
+ tf.get_default_session().run(train_step)
+ self._run_monitor(learn.monitors.LoggingTrainable('foo'))
+ self.assertRegexpMatches(str(self.logged_message), var.name)
def test_summary_saver(self):
with tf.Graph().as_default() as g, self.test_session(g):