aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2017-11-15 11:31:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-15 11:35:49 -0800
commit6fb721d608c4cd3855fe8793099a629428b9853c (patch)
treefaef08ed8bac4f5a8b065825a4405ef8a12e875f /tensorflow/contrib/summary
parentb7b183b90aee8a4f4808f7d90a2c7a54a942e640 (diff)
Add graph writer op to contrib/summary
This change also defines a simple SQL data model for tf.GraphDef, which should move us closer to a world where TensorBoard can render the graph explorer without having to download the entire thing to the browser, as that could potentially be hundreds of megabytes. PiperOrigin-RevId: 175854921
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/BUILD29
-rw-r--r--tensorflow/contrib/summary/summary.py3
-rw-r--r--tensorflow/contrib/summary/summary_ops.py149
-rw-r--r--tensorflow/contrib/summary/summary_ops_graph_test.py52
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py47
-rw-r--r--tensorflow/contrib/summary/summary_test_internal.py59
6 files changed, 291 insertions, 48 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index d1beafcb28..3892654f25 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -25,13 +25,12 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":summary_ops",
+ ":summary_test_internal",
":summary_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
@@ -41,6 +40,20 @@ py_test(
],
)
+py_test(
+ name = "summary_ops_graph_test",
+ srcs = ["summary_ops_graph_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":summary_ops",
+ ":summary_test_internal",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ ],
+)
+
py_library(
name = "summary_ops",
srcs = ["summary_ops.py"],
@@ -98,3 +111,15 @@ py_library(
"//tensorflow/python:platform",
],
)
+
+py_library(
+ name = "summary_test_internal",
+ testonly = 1,
+ srcs = ["summary_test_internal.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ "//tensorflow/python:lib",
+ "//tensorflow/python:platform",
+ ],
+)
diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py
index 813e8b2b09..a73193f460 100644
--- a/tensorflow/contrib/summary/summary.py
+++ b/tensorflow/contrib/summary/summary.py
@@ -32,11 +32,14 @@ from tensorflow.contrib.summary.summary_ops import create_summary_db_writer
from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
from tensorflow.contrib.summary.summary_ops import eval_dir
from tensorflow.contrib.summary.summary_ops import generic
+from tensorflow.contrib.summary.summary_ops import graph
from tensorflow.contrib.summary.summary_ops import histogram
from tensorflow.contrib.summary.summary_ops import image
from tensorflow.contrib.summary.summary_ops import import_event
+from tensorflow.contrib.summary.summary_ops import initialize
from tensorflow.contrib.summary.summary_ops import never_record_summaries
from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps
from tensorflow.contrib.summary.summary_ops import scalar
from tensorflow.contrib.summary.summary_ops import should_record_summaries
from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op
+from tensorflow.contrib.summary.summary_ops import SummaryWriter
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index f6be99f6ae..a72c0c80aa 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -27,6 +27,7 @@ import time
import six
from tensorflow.contrib.summary import gen_summary_ops
+from tensorflow.core.framework import graph_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -99,25 +100,32 @@ def never_record_summaries():
class SummaryWriter(object):
- """Encapsulates a summary writer."""
+ """Encapsulates a stateful summary writer resource.
- def __init__(self, resource):
+ See also:
+ - @{tf.contrib.summary.create_summary_file_writer}
+ - @{tf.contrib.summary.create_summary_db_writer}
+ """
+
+ def __init__(self, resource):
self._resource = resource
if context.in_eager_mode():
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device="cpu:0")
def set_as_default(self):
+ """Enables this summary writer for the current thread."""
context.context().summary_writer_resource = self._resource
@tf_contextlib.contextmanager
def as_default(self):
+ """Enables summary writing within a `with` block."""
if self._resource is None:
- yield
+ yield self
else:
old = context.context().summary_writer_resource
context.context().summary_writer_resource = self._resource
- yield
+ yield self
# Flushes the summary writer in eager mode or in graph functions, but not
# in legacy graph mode (you're on your own there).
with ops.device("cpu:0"):
@@ -125,6 +133,43 @@ class SummaryWriter(object):
context.context().summary_writer_resource = old
+def initialize(
+ graph=None, # pylint: disable=redefined-outer-name
+ session=None):
+ """Initializes summary writing for graph execution mode.
+
+ This helper method provides a higher-level alternative to using
+ @{tf.contrib.summary.summary_writer_initializer_op} and
+ @{tf.contrib.summary.graph}.
+
+ Most users will also want to call @{tf.train.create_global_step}
+ which can happen before or after this function is called.
+
+ Args:
+ graph: A @{tf.Graph} or @{tf.GraphDef} to output to the writer.
+ This function will not write the default graph by default. When
+ writing to an event log file, the associated step will be zero.
+ session: So this method can call @{tf.Session.run}. This defaults
+ to @{tf.get_default_session}.
+
+ Raises:
+ RuntimeError: If in eager mode, or if the current thread has no
+ default @{tf.contrib.summary.SummaryWriter}.
+ ValueError: If session wasn't passed and no default session.
+ """
+ if context.context().summary_writer_resource is None:
+ raise RuntimeError("No default tf.contrib.summary.SummaryWriter found")
+ if session is None:
+ session = ops.get_default_session()
+ if session is None:
+ raise ValueError("session must be passed if no default session exists")
+ session.run(summary_writer_initializer_op())
+ if graph is not None:
+ data = _serialize_graph(graph)
+ x = array_ops.placeholder(dtypes.string)
+ session.run(_graph(x, 0), feed_dict={x: data})
+
+
def create_summary_file_writer(logdir,
max_queue=None,
flush_millis=None,
@@ -192,10 +237,10 @@ def create_summary_db_writer(db_uri,
Experiment will not be associated with a User. Must be valid as
both a DNS label and Linux username.
name: Shared name for this SummaryWriter resource stored to default
- Graph.
+ @{tf.Graph}.
Returns:
- A new SummaryWriter instance.
+ A @{tf.contrib.summary.SummaryWriter} instance.
"""
with ops.device("cpu:0"):
if experiment_name is None:
@@ -240,7 +285,16 @@ def _nothing():
def all_summary_ops():
- """Graph-mode only. Returns all summary ops."""
+ """Graph-mode only. Returns all summary ops.
+
+ Please note this excludes @{tf.contrib.summary.graph} ops.
+
+ Returns:
+ The summary ops.
+
+ Raises:
+ RuntimeError: If in Eager mode.
+ """
if context.in_eager_mode():
raise RuntimeError(
"tf.contrib.summary.all_summary_ops is only supported in graph mode.")
@@ -248,7 +302,14 @@ def all_summary_ops():
def summary_writer_initializer_op():
- """Graph-mode only. Returns the list of ops to create all summary writers."""
+ """Graph-mode only. Returns the list of ops to create all summary writers.
+
+ Returns:
+ The initializer ops.
+
+ Raises:
+ RuntimeError: If in Eager mode.
+ """
if context.in_eager_mode():
raise RuntimeError(
"tf.contrib.summary.summary_writer_initializer_op is only "
@@ -367,21 +428,72 @@ def audio(name, tensor, sample_rate, max_outputs, family=None,
return summary_writer_function(name, tensor, function, family=family)
+def graph(param, step=None, name=None):
+ """Writes a TensorFlow graph to the summary interface.
+
+ The graph summary is, strictly speaking, not a summary. Conditions
+ like @{tf.contrib.summary.never_record_summaries} do not apply. Only
+ a single graph can be associated with a particular run. If multiple
+ graphs are written, then only the last one will be considered by
+ TensorBoard.
+
+ When not using eager execution mode, the user should consider passing
+ the `graph` parameter to @{tf.contrib.summary.initialize} instead of
+ calling this function. Otherwise special care needs to be taken when
+ using the graph to record the graph.
+
+ Args:
+ param: A @{tf.Tensor} containing a serialized graph proto. When
+ eager execution is enabled, this function will automatically
+ coerce @{tf.Graph}, @{tf.GraphDef}, and string types.
+ step: The global step variable. This doesn't have useful semantics
+ for graph summaries, but is used anyway, due to the structure of
+ event log files. This defaults to the global step.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created @{tf.Operation} or a @{tf.no_op} if summary writing has
+ not been enabled for this context.
+
+ Raises:
+ TypeError: If `param` isn't already a @{tf.Tensor} in graph mode.
+ """
+ if not context.in_eager_mode() and not isinstance(param, ops.Tensor):
+ raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph "
+ "mode, but was: %s" % type(param))
+ writer = context.context().summary_writer_resource
+ if writer is None:
+ return control_flow_ops.no_op()
+ with ops.device("cpu:0"):
+ if step is None:
+ step = training_util.get_global_step()
+ else:
+ step = ops.convert_to_tensor(step, dtypes.int64)
+ if isinstance(param, (ops.Graph, graph_pb2.GraphDef)):
+ tensor = ops.convert_to_tensor(_serialize_graph(param), dtypes.string)
+ else:
+ tensor = array_ops.identity(param)
+ return gen_summary_ops.write_graph_summary(writer, step, tensor, name=name)
+
+_graph = graph # for functions with a graph parameter
+
+
def import_event(tensor, name=None):
- """Writes a tf.Event binary proto.
+ """Writes a @{tf.Event} binary proto.
When using create_summary_db_writer(), this can be used alongside
- tf.TFRecordReader to load event logs into the database. Please note
- that this is lower level than the other summary functions and will
- ignore any conditions set by methods like should_record_summaries().
+ @{tf.TFRecordReader} to load event logs into the database. Please
+ note that this is lower level than the other summary functions and
+ will ignore any conditions set by methods like
+ @{tf.contrib.summary.should_record_summaries}.
Args:
- tensor: A `Tensor` of type `string` containing a serialized `Event`
- proto.
+ tensor: A @{tf.Tensor} of type `string` containing a serialized
+ @{tf.Event} proto.
name: A name for the operation (optional).
Returns:
- The created Operation.
+ The created @{tf.Operation}.
"""
return gen_summary_ops.import_event(
context.context().summary_writer_resource, tensor, name=name)
@@ -390,3 +502,10 @@ def import_event(tensor, name=None):
def eval_dir(model_dir, name=None):
"""Construct a logdir for an eval summary writer."""
return os.path.join(model_dir, "eval" if not name else "eval_" + name)
+
+
+def _serialize_graph(arbitrary_graph):
+ if isinstance(arbitrary_graph, ops.Graph):
+ return arbitrary_graph.as_graph_def(add_shapes=True).SerializeToString()
+ else:
+ return arbitrary_graph.SerializeToString()
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py
new file mode 100644
index 0000000000..8f85f67a25
--- /dev/null
+++ b/tensorflow/contrib/summary/summary_ops_graph_test.py
@@ -0,0 +1,52 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib.summary import summary_ops
+from tensorflow.contrib.summary import summary_test_internal
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import training_util
+
+get_all = summary_test_internal.get_all
+
+
+class DbTest(summary_test_internal.SummaryDbTest):
+
+ def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self):
+ with self.assertRaises(TypeError):
+ summary_ops.graph(ops.Graph())
+ with self.assertRaises(TypeError):
+ summary_ops.graph('')
+
+ def testGraphSummary(self):
+ training_util.get_or_create_global_step()
+ name = 'hi'
+ graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
+ with self.test_session():
+ with self.create_summary_db_writer().as_default():
+ summary_ops.initialize(graph=graph)
+ six.assertCountEqual(self, [name],
+ get_all(self.db, 'SELECT node_name FROM Nodes'))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 6e1a746815..09169fa6d7 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -12,20 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
-import os
import tempfile
import six
-import sqlite3
from tensorflow.contrib.summary import summary_ops
+from tensorflow.contrib.summary import summary_test_internal
from tensorflow.contrib.summary import summary_test_util
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import node_def_pb2
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
@@ -36,6 +35,9 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
+get_all = summary_test_internal.get_all
+get_one = summary_test_internal.get_one
+
class TargetTest(test_util.TensorFlowTestCase):
@@ -108,22 +110,7 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
-class DbTest(test_util.TensorFlowTestCase):
-
- def setUp(self):
- self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
- if os.path.exists(self.db_path):
- os.unlink(self.db_path)
- self.db = sqlite3.connect(self.db_path)
- self.create_summary_db_writer = functools.partial(
- summary_ops.create_summary_db_writer,
- db_uri=self.db_path,
- experiment_name='experiment',
- run_name='run',
- user_name='user')
-
- def tearDown(self):
- self.db.close()
+class DbTest(summary_test_internal.SummaryDbTest):
def testIntegerSummaries(self):
step = training_util.create_global_step()
@@ -186,13 +173,15 @@ class DbTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
self.create_summary_db_writer(user_name='@')
-
-def get_one(db, q, *p):
- return db.execute(q, p).fetchone()[0]
-
-
-def get_all(db, q, *p):
- return unroll(db.execute(q, p).fetchall())
+ def testGraphSummary(self):
+ training_util.get_or_create_global_step()
+ name = 'hi'
+ graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
+ with summary_ops.always_record_summaries():
+ with self.create_summary_db_writer().as_default():
+ summary_ops.graph(graph)
+ six.assertCountEqual(self, [name],
+ get_all(self.db, 'SELECT node_name FROM Nodes'))
def get_tensor(db, tag_id, step):
@@ -205,9 +194,5 @@ def int64(x):
return array_ops.constant(x, dtypes.int64)
-def unroll(list_of_tuples):
- return sum(list_of_tuples, ())
-
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/summary/summary_test_internal.py b/tensorflow/contrib/summary/summary_test_internal.py
new file mode 100644
index 0000000000..54233f2f50
--- /dev/null
+++ b/tensorflow/contrib/summary/summary_test_internal.py
@@ -0,0 +1,59 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Internal helpers for tests in this directory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import os
+import sqlite3
+
+from tensorflow.contrib.summary import summary_ops
+from tensorflow.python.framework import test_util
+
+
+class SummaryDbTest(test_util.TensorFlowTestCase):
+ """Helper for summary database testing."""
+
+ def setUp(self):
+ super(SummaryDbTest, self).setUp()
+ self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
+ if os.path.exists(self.db_path):
+ os.unlink(self.db_path)
+ self.db = sqlite3.connect(self.db_path)
+ self.create_summary_db_writer = functools.partial(
+ summary_ops.create_summary_db_writer,
+ db_uri=self.db_path,
+ experiment_name='experiment',
+ run_name='run',
+ user_name='user')
+
+ def tearDown(self):
+ self.db.close()
+ super(SummaryDbTest, self).tearDown()
+
+
+def get_one(db, q, *p):
+ return db.execute(q, p).fetchone()[0]
+
+
+def get_all(db, q, *p):
+ return unroll(db.execute(q, p).fetchall())
+
+
+def unroll(list_of_tuples):
+ return sum(list_of_tuples, ())