diff options
-rw-r--r-- | tensorflow/contrib/learn/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/graph_actions.py | 34 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py | 200 |
3 files changed, 242 insertions, 4 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 485a1220cc..103759ac0b 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -212,6 +212,18 @@ py_test( ) py_test( + name = "graph_actions_test", + size = "small", + srcs = ["python/learn/tests/graph_actions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + ], +) + +py_test( name = "learn_runner_test", size = "small", srcs = ["python/learn/tests/learn_runner_test.py"], diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 39a74844ca..890d369a0f 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -49,13 +49,20 @@ from tensorflow.python.training import session_manager as session_manager_lib from tensorflow.python.training import summary_io from tensorflow.python.training import supervisor as tf_supervisor -# Singletone for SummaryWriter per logdir folder. +# Singleton for SummaryWriter per logdir folder. _SUMMARY_WRITERS = {} # Lock protecting _SUMMARY_WRITERS _summary_writer_lock = threading.Lock() +def clear_summary_writers(): + """Clear cached summary writers. Currently only used for unit tests.""" + _summary_writer_lock.acquire() + _SUMMARY_WRITERS.clear() + _summary_writer_lock.release() + + def get_summary_writer(logdir): """Returns single SummaryWriter per logdir in current run. @@ -398,7 +405,6 @@ def _write_summary_results(output_dir, eval_results, current_global_step): summary_writer.flush() -# TODO(ptucker): Add unit test. def evaluate(graph, output_dir, checkpoint_path, @@ -563,8 +569,8 @@ def run_n(output_dict, feed_dict=None, restore_checkpoint_path=None, n=1): def run_feeds(output_dict, feed_dicts, restore_checkpoint_path=None): """Run `output_dict` tensors with each input in `feed_dicts`. - If `checkpoint_path` is supplied, restore from checkpoint. Otherwise, init all - variables. + If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise, + init all variables. Args: output_dict: A `dict` mapping string names to `Tensor` objects to run. @@ -609,6 +615,26 @@ def run_feeds(output_dict, feed_dicts, restore_checkpoint_path=None): def infer(restore_checkpoint_path, output_dict, feed_dict=None): + """Restore graph from `restore_checkpoint_path` and run `output_dict` tensors. + + If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise, + init all variables. + + Args: + restore_checkpoint_path: A string containing the path to a checkpoint to + restore. + output_dict: A `dict` mapping string names to `Tensor` objects to run. + Tensors must all be from the same graph. + feed_dict: `dict` object mapping `Tensor` objects to input values to feed. + + Returns: + Dict of values read from `output_dict` tensors. Keys are the same as + `output_dict`, values are the results read from the corresponding `Tensor` + in `output_dict`. + + Raises: + ValueError: if `output_dict` or `feed_dicts` is None or empty. + """ return run_feeds(output_dict=output_dict, feed_dicts=[feed_dict] if feed_dict is not None else [None], restore_checkpoint_path=restore_checkpoint_path)[0] diff --git a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py new file mode 100644 index 0000000000..e391265aad --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py @@ -0,0 +1,200 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Graph actions tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.learn.python import learn +from tensorflow.python.training import summary_io + + +# TODO(ptucker): Replace with mock framework. +class _FakeSummaryWriter(object): + + def __init__(self, logdir, graph=None): + self._logdir = logdir + self._graph = graph + self._summaries = {} + self._flushed = True + + @property + def logdir(self): + return self._logdir + + @property + def graph(self): + return self._graph + + @property + def summaries(self): + return self._summaries + + @property + def flushed(self): + return self._flushed + + def add_summary(self, summary, current_global_step): + if current_global_step in self._summaries: + raise ValueError('Dupe summary for step %s.' % current_global_step) + self._summaries[current_global_step] = summary + self._flushed = False + + def flush(self): + self._flushed = True + + +class _Feeder(object): + """Simple generator for `feed_fn`, returning 10 * step.""" + + def __init__(self, tensor): + self._step = 0 + self._tensor = tensor + + @property + def step(self): + return self._step + + def feed_fn(self): + value = self._step * 10.0 + self._step += 1 + return {self._tensor: value} + + +class GraphActionsTest(tf.test.TestCase): + """Graph actions tests.""" + + def setUp(self): + learn.graph_actions.clear_summary_writers() + self._original_summary_writer = summary_io.SummaryWriter + summary_io.SummaryWriter = _FakeSummaryWriter + + def tearDown(self): + summary_io.SummaryWriter = self._original_summary_writer + learn.graph_actions.clear_summary_writers() + + def _assert_fake_summary_writer(self, output_dir, expected_summaries=None): + writer = learn.graph_actions.get_summary_writer(output_dir) + self.assertTrue(isinstance(writer, _FakeSummaryWriter)) + self.assertEqual(output_dir, writer.logdir) + self.assertTrue(tf.get_default_graph() is writer.graph) + self.assertTrue(writer.flushed) + expected_summaries = expected_summaries or {} + expected_steps = expected_summaries.keys() + self.assertEqual(set(expected_steps), set(writer.summaries.keys())) + for step in expected_steps: + actual_simple_values = {} + for v in writer.summaries[step].value: + actual_simple_values[v.tag] = v.simple_value + self.assertEqual(expected_summaries[step], actual_simple_values) + + # TODO(ptucker): Test lock, multi-threaded access? + def test_summary_writer(self): + self._assert_fake_summary_writer('log/dir/0') + self.assertTrue( + learn.graph_actions.get_summary_writer('log/dir/0') is + learn.graph_actions.get_summary_writer('log/dir/0')) + self.assertTrue( + learn.graph_actions.get_summary_writer('log/dir/0') is not + learn.graph_actions.get_summary_writer('log/dir/1')) + + # TODO(ptucker): Test restore_checkpoint_path. + # TODO(ptucker): Test start_queue_runners. + # TODO(ptucker): Test coord.request_stop & coord.join. + + def _build_inference_graph(self): + """Build simple inference graph. + + This includes a regular variable, local variable, and fake table. + + Returns: + Tuple of 3 `Tensor` objects, 2 input and 1 output. + """ + tf.contrib.framework.create_global_step() + in0 = tf.Variable(1.0) + in1 = tf.contrib.framework.local_variable(2.0) + fake_table = tf.Variable( + 3.0, trainable=False, collections=['fake_tables'], + name='fake_table_var') + in0.graph.add_to_collections( + [tf.GraphKeys.TABLE_INITIALIZERS], fake_table.initializer) + out = in0 + in1 + fake_table + return in0, in1, out + + def test_infer(self): + with tf.Graph().as_default() as g, self.test_session(g): + in0, in1, out = self._build_inference_graph() + self.assertEqual( + {'a': 1.0, 'b': 2.0, 'c': 6.0}, + learn.graph_actions.infer(None, {'a': in0, 'b': in1, 'c': out})) + + def test_infer_different_default_graph(self): + with self.test_session(): + with tf.Graph().as_default(): + in0, in1, out = self._build_inference_graph() + with tf.Graph().as_default(): + self.assertEqual( + {'a': 1.0, 'b': 2.0, 'c': 6.0}, + learn.graph_actions.infer(None, {'a': in0, 'b': in1, 'c': out})) + + def test_infer_invalid_feed(self): + with tf.Graph().as_default() as g, self.test_session(g): + in0, _, _ = self._build_inference_graph() + with self.assertRaisesRegexp( + tf.errors.InvalidArgumentError, 'both fed and fetched'): + learn.graph_actions.infer(None, {'a': in0}, feed_dict={in0: 4.0}) + + def test_infer_feed(self): + with tf.Graph().as_default() as g, self.test_session(g): + in0, _, out = self._build_inference_graph() + self.assertEqual( + {'c': 9.0}, + learn.graph_actions.infer(None, {'c': out}, feed_dict={in0: 4.0})) + + # TODO(ptucker): Test saver and ckpt_path. + # TODO(ptucker): Test eval for 1 epoch. + + def test_evaluate(self): + with tf.Graph().as_default() as g, self.test_session(g): + _, _, out = self._build_inference_graph() + output_dir = 'out/dir' + self._assert_fake_summary_writer(output_dir, {}) + results = learn.graph_actions.evaluate( + g, output_dir=output_dir, checkpoint_path=None, eval_dict={'a': out}, + max_steps=1) + self.assertEqual(({'a': 6.0}, 0), results) + self._assert_fake_summary_writer(output_dir, {0: {'a': 6.0}}) + + def test_evaluate_feed_fn(self): + with tf.Graph().as_default() as g, self.test_session(g): + in0, _, out = self._build_inference_graph() + output_dir = 'out/dir' + self._assert_fake_summary_writer(output_dir, {}) + feeder = _Feeder(in0) + results = learn.graph_actions.evaluate( + g, output_dir=output_dir, checkpoint_path=None, eval_dict={'a': out}, + feed_fn=feeder.feed_fn, max_steps=3) + self.assertEqual(3, feeder.step) + self.assertEqual(({'a': 25.0}, 0), results) + self._assert_fake_summary_writer(output_dir, {0: {'a': 25.0}}) + + +if __name__ == '__main__': + tf.test.main() |