aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary/summary_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/summary/summary_ops_test.py')
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py122
1 files changed, 0 insertions, 122 deletions
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 6e1a746815..de7ae6ec27 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -17,22 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
-import os
import tempfile
-import six
-import sqlite3
-
from tensorflow.contrib.summary import summary_ops
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import function
from tensorflow.python.eager import test
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
@@ -94,120 +86,6 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
- def testSummaryGlobalStep(self):
- global_step = training_util.get_or_create_global_step()
- logdir = tempfile.mkdtemp()
- with summary_ops.create_summary_file_writer(
- logdir, max_queue=0,
- name='t2').as_default(), summary_ops.always_record_summaries():
-
- summary_ops.scalar('scalar', 2.0, global_step=global_step)
-
- events = summary_test_util.events_from_file(logdir)
- self.assertEqual(len(events), 2)
- self.assertEqual(events[1].summary.value[0].tag, 'scalar')
-
-
-class DbTest(test_util.TensorFlowTestCase):
-
- def setUp(self):
- self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
- if os.path.exists(self.db_path):
- os.unlink(self.db_path)
- self.db = sqlite3.connect(self.db_path)
- self.create_summary_db_writer = functools.partial(
- summary_ops.create_summary_db_writer,
- db_uri=self.db_path,
- experiment_name='experiment',
- run_name='run',
- user_name='user')
-
- def tearDown(self):
- self.db.close()
-
- def testIntegerSummaries(self):
- step = training_util.create_global_step()
-
- def adder(x, y):
- state_ops.assign_add(step, 1)
- summary_ops.generic('x', x)
- summary_ops.generic('y', y)
- sum_ = x + y
- summary_ops.generic('sum', sum_)
- return sum_
-
- with summary_ops.always_record_summaries():
- with self.create_summary_db_writer().as_default():
- self.assertEqual(5, adder(int64(2), int64(3)).numpy())
-
- six.assertCountEqual(self, [1, 1, 1],
- get_all(self.db, 'SELECT step FROM Tensors'))
- six.assertCountEqual(self, ['x', 'y', 'sum'],
- get_all(self.db, 'SELECT tag_name FROM Tags'))
- x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"')
- y_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "y"')
- sum_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "sum"')
-
- with summary_ops.always_record_summaries():
- with self.create_summary_db_writer().as_default():
- self.assertEqual(9, adder(int64(4), int64(5)).numpy())
-
- six.assertCountEqual(self, [1, 1, 1, 2, 2, 2],
- get_all(self.db, 'SELECT step FROM Tensors'))
- six.assertCountEqual(self, [x_id, y_id, sum_id],
- get_all(self.db, 'SELECT tag_id FROM Tags'))
- self.assertEqual(2, get_tensor(self.db, x_id, 1))
- self.assertEqual(3, get_tensor(self.db, y_id, 1))
- self.assertEqual(5, get_tensor(self.db, sum_id, 1))
- self.assertEqual(4, get_tensor(self.db, x_id, 2))
- self.assertEqual(5, get_tensor(self.db, y_id, 2))
- self.assertEqual(9, get_tensor(self.db, sum_id, 2))
- six.assertCountEqual(
- self, ['experiment'],
- get_all(self.db, 'SELECT experiment_name FROM Experiments'))
- six.assertCountEqual(self, ['run'],
- get_all(self.db, 'SELECT run_name FROM Runs'))
- six.assertCountEqual(self, ['user'],
- get_all(self.db, 'SELECT user_name FROM Users'))
-
- def testBadExperimentName(self):
- with self.assertRaises(ValueError):
- self.create_summary_db_writer(experiment_name='\0')
-
- def testBadRunName(self):
- with self.assertRaises(ValueError):
- self.create_summary_db_writer(run_name='\0')
-
- def testBadUserName(self):
- with self.assertRaises(ValueError):
- self.create_summary_db_writer(user_name='-hi')
- with self.assertRaises(ValueError):
- self.create_summary_db_writer(user_name='hi-')
- with self.assertRaises(ValueError):
- self.create_summary_db_writer(user_name='@')
-
-
-def get_one(db, q, *p):
- return db.execute(q, p).fetchone()[0]
-
-
-def get_all(db, q, *p):
- return unroll(db.execute(q, p).fetchall())
-
-
-def get_tensor(db, tag_id, step):
- return get_one(
- db, 'SELECT tensor FROM Tensors WHERE tag_id = ? AND step = ?', tag_id,
- step)
-
-
-def int64(x):
- return array_ops.constant(x, dtypes.int64)
-
-
-def unroll(list_of_tuples):
- return sum(list_of_tuples, ())
-
if __name__ == '__main__':
test.main()