aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py30
-rw-r--r--tensorflow/contrib/testing/python/framework/fake_summary_writer.py7
2 files changed, 13 insertions, 24 deletions
diff --git a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
index df5f2184c6..643ab55fd4 100644
--- a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
@@ -151,18 +151,19 @@ class GraphActionsTest(tf.test.TestCase):
def test_evaluate(self):
with tf.Graph().as_default() as g, self.test_session(g):
_, _, out = self._build_inference_graph()
- self._assert_summaries(self._output_dir)
+ self._assert_summaries(self._output_dir, expected_session_logs=[])
results = learn.graph_actions.evaluate(
g, output_dir=self._output_dir, checkpoint_path=None,
eval_dict={'a': out}, max_steps=1)
self.assertEqual(({'a': 6.0}, 0), results)
self._assert_summaries(
- self._output_dir, expected_summaries={0: {'a': 6.0}})
+ self._output_dir, expected_summaries={0: {'a': 6.0}},
+ expected_session_logs=[])
def test_evaluate_feed_fn(self):
with tf.Graph().as_default() as g, self.test_session(g):
in0, _, out = self._build_inference_graph()
- self._assert_summaries(self._output_dir)
+ self._assert_summaries(self._output_dir, expected_session_logs=[])
feeder = _Feeder(in0)
results = learn.graph_actions.evaluate(
g, output_dir=self._output_dir, checkpoint_path=None,
@@ -170,7 +171,8 @@ class GraphActionsTest(tf.test.TestCase):
self.assertEqual(3, feeder.step)
self.assertEqual(({'a': 25.0}, 0), results)
self._assert_summaries(
- self._output_dir, expected_summaries={0: {'a': 25.0}})
+ self._output_dir, expected_summaries={0: {'a': 25.0}},
+ expected_session_logs=[])
def test_train_invalid_args(self):
with tf.Graph().as_default() as g, self.test_session(g):
@@ -198,14 +200,7 @@ class GraphActionsTest(tf.test.TestCase):
# TODO(ptucker): Resume training from previous ckpt.
# TODO(ptucker): !supervisor_is_chief
# TODO(ptucker): Custom init op for training.
-
- def _expected_train_session_logs(self):
- return [
- tf.SessionLog(status=tf.SessionLog.START),
- tf.SessionLog(
- status=tf.SessionLog.CHECKPOINT,
- checkpoint_path='%s/model.ckpt' % self._output_dir),
- ]
+ # TODO(ptucker): Mock supervisor, and assert all interactions.
def test_train(self):
with tf.Graph().as_default() as g, self.test_session(g):
@@ -216,10 +211,7 @@ class GraphActionsTest(tf.test.TestCase):
g, output_dir=self._output_dir, train_op=train_op,
loss_op=tf.constant(2.0), steps=1)
self.assertEqual(2.0, loss)
- self._assert_summaries(
- self._output_dir,
- expected_graphs=[g],
- expected_session_logs=self._expected_train_session_logs())
+ self._assert_summaries(self._output_dir, expected_graphs=[g])
def test_train_loss(self):
with tf.Graph().as_default() as g, self.test_session(g):
@@ -233,10 +225,7 @@ class GraphActionsTest(tf.test.TestCase):
g, output_dir=self._output_dir, train_op=train_op,
loss_op=loss_var.value(), steps=6)
self.assertEqual(4.0, loss)
- self._assert_summaries(
- self._output_dir,
- expected_graphs=[g],
- expected_session_logs=self._expected_train_session_logs())
+ self._assert_summaries(self._output_dir, expected_graphs=[g])
def test_train_summaries(self):
with tf.Graph().as_default() as g, self.test_session(g):
@@ -252,7 +241,6 @@ class GraphActionsTest(tf.test.TestCase):
self._assert_summaries(
self._output_dir,
expected_graphs=[g],
- expected_session_logs=self._expected_train_session_logs(),
expected_summaries={1: {'loss': 2.0}})
diff --git a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
index 8cc2c4c464..d742608cf1 100644
--- a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
+++ b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
@@ -71,9 +71,10 @@ class FakeSummaryWriter(object):
if 'global_step/sec' != v.tag:
actual_simple_values[v.tag] = v.simple_value
test_case.assertEqual(expected_summaries[step], actual_simple_values)
- test_case.assertEqual(expected_added_graphs or [], self._added_graphs)
- test_case.assertEqual(
- expected_session_logs or [], self._added_session_logs)
+ if expected_added_graphs is not None:
+ test_case.assertEqual(expected_added_graphs, self._added_graphs)
+ if expected_session_logs is not None:
+ test_case.assertEqual(expected_session_logs, self._added_session_logs)
def add_summary(self, summary, current_global_step):
"""Add summary."""