aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/BUILD12
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions.py34
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py200
3 files changed, 242 insertions, 4 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 485a1220cc..103759ac0b 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -212,6 +212,18 @@ py_test(
)
py_test(
+ name = "graph_actions_test",
+ size = "small",
+ srcs = ["python/learn/tests/graph_actions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":learn",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_test(
name = "learn_runner_test",
size = "small",
srcs = ["python/learn/tests/learn_runner_test.py"],
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index 39a74844ca..890d369a0f 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -49,13 +49,20 @@ from tensorflow.python.training import session_manager as session_manager_lib
from tensorflow.python.training import summary_io
from tensorflow.python.training import supervisor as tf_supervisor
-# Singletone for SummaryWriter per logdir folder.
+# Singleton for SummaryWriter per logdir folder.
_SUMMARY_WRITERS = {}
# Lock protecting _SUMMARY_WRITERS
_summary_writer_lock = threading.Lock()
+def clear_summary_writers():
+ """Clear cached summary writers. Currently only used for unit tests."""
+ _summary_writer_lock.acquire()
+ _SUMMARY_WRITERS.clear()
+ _summary_writer_lock.release()
+
+
def get_summary_writer(logdir):
"""Returns single SummaryWriter per logdir in current run.
@@ -398,7 +405,6 @@ def _write_summary_results(output_dir, eval_results, current_global_step):
summary_writer.flush()
-# TODO(ptucker): Add unit test.
def evaluate(graph,
output_dir,
checkpoint_path,
@@ -563,8 +569,8 @@ def run_n(output_dict, feed_dict=None, restore_checkpoint_path=None, n=1):
def run_feeds(output_dict, feed_dicts, restore_checkpoint_path=None):
"""Run `output_dict` tensors with each input in `feed_dicts`.
- If `checkpoint_path` is supplied, restore from checkpoint. Otherwise, init all
- variables.
+ If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise,
+ init all variables.
Args:
output_dict: A `dict` mapping string names to `Tensor` objects to run.
@@ -609,6 +615,26 @@ def run_feeds(output_dict, feed_dicts, restore_checkpoint_path=None):
def infer(restore_checkpoint_path, output_dict, feed_dict=None):
+ """Restore graph from `restore_checkpoint_path` and run `output_dict` tensors.
+
+ If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise,
+ init all variables.
+
+ Args:
+ restore_checkpoint_path: A string containing the path to a checkpoint to
+ restore.
+ output_dict: A `dict` mapping string names to `Tensor` objects to run.
+ Tensors must all be from the same graph.
+ feed_dict: `dict` object mapping `Tensor` objects to input values to feed.
+
+ Returns:
+ Dict of values read from `output_dict` tensors. Keys are the same as
+ `output_dict`, values are the results read from the corresponding `Tensor`
+ in `output_dict`.
+
+ Raises:
+ ValueError: if `output_dict` or `feed_dicts` is None or empty.
+ """
return run_feeds(output_dict=output_dict,
feed_dicts=[feed_dict] if feed_dict is not None else [None],
restore_checkpoint_path=restore_checkpoint_path)[0]
diff --git a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
new file mode 100644
index 0000000000..e391265aad
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
@@ -0,0 +1,200 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Graph actions tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.learn.python import learn
+from tensorflow.python.training import summary_io
+
+
+# TODO(ptucker): Replace with mock framework.
+class _FakeSummaryWriter(object):
+
+ def __init__(self, logdir, graph=None):
+ self._logdir = logdir
+ self._graph = graph
+ self._summaries = {}
+ self._flushed = True
+
+ @property
+ def logdir(self):
+ return self._logdir
+
+ @property
+ def graph(self):
+ return self._graph
+
+ @property
+ def summaries(self):
+ return self._summaries
+
+ @property
+ def flushed(self):
+ return self._flushed
+
+ def add_summary(self, summary, current_global_step):
+ if current_global_step in self._summaries:
+ raise ValueError('Dupe summary for step %s.' % current_global_step)
+ self._summaries[current_global_step] = summary
+ self._flushed = False
+
+ def flush(self):
+ self._flushed = True
+
+
+class _Feeder(object):
+ """Simple generator for `feed_fn`, returning 10 * step."""
+
+ def __init__(self, tensor):
+ self._step = 0
+ self._tensor = tensor
+
+ @property
+ def step(self):
+ return self._step
+
+ def feed_fn(self):
+ value = self._step * 10.0
+ self._step += 1
+ return {self._tensor: value}
+
+
+class GraphActionsTest(tf.test.TestCase):
+ """Graph actions tests."""
+
+ def setUp(self):
+ learn.graph_actions.clear_summary_writers()
+ self._original_summary_writer = summary_io.SummaryWriter
+ summary_io.SummaryWriter = _FakeSummaryWriter
+
+ def tearDown(self):
+ summary_io.SummaryWriter = self._original_summary_writer
+ learn.graph_actions.clear_summary_writers()
+
+ def _assert_fake_summary_writer(self, output_dir, expected_summaries=None):
+ writer = learn.graph_actions.get_summary_writer(output_dir)
+ self.assertTrue(isinstance(writer, _FakeSummaryWriter))
+ self.assertEqual(output_dir, writer.logdir)
+ self.assertTrue(tf.get_default_graph() is writer.graph)
+ self.assertTrue(writer.flushed)
+ expected_summaries = expected_summaries or {}
+ expected_steps = expected_summaries.keys()
+ self.assertEqual(set(expected_steps), set(writer.summaries.keys()))
+ for step in expected_steps:
+ actual_simple_values = {}
+ for v in writer.summaries[step].value:
+ actual_simple_values[v.tag] = v.simple_value
+ self.assertEqual(expected_summaries[step], actual_simple_values)
+
+ # TODO(ptucker): Test lock, multi-threaded access?
+ def test_summary_writer(self):
+ self._assert_fake_summary_writer('log/dir/0')
+ self.assertTrue(
+ learn.graph_actions.get_summary_writer('log/dir/0') is
+ learn.graph_actions.get_summary_writer('log/dir/0'))
+ self.assertTrue(
+ learn.graph_actions.get_summary_writer('log/dir/0') is not
+ learn.graph_actions.get_summary_writer('log/dir/1'))
+
+ # TODO(ptucker): Test restore_checkpoint_path.
+ # TODO(ptucker): Test start_queue_runners.
+ # TODO(ptucker): Test coord.request_stop & coord.join.
+
+ def _build_inference_graph(self):
+ """Build simple inference graph.
+
+ This includes a regular variable, local variable, and fake table.
+
+ Returns:
+ Tuple of 3 `Tensor` objects, 2 input and 1 output.
+ """
+ tf.contrib.framework.create_global_step()
+ in0 = tf.Variable(1.0)
+ in1 = tf.contrib.framework.local_variable(2.0)
+ fake_table = tf.Variable(
+ 3.0, trainable=False, collections=['fake_tables'],
+ name='fake_table_var')
+ in0.graph.add_to_collections(
+ [tf.GraphKeys.TABLE_INITIALIZERS], fake_table.initializer)
+ out = in0 + in1 + fake_table
+ return in0, in1, out
+
+ def test_infer(self):
+ with tf.Graph().as_default() as g, self.test_session(g):
+ in0, in1, out = self._build_inference_graph()
+ self.assertEqual(
+ {'a': 1.0, 'b': 2.0, 'c': 6.0},
+ learn.graph_actions.infer(None, {'a': in0, 'b': in1, 'c': out}))
+
+ def test_infer_different_default_graph(self):
+ with self.test_session():
+ with tf.Graph().as_default():
+ in0, in1, out = self._build_inference_graph()
+ with tf.Graph().as_default():
+ self.assertEqual(
+ {'a': 1.0, 'b': 2.0, 'c': 6.0},
+ learn.graph_actions.infer(None, {'a': in0, 'b': in1, 'c': out}))
+
+ def test_infer_invalid_feed(self):
+ with tf.Graph().as_default() as g, self.test_session(g):
+ in0, _, _ = self._build_inference_graph()
+ with self.assertRaisesRegexp(
+ tf.errors.InvalidArgumentError, 'both fed and fetched'):
+ learn.graph_actions.infer(None, {'a': in0}, feed_dict={in0: 4.0})
+
+ def test_infer_feed(self):
+ with tf.Graph().as_default() as g, self.test_session(g):
+ in0, _, out = self._build_inference_graph()
+ self.assertEqual(
+ {'c': 9.0},
+ learn.graph_actions.infer(None, {'c': out}, feed_dict={in0: 4.0}))
+
+ # TODO(ptucker): Test saver and ckpt_path.
+ # TODO(ptucker): Test eval for 1 epoch.
+
+ def test_evaluate(self):
+ with tf.Graph().as_default() as g, self.test_session(g):
+ _, _, out = self._build_inference_graph()
+ output_dir = 'out/dir'
+ self._assert_fake_summary_writer(output_dir, {})
+ results = learn.graph_actions.evaluate(
+ g, output_dir=output_dir, checkpoint_path=None, eval_dict={'a': out},
+ max_steps=1)
+ self.assertEqual(({'a': 6.0}, 0), results)
+ self._assert_fake_summary_writer(output_dir, {0: {'a': 6.0}})
+
+ def test_evaluate_feed_fn(self):
+ with tf.Graph().as_default() as g, self.test_session(g):
+ in0, _, out = self._build_inference_graph()
+ output_dir = 'out/dir'
+ self._assert_fake_summary_writer(output_dir, {})
+ feeder = _Feeder(in0)
+ results = learn.graph_actions.evaluate(
+ g, output_dir=output_dir, checkpoint_path=None, eval_dict={'a': out},
+ feed_fn=feeder.feed_fn, max_steps=3)
+ self.assertEqual(3, feeder.step)
+ self.assertEqual(({'a': 25.0}, 0), results)
+ self._assert_fake_summary_writer(output_dir, {0: {'a': 25.0}})
+
+
+if __name__ == '__main__':
+ tf.test.main()