aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2017-11-30 16:37:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 16:41:01 -0800
commitb2db981a6731e978453862a73dab892bc674db68 (patch)
treec11a7c4038e2595268113c2859c1d0d3072ede4f /tensorflow/contrib/summary
parent0438ac79bdb503ed267bec2146e7136ac8e99ff9 (diff)
Merge changes from github.
PiperOrigin-RevId: 177526301
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/BUILD14
-rw-r--r--tensorflow/contrib/summary/summary_ops_graph_test.py5
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py7
-rw-r--r--tensorflow/contrib/summary/summary_test_util.py39
4 files changed, 43 insertions, 22 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index 45d6454526..f34291c203 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -25,7 +25,6 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":summary_ops",
- ":summary_test_internal",
":summary_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
@@ -46,7 +45,6 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":summary_ops",
- ":summary_test_internal",
":summary_test_util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@@ -119,15 +117,3 @@ py_library(
"//tensorflow/python:platform",
],
)
-
-py_library(
- name = "summary_test_internal",
- testonly = 1,
- srcs = ["summary_test_internal.py"],
- srcs_version = "PY2AND3",
- visibility = ["//visibility:private"],
- deps = [
- "//tensorflow/python:lib",
- "//tensorflow/python:platform",
- ],
-)
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py
index fe55bf93e2..703adb7b46 100644
--- a/tensorflow/contrib/summary/summary_ops_graph_test.py
+++ b/tensorflow/contrib/summary/summary_ops_graph_test.py
@@ -21,7 +21,6 @@ import tempfile
import six
from tensorflow.contrib.summary import summary_ops
-from tensorflow.contrib.summary import summary_test_internal
from tensorflow.contrib.summary import summary_test_util
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
@@ -33,10 +32,10 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.training import training_util
-get_all = summary_test_internal.get_all
+get_all = summary_test_util.get_all
-class DbTest(summary_test_internal.SummaryDbTest):
+class DbTest(summary_test_util.SummaryDbTest):
def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self):
with self.assertRaises(TypeError):
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 3fe421a7e9..54433deb28 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -21,7 +21,6 @@ import tempfile
import six
from tensorflow.contrib.summary import summary_ops
-from tensorflow.contrib.summary import summary_test_internal
from tensorflow.contrib.summary import summary_test_util
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
@@ -35,8 +34,8 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
-get_all = summary_test_internal.get_all
-get_one = summary_test_internal.get_one
+get_all = summary_test_util.get_all
+get_one = summary_test_util.get_one
class TargetTest(test_util.TensorFlowTestCase):
@@ -137,7 +136,7 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(3, get_total())
-class DbTest(summary_test_internal.SummaryDbTest):
+class DbTest(summary_test_util.SummaryDbTest):
def testIntegerSummaries(self):
step = training_util.create_global_step()
diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py
index 794c5b8bab..915820e05b 100644
--- a/tensorflow/contrib/summary/summary_test_util.py
+++ b/tensorflow/contrib/summary/summary_test_util.py
@@ -19,13 +19,38 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import os
+import sqlite3
+from tensorflow.contrib.summary import summary_ops
from tensorflow.core.util import event_pb2
+from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile
+class SummaryDbTest(test_util.TensorFlowTestCase):
+ """Helper for summary database testing."""
+
+ def setUp(self):
+ super(SummaryDbTest, self).setUp()
+ self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
+ if os.path.exists(self.db_path):
+ os.unlink(self.db_path)
+ self.db = sqlite3.connect(self.db_path)
+ self.create_summary_db_writer = functools.partial(
+ summary_ops.create_summary_db_writer,
+ db_uri=self.db_path,
+ experiment_name='experiment',
+ run_name='run',
+ user_name='user')
+
+ def tearDown(self):
+ self.db.close()
+ super(SummaryDbTest, self).tearDown()
+
+
def events_from_file(filepath):
"""Returns all events in a single event file.
@@ -58,5 +83,17 @@ def events_from_logdir(logdir):
"""
assert gfile.Exists(logdir)
files = gfile.ListDirectory(logdir)
- assert len(files) == 1, "Found not exactly one file in logdir: %s" % files
+ assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
return events_from_file(os.path.join(logdir, files[0]))
+
+
+def get_one(db, q, *p):
+ return db.execute(q, p).fetchone()[0]
+
+
+def get_all(db, q, *p):
+ return unroll(db.execute(q, p).fetchall())
+
+
+def unroll(list_of_tuples):
+ return sum(list_of_tuples, ())