diff options
author | Sourabh Bajaj <sourabhbajaj@google.com> | 2017-11-30 16:37:11 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-30 16:41:01 -0800 |
commit | b2db981a6731e978453862a73dab892bc674db68 (patch) | |
tree | c11a7c4038e2595268113c2859c1d0d3072ede4f /tensorflow/contrib/summary | |
parent | 0438ac79bdb503ed267bec2146e7136ac8e99ff9 (diff) |
Merge changes from github.
PiperOrigin-RevId: 177526301
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_graph_test.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_test_util.py | 39 |
4 files changed, 43 insertions, 22 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index 45d6454526..f34291c203 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -25,7 +25,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":summary_ops", - ":summary_test_internal", ":summary_test_util", "//tensorflow/python:array_ops", "//tensorflow/python:errors", @@ -46,7 +45,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":summary_ops", - ":summary_test_internal", ":summary_test_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -119,15 +117,3 @@ py_library( "//tensorflow/python:platform", ], ) - -py_library( - name = "summary_test_internal", - testonly = 1, - srcs = ["summary_test_internal.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:private"], - deps = [ - "//tensorflow/python:lib", - "//tensorflow/python:platform", - ], -) diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py index fe55bf93e2..703adb7b46 100644 --- a/tensorflow/contrib/summary/summary_ops_graph_test.py +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -21,7 +21,6 @@ import tempfile import six from tensorflow.contrib.summary import summary_ops -from tensorflow.contrib.summary import summary_test_internal from tensorflow.contrib.summary import summary_test_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 @@ -33,10 +32,10 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test from tensorflow.python.training import training_util -get_all = summary_test_internal.get_all +get_all = summary_test_util.get_all -class DbTest(summary_test_internal.SummaryDbTest): +class DbTest(summary_test_util.SummaryDbTest): def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self): with self.assertRaises(TypeError): diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 3fe421a7e9..54433deb28 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -21,7 +21,6 @@ import tempfile import six from tensorflow.contrib.summary import summary_ops -from tensorflow.contrib.summary import summary_test_internal from tensorflow.contrib.summary import summary_test_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 @@ -35,8 +34,8 @@ from tensorflow.python.ops import state_ops from tensorflow.python.platform import gfile from tensorflow.python.training import training_util -get_all = summary_test_internal.get_all -get_one = summary_test_internal.get_one +get_all = summary_test_util.get_all +get_one = summary_test_util.get_one class TargetTest(test_util.TensorFlowTestCase): @@ -137,7 +136,7 @@ class TargetTest(test_util.TensorFlowTestCase): self.assertEqual(3, get_total()) -class DbTest(summary_test_internal.SummaryDbTest): +class DbTest(summary_test_util.SummaryDbTest): def testIntegerSummaries(self): step = training_util.create_global_step() diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py index 794c5b8bab..915820e05b 100644 --- a/tensorflow/contrib/summary/summary_test_util.py +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -19,13 +19,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import os +import sqlite3 +from tensorflow.contrib.summary import summary_ops from tensorflow.core.util import event_pb2 +from tensorflow.python.framework import test_util from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile +class SummaryDbTest(test_util.TensorFlowTestCase): + """Helper for summary database testing.""" + + def setUp(self): + super(SummaryDbTest, self).setUp() + self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') + if os.path.exists(self.db_path): + os.unlink(self.db_path) + self.db = sqlite3.connect(self.db_path) + self.create_summary_db_writer = functools.partial( + summary_ops.create_summary_db_writer, + db_uri=self.db_path, + experiment_name='experiment', + run_name='run', + user_name='user') + + def tearDown(self): + self.db.close() + super(SummaryDbTest, self).tearDown() + + def events_from_file(filepath): """Returns all events in a single event file. @@ -58,5 +83,17 @@ def events_from_logdir(logdir): """ assert gfile.Exists(logdir) files = gfile.ListDirectory(logdir) - assert len(files) == 1, "Found not exactly one file in logdir: %s" % files + assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files return events_from_file(os.path.join(logdir, files[0])) + + +def get_one(db, q, *p): + return db.execute(q, p).fetchone()[0] + + +def get_all(db, q, *p): + return unroll(db.execute(q, p).fetchall()) + + +def unroll(list_of_tuples): + return sum(list_of_tuples, ()) |