aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary/summary_test_internal.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/summary/summary_test_internal.py')
-rw-r--r--tensorflow/contrib/summary/summary_test_internal.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/tensorflow/contrib/summary/summary_test_internal.py b/tensorflow/contrib/summary/summary_test_internal.py
new file mode 100644
index 0000000000..54233f2f50
--- /dev/null
+++ b/tensorflow/contrib/summary/summary_test_internal.py
@@ -0,0 +1,59 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Internal helpers for tests in this directory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import os
+import sqlite3
+
+from tensorflow.contrib.summary import summary_ops
+from tensorflow.python.framework import test_util
+
+
+class SummaryDbTest(test_util.TensorFlowTestCase):
+ """Helper for summary database testing."""
+
+ def setUp(self):
+ super(SummaryDbTest, self).setUp()
+ self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
+ if os.path.exists(self.db_path):
+ os.unlink(self.db_path)
+ self.db = sqlite3.connect(self.db_path)
+ self.create_summary_db_writer = functools.partial(
+ summary_ops.create_summary_db_writer,
+ db_uri=self.db_path,
+ experiment_name='experiment',
+ run_name='run',
+ user_name='user')
+
+ def tearDown(self):
+ self.db.close()
+ super(SummaryDbTest, self).tearDown()
+
+
+def get_one(db, q, *p):
+ return db.execute(q, p).fetchone()[0]
+
+
+def get_all(db, q, *p):
+ return unroll(db.execute(q, p).fetchall())
+
+
+def unroll(list_of_tuples):
+ return sum(list_of_tuples, ())