diff options
Diffstat (limited to 'tensorflow/contrib/summary/summary_ops_test.py')
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 122 |
1 files changed, 0 insertions, 122 deletions
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 6e1a746815..de7ae6ec27 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -17,22 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools -import os import tempfile -import six -import sqlite3 - from tensorflow.contrib.summary import summary_ops from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import function from tensorflow.python.eager import test -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import state_ops from tensorflow.python.platform import gfile from tensorflow.python.training import training_util @@ -94,120 +86,6 @@ class TargetTest(test_util.TensorFlowTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') - def testSummaryGlobalStep(self): - global_step = training_util.get_or_create_global_step() - logdir = tempfile.mkdtemp() - with summary_ops.create_summary_file_writer( - logdir, max_queue=0, - name='t2').as_default(), summary_ops.always_record_summaries(): - - summary_ops.scalar('scalar', 2.0, global_step=global_step) - - events = summary_test_util.events_from_file(logdir) - self.assertEqual(len(events), 2) - self.assertEqual(events[1].summary.value[0].tag, 'scalar') - - -class DbTest(test_util.TensorFlowTestCase): - - def setUp(self): - self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') - if os.path.exists(self.db_path): - os.unlink(self.db_path) - self.db = sqlite3.connect(self.db_path) - self.create_summary_db_writer = functools.partial( - summary_ops.create_summary_db_writer, - db_uri=self.db_path, - experiment_name='experiment', - run_name='run', - user_name='user') - - def tearDown(self): - self.db.close() - - def testIntegerSummaries(self): - step = training_util.create_global_step() - - def adder(x, y): - state_ops.assign_add(step, 1) - summary_ops.generic('x', x) - summary_ops.generic('y', y) - sum_ = x + y - summary_ops.generic('sum', sum_) - return sum_ - - with summary_ops.always_record_summaries(): - with self.create_summary_db_writer().as_default(): - self.assertEqual(5, adder(int64(2), int64(3)).numpy()) - - six.assertCountEqual(self, [1, 1, 1], - get_all(self.db, 'SELECT step FROM Tensors')) - six.assertCountEqual(self, ['x', 'y', 'sum'], - get_all(self.db, 'SELECT tag_name FROM Tags')) - x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"') - y_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "y"') - sum_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "sum"') - - with summary_ops.always_record_summaries(): - with self.create_summary_db_writer().as_default(): - self.assertEqual(9, adder(int64(4), int64(5)).numpy()) - - six.assertCountEqual(self, [1, 1, 1, 2, 2, 2], - get_all(self.db, 'SELECT step FROM Tensors')) - six.assertCountEqual(self, [x_id, y_id, sum_id], - get_all(self.db, 'SELECT tag_id FROM Tags')) - self.assertEqual(2, get_tensor(self.db, x_id, 1)) - self.assertEqual(3, get_tensor(self.db, y_id, 1)) - self.assertEqual(5, get_tensor(self.db, sum_id, 1)) - self.assertEqual(4, get_tensor(self.db, x_id, 2)) - self.assertEqual(5, get_tensor(self.db, y_id, 2)) - self.assertEqual(9, get_tensor(self.db, sum_id, 2)) - six.assertCountEqual( - self, ['experiment'], - get_all(self.db, 'SELECT experiment_name FROM Experiments')) - six.assertCountEqual(self, ['run'], - get_all(self.db, 'SELECT run_name FROM Runs')) - six.assertCountEqual(self, ['user'], - get_all(self.db, 'SELECT user_name FROM Users')) - - def testBadExperimentName(self): - with self.assertRaises(ValueError): - self.create_summary_db_writer(experiment_name='\0') - - def testBadRunName(self): - with self.assertRaises(ValueError): - self.create_summary_db_writer(run_name='\0') - - def testBadUserName(self): - with self.assertRaises(ValueError): - self.create_summary_db_writer(user_name='-hi') - with self.assertRaises(ValueError): - self.create_summary_db_writer(user_name='hi-') - with self.assertRaises(ValueError): - self.create_summary_db_writer(user_name='@') - - -def get_one(db, q, *p): - return db.execute(q, p).fetchone()[0] - - -def get_all(db, q, *p): - return unroll(db.execute(q, p).fetchall()) - - -def get_tensor(db, tag_id, step): - return get_one( - db, 'SELECT tensor FROM Tensors WHERE tag_id = ? AND step = ?', tag_id, - step) - - -def int64(x): - return array_ops.constant(x, dtypes.int64) - - -def unroll(list_of_tuples): - return sum(list_of_tuples, ()) - if __name__ == '__main__': test.main() |