aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-12 20:09:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-12 21:16:14 -0700
commitb7c375735732641b376ac4bdcc84bd74f1e747e4 (patch)
treec81f67dcaeb404783dcc8b7f7c97f4a16b59cda3
parent818e82f2211eff7c65b2c5da838aba70fa42c347 (diff)
Add a monitor that writes trainable variable values into log. This is useful for monitoring the distribution of learned weights and debugging change of weights in every n steps.
Change: 127276229
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors.py39
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/monitors_test.py29
2 files changed, 66 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index 10f053aee2..f6af7a6f8c 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -379,6 +379,45 @@ class PrintTensor(EveryN):
logging.info("Step %d: %s", step, ", ".join(stats))
+class LoggingTrainable(EveryN):
+ """Writes trainable varialbe values into log every N steps.
+
+ Write the tensors in trainable variables `every_n` steps,
+ starting with the `first_n`th step.
+
+ """
+
+ def __init__(self, scope=None, every_n=100, first_n=1):
+ """Initializes LoggingTrainable monitor.
+
+ Args:
+ scope: An optional string to match variable names using re.match.
+ every_n: Print every N steps.
+ first_n: Print first N steps.
+ """
+ super(LoggingTrainable, self).__init__(every_n, first_n)
+ self._scope = scope
+
+ def every_n_step_begin(self, step):
+ super(LoggingTrainable, self).every_n_step_begin(step)
+ # Get a list of trainable variables at the begining of every N steps.
+ # We cannot get this in __init__ because train_op has not been generated.
+ trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES,
+ scope=self._scope)
+ self._names = {}
+ for var in trainables:
+ self._names[var.name] = var.value().name
+ return list(self._names.values())
+
+ def every_n_step_end(self, step, outputs):
+ super(LoggingTrainable, self).every_n_step_end(step, outputs)
+ stats = []
+ for tag, tensor_name in six.iteritems(self._names):
+ if tensor_name in outputs:
+ stats.append("%s = %s" % (tag, str(outputs[tensor_name])))
+ logging.info("Logging Trainable: Step %d: %s", step, ", ".join(stats))
+
+
class SummarySaver(EveryN):
"""Saves summaries every N steps."""
diff --git a/tensorflow/contrib/learn/python/learn/tests/monitors_test.py b/tensorflow/contrib/learn/python/learn/tests/monitors_test.py
index 3327a67d53..09555f11cb 100644
--- a/tensorflow/contrib/learn/python/learn/tests/monitors_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/monitors_test.py
@@ -24,6 +24,7 @@ import tensorflow as tf
from tensorflow.contrib import testing
from tensorflow.contrib.learn.python import learn
+from tensorflow.python.platform import tf_logging as logging
class _MyEveryN(learn.monitors.EveryN):
@@ -56,6 +57,20 @@ class _MyEveryN(learn.monitors.EveryN):
class MonitorsTest(tf.test.TestCase):
"""Monitors tests."""
+ def setUp(self):
+ # Mock out logging calls so we can verify whether correct tensors are being
+ # monitored.
+ self._actual_log = logging.info
+
+ def mockLog(*args, **kwargs):
+ self.logged_message = args
+ self._actual_log(*args, **kwargs)
+
+ logging.info = mockLog
+
+ def tearDown(self):
+ logging.info = self._actual_log
+
def _run_monitor(self, monitor, num_epochs=3, num_steps_per_epoch=10):
monitor.begin(max_steps=(num_epochs * num_steps_per_epoch) - 1)
for epoch in xrange(num_epochs):
@@ -86,12 +101,22 @@ class MonitorsTest(tf.test.TestCase):
self.assertEqual(expected_steps, monitor.steps_begun)
self.assertEqual(expected_steps, monitor.steps_ended)
- # TODO(b/29293803): This is just a sanity check for now, add better tests with
- # a mocked logger.
def test_print(self):
with tf.Graph().as_default() as g, self.test_session(g):
t = tf.constant(42.0, name='foo')
self._run_monitor(learn.monitors.PrintTensor(tensor_names=[t.name]))
+ self.assertRegexpMatches(str(self.logged_message), t.name)
+
+ def test_logging_trainable(self):
+ with tf.Graph().as_default() as g, self.test_session(g):
+ var = tf.Variable(tf.constant(42.0), name='foo')
+ var.initializer.run()
+ cof = tf.constant(1.0)
+ loss = tf.sub(tf.mul(var, cof), tf.constant(1.0))
+ train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
+ tf.get_default_session().run(train_step)
+ self._run_monitor(learn.monitors.LoggingTrainable('foo'))
+ self.assertRegexpMatches(str(self.logged_message), var.name)
def test_summary_saver(self):
with tf.Graph().as_default() as g, self.test_session(g):