diff options
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py | 30 | ||||
-rw-r--r-- | tensorflow/contrib/testing/python/framework/fake_summary_writer.py | 7 |
2 files changed, 13 insertions, 24 deletions
diff --git a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py index df5f2184c6..643ab55fd4 100644 --- a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py @@ -151,18 +151,19 @@ class GraphActionsTest(tf.test.TestCase): def test_evaluate(self): with tf.Graph().as_default() as g, self.test_session(g): _, _, out = self._build_inference_graph() - self._assert_summaries(self._output_dir) + self._assert_summaries(self._output_dir, expected_session_logs=[]) results = learn.graph_actions.evaluate( g, output_dir=self._output_dir, checkpoint_path=None, eval_dict={'a': out}, max_steps=1) self.assertEqual(({'a': 6.0}, 0), results) self._assert_summaries( - self._output_dir, expected_summaries={0: {'a': 6.0}}) + self._output_dir, expected_summaries={0: {'a': 6.0}}, + expected_session_logs=[]) def test_evaluate_feed_fn(self): with tf.Graph().as_default() as g, self.test_session(g): in0, _, out = self._build_inference_graph() - self._assert_summaries(self._output_dir) + self._assert_summaries(self._output_dir, expected_session_logs=[]) feeder = _Feeder(in0) results = learn.graph_actions.evaluate( g, output_dir=self._output_dir, checkpoint_path=None, @@ -170,7 +171,8 @@ class GraphActionsTest(tf.test.TestCase): self.assertEqual(3, feeder.step) self.assertEqual(({'a': 25.0}, 0), results) self._assert_summaries( - self._output_dir, expected_summaries={0: {'a': 25.0}}) + self._output_dir, expected_summaries={0: {'a': 25.0}}, + expected_session_logs=[]) def test_train_invalid_args(self): with tf.Graph().as_default() as g, self.test_session(g): @@ -198,14 +200,7 @@ class GraphActionsTest(tf.test.TestCase): # TODO(ptucker): Resume training from previous ckpt. # TODO(ptucker): !supervisor_is_chief # TODO(ptucker): Custom init op for training. - - def _expected_train_session_logs(self): - return [ - tf.SessionLog(status=tf.SessionLog.START), - tf.SessionLog( - status=tf.SessionLog.CHECKPOINT, - checkpoint_path='%s/model.ckpt' % self._output_dir), - ] + # TODO(ptucker): Mock supervisor, and assert all interactions. def test_train(self): with tf.Graph().as_default() as g, self.test_session(g): @@ -216,10 +211,7 @@ class GraphActionsTest(tf.test.TestCase): g, output_dir=self._output_dir, train_op=train_op, loss_op=tf.constant(2.0), steps=1) self.assertEqual(2.0, loss) - self._assert_summaries( - self._output_dir, - expected_graphs=[g], - expected_session_logs=self._expected_train_session_logs()) + self._assert_summaries(self._output_dir, expected_graphs=[g]) def test_train_loss(self): with tf.Graph().as_default() as g, self.test_session(g): @@ -233,10 +225,7 @@ class GraphActionsTest(tf.test.TestCase): g, output_dir=self._output_dir, train_op=train_op, loss_op=loss_var.value(), steps=6) self.assertEqual(4.0, loss) - self._assert_summaries( - self._output_dir, - expected_graphs=[g], - expected_session_logs=self._expected_train_session_logs()) + self._assert_summaries(self._output_dir, expected_graphs=[g]) def test_train_summaries(self): with tf.Graph().as_default() as g, self.test_session(g): @@ -252,7 +241,6 @@ class GraphActionsTest(tf.test.TestCase): self._assert_summaries( self._output_dir, expected_graphs=[g], - expected_session_logs=self._expected_train_session_logs(), expected_summaries={1: {'loss': 2.0}}) diff --git a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py index 8cc2c4c464..d742608cf1 100644 --- a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py +++ b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py @@ -71,9 +71,10 @@ class FakeSummaryWriter(object): if 'global_step/sec' != v.tag: actual_simple_values[v.tag] = v.simple_value test_case.assertEqual(expected_summaries[step], actual_simple_values) - test_case.assertEqual(expected_added_graphs or [], self._added_graphs) - test_case.assertEqual( - expected_session_logs or [], self._added_session_logs) + if expected_added_graphs is not None: + test_case.assertEqual(expected_added_graphs, self._added_graphs) + if expected_session_logs is not None: + test_case.assertEqual(expected_session_logs, self._added_session_logs) def add_summary(self, summary, current_global_step): """Add summary.""" |