diff options
author | 2016-07-12 20:09:25 -0800 | |
---|---|---|
committer | 2016-07-12 21:16:14 -0700 | |
commit | b7c375735732641b376ac4bdcc84bd74f1e747e4 (patch) | |
tree | c81f67dcaeb404783dcc8b7f7c97f4a16b59cda3 | |
parent | 818e82f2211eff7c65b2c5da838aba70fa42c347 (diff) |
Add a monitor that writes trainable variable values into log. This is useful for monitoring the distribution of learned weights and debugging change of weights in every n steps.
Change: 127276229
-rw-r--r-- | tensorflow/contrib/learn/python/learn/monitors.py | 39 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/tests/monitors_test.py | 29 |
2 files changed, 66 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 10f053aee2..f6af7a6f8c 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -379,6 +379,45 @@ class PrintTensor(EveryN): logging.info("Step %d: %s", step, ", ".join(stats)) +class LoggingTrainable(EveryN): + """Writes trainable varialbe values into log every N steps. + + Write the tensors in trainable variables `every_n` steps, + starting with the `first_n`th step. + + """ + + def __init__(self, scope=None, every_n=100, first_n=1): + """Initializes LoggingTrainable monitor. + + Args: + scope: An optional string to match variable names using re.match. + every_n: Print every N steps. + first_n: Print first N steps. + """ + super(LoggingTrainable, self).__init__(every_n, first_n) + self._scope = scope + + def every_n_step_begin(self, step): + super(LoggingTrainable, self).every_n_step_begin(step) + # Get a list of trainable variables at the begining of every N steps. + # We cannot get this in __init__ because train_op has not been generated. + trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, + scope=self._scope) + self._names = {} + for var in trainables: + self._names[var.name] = var.value().name + return list(self._names.values()) + + def every_n_step_end(self, step, outputs): + super(LoggingTrainable, self).every_n_step_end(step, outputs) + stats = [] + for tag, tensor_name in six.iteritems(self._names): + if tensor_name in outputs: + stats.append("%s = %s" % (tag, str(outputs[tensor_name]))) + logging.info("Logging Trainable: Step %d: %s", step, ", ".join(stats)) + + class SummarySaver(EveryN): """Saves summaries every N steps.""" diff --git a/tensorflow/contrib/learn/python/learn/tests/monitors_test.py b/tensorflow/contrib/learn/python/learn/tests/monitors_test.py index 3327a67d53..09555f11cb 100644 --- a/tensorflow/contrib/learn/python/learn/tests/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/monitors_test.py @@ -24,6 +24,7 @@ import tensorflow as tf from tensorflow.contrib import testing from tensorflow.contrib.learn.python import learn +from tensorflow.python.platform import tf_logging as logging class _MyEveryN(learn.monitors.EveryN): @@ -56,6 +57,20 @@ class _MyEveryN(learn.monitors.EveryN): class MonitorsTest(tf.test.TestCase): """Monitors tests.""" + def setUp(self): + # Mock out logging calls so we can verify whether correct tensors are being + # monitored. + self._actual_log = logging.info + + def mockLog(*args, **kwargs): + self.logged_message = args + self._actual_log(*args, **kwargs) + + logging.info = mockLog + + def tearDown(self): + logging.info = self._actual_log + def _run_monitor(self, monitor, num_epochs=3, num_steps_per_epoch=10): monitor.begin(max_steps=(num_epochs * num_steps_per_epoch) - 1) for epoch in xrange(num_epochs): @@ -86,12 +101,22 @@ class MonitorsTest(tf.test.TestCase): self.assertEqual(expected_steps, monitor.steps_begun) self.assertEqual(expected_steps, monitor.steps_ended) - # TODO(b/29293803): This is just a sanity check for now, add better tests with - # a mocked logger. def test_print(self): with tf.Graph().as_default() as g, self.test_session(g): t = tf.constant(42.0, name='foo') self._run_monitor(learn.monitors.PrintTensor(tensor_names=[t.name])) + self.assertRegexpMatches(str(self.logged_message), t.name) + + def test_logging_trainable(self): + with tf.Graph().as_default() as g, self.test_session(g): + var = tf.Variable(tf.constant(42.0), name='foo') + var.initializer.run() + cof = tf.constant(1.0) + loss = tf.sub(tf.mul(var, cof), tf.constant(1.0)) + train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss) + tf.get_default_session().run(train_step) + self._run_monitor(learn.monitors.LoggingTrainable('foo')) + self.assertRegexpMatches(str(self.logged_message), var.name) def test_summary_saver(self): with tf.Graph().as_default() as g, self.test_session(g): |