diff options
author | 2017-11-15 11:31:43 -0800 | |
---|---|---|
committer | 2017-11-15 11:35:49 -0800 | |
commit | 6fb721d608c4cd3855fe8793099a629428b9853c (patch) | |
tree | faef08ed8bac4f5a8b065825a4405ef8a12e875f /tensorflow/contrib/summary | |
parent | b7b183b90aee8a4f4808f7d90a2c7a54a942e640 (diff) |
Add graph writer op to contrib/summary
This change also defines a simple SQL data model for tf.GraphDef, which
should move us closer to a world where TensorBoard can render the graph
explorer without having to download the entire thing to the browser, as
that could potentially be hundreds of megabytes.
PiperOrigin-RevId: 175854921
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/BUILD | 29 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 149 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_graph_test.py | 52 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 47 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_test_internal.py | 59 |
6 files changed, 291 insertions, 48 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index d1beafcb28..3892654f25 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -25,13 +25,12 @@ py_test( srcs_version = "PY2AND3", deps = [ ":summary_ops", + ":summary_test_internal", ":summary_test_util", "//tensorflow/python:array_ops", "//tensorflow/python:errors", "//tensorflow/python:framework", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:ops", "//tensorflow/python:platform", "//tensorflow/python:state_ops", "//tensorflow/python:training", @@ -41,6 +40,20 @@ py_test( ], ) +py_test( + name = "summary_ops_graph_test", + srcs = ["summary_ops_graph_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":summary_ops", + ":summary_test_internal", + "//tensorflow/python:client_testlib", + "//tensorflow/python:ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + ], +) + py_library( name = "summary_ops", srcs = ["summary_ops.py"], @@ -98,3 +111,15 @@ py_library( "//tensorflow/python:platform", ], ) + +py_library( + name = "summary_test_internal", + testonly = 1, + srcs = ["summary_test_internal.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/python:lib", + "//tensorflow/python:platform", + ], +) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index 813e8b2b09..a73193f460 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -32,11 +32,14 @@ from tensorflow.contrib.summary.summary_ops import create_summary_db_writer from tensorflow.contrib.summary.summary_ops import create_summary_file_writer from tensorflow.contrib.summary.summary_ops import eval_dir from tensorflow.contrib.summary.summary_ops import generic +from tensorflow.contrib.summary.summary_ops import graph from tensorflow.contrib.summary.summary_ops import histogram from tensorflow.contrib.summary.summary_ops import image from tensorflow.contrib.summary.summary_ops import import_event +from tensorflow.contrib.summary.summary_ops import initialize from tensorflow.contrib.summary.summary_ops import never_record_summaries from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps from tensorflow.contrib.summary.summary_ops import scalar from tensorflow.contrib.summary.summary_ops import should_record_summaries from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op +from tensorflow.contrib.summary.summary_ops import SummaryWriter diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index f6be99f6ae..a72c0c80aa 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -27,6 +27,7 @@ import time import six from tensorflow.contrib.summary import gen_summary_ops +from tensorflow.core.framework import graph_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -99,25 +100,32 @@ def never_record_summaries(): class SummaryWriter(object): - """Encapsulates a summary writer.""" + """Encapsulates a stateful summary writer resource. - def __init__(self, resource): + See also: + - @{tf.contrib.summary.create_summary_file_writer} + - @{tf.contrib.summary.create_summary_db_writer} + """ + + def __init__(self, resource): self._resource = resource if context.in_eager_mode(): self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="cpu:0") def set_as_default(self): + """Enables this summary writer for the current thread.""" context.context().summary_writer_resource = self._resource @tf_contextlib.contextmanager def as_default(self): + """Enables summary writing within a `with` block.""" if self._resource is None: - yield + yield self else: old = context.context().summary_writer_resource context.context().summary_writer_resource = self._resource - yield + yield self # Flushes the summary writer in eager mode or in graph functions, but not # in legacy graph mode (you're on your own there). with ops.device("cpu:0"): @@ -125,6 +133,43 @@ class SummaryWriter(object): context.context().summary_writer_resource = old +def initialize( + graph=None, # pylint: disable=redefined-outer-name + session=None): + """Initializes summary writing for graph execution mode. + + This helper method provides a higher-level alternative to using + @{tf.contrib.summary.summary_writer_initializer_op} and + @{tf.contrib.summary.graph}. + + Most users will also want to call @{tf.train.create_global_step} + which can happen before or after this function is called. + + Args: + graph: A @{tf.Graph} or @{tf.GraphDef} to output to the writer. + This function will not write the default graph by default. When + writing to an event log file, the associated step will be zero. + session: So this method can call @{tf.Session.run}. This defaults + to @{tf.get_default_session}. + + Raises: + RuntimeError: If in eager mode, or if the current thread has no + default @{tf.contrib.summary.SummaryWriter}. + ValueError: If session wasn't passed and no default session. + """ + if context.context().summary_writer_resource is None: + raise RuntimeError("No default tf.contrib.summary.SummaryWriter found") + if session is None: + session = ops.get_default_session() + if session is None: + raise ValueError("session must be passed if no default session exists") + session.run(summary_writer_initializer_op()) + if graph is not None: + data = _serialize_graph(graph) + x = array_ops.placeholder(dtypes.string) + session.run(_graph(x, 0), feed_dict={x: data}) + + def create_summary_file_writer(logdir, max_queue=None, flush_millis=None, @@ -192,10 +237,10 @@ def create_summary_db_writer(db_uri, Experiment will not be associated with a User. Must be valid as both a DNS label and Linux username. name: Shared name for this SummaryWriter resource stored to default - Graph. + @{tf.Graph}. Returns: - A new SummaryWriter instance. + A @{tf.contrib.summary.SummaryWriter} instance. """ with ops.device("cpu:0"): if experiment_name is None: @@ -240,7 +285,16 @@ def _nothing(): def all_summary_ops(): - """Graph-mode only. Returns all summary ops.""" + """Graph-mode only. Returns all summary ops. + + Please note this excludes @{tf.contrib.summary.graph} ops. + + Returns: + The summary ops. + + Raises: + RuntimeError: If in Eager mode. + """ if context.in_eager_mode(): raise RuntimeError( "tf.contrib.summary.all_summary_ops is only supported in graph mode.") @@ -248,7 +302,14 @@ def all_summary_ops(): def summary_writer_initializer_op(): - """Graph-mode only. Returns the list of ops to create all summary writers.""" + """Graph-mode only. Returns the list of ops to create all summary writers. + + Returns: + The initializer ops. + + Raises: + RuntimeError: If in Eager mode. + """ if context.in_eager_mode(): raise RuntimeError( "tf.contrib.summary.summary_writer_initializer_op is only " @@ -367,21 +428,72 @@ def audio(name, tensor, sample_rate, max_outputs, family=None, return summary_writer_function(name, tensor, function, family=family) +def graph(param, step=None, name=None): + """Writes a TensorFlow graph to the summary interface. + + The graph summary is, strictly speaking, not a summary. Conditions + like @{tf.contrib.summary.never_record_summaries} do not apply. Only + a single graph can be associated with a particular run. If multiple + graphs are written, then only the last one will be considered by + TensorBoard. + + When not using eager execution mode, the user should consider passing + the `graph` parameter to @{tf.contrib.summary.initialize} instead of + calling this function. Otherwise special care needs to be taken when + using the graph to record the graph. + + Args: + param: A @{tf.Tensor} containing a serialized graph proto. When + eager execution is enabled, this function will automatically + coerce @{tf.Graph}, @{tf.GraphDef}, and string types. + step: The global step variable. This doesn't have useful semantics + for graph summaries, but is used anyway, due to the structure of + event log files. This defaults to the global step. + name: A name for the operation (optional). + + Returns: + The created @{tf.Operation} or a @{tf.no_op} if summary writing has + not been enabled for this context. + + Raises: + TypeError: If `param` isn't already a @{tf.Tensor} in graph mode. + """ + if not context.in_eager_mode() and not isinstance(param, ops.Tensor): + raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph " + "mode, but was: %s" % type(param)) + writer = context.context().summary_writer_resource + if writer is None: + return control_flow_ops.no_op() + with ops.device("cpu:0"): + if step is None: + step = training_util.get_global_step() + else: + step = ops.convert_to_tensor(step, dtypes.int64) + if isinstance(param, (ops.Graph, graph_pb2.GraphDef)): + tensor = ops.convert_to_tensor(_serialize_graph(param), dtypes.string) + else: + tensor = array_ops.identity(param) + return gen_summary_ops.write_graph_summary(writer, step, tensor, name=name) + +_graph = graph # for functions with a graph parameter + + def import_event(tensor, name=None): - """Writes a tf.Event binary proto. + """Writes a @{tf.Event} binary proto. When using create_summary_db_writer(), this can be used alongside - tf.TFRecordReader to load event logs into the database. Please note - that this is lower level than the other summary functions and will - ignore any conditions set by methods like should_record_summaries(). + @{tf.TFRecordReader} to load event logs into the database. Please + note that this is lower level than the other summary functions and + will ignore any conditions set by methods like + @{tf.contrib.summary.should_record_summaries}. Args: - tensor: A `Tensor` of type `string` containing a serialized `Event` - proto. + tensor: A @{tf.Tensor} of type `string` containing a serialized + @{tf.Event} proto. name: A name for the operation (optional). Returns: - The created Operation. + The created @{tf.Operation}. """ return gen_summary_ops.import_event( context.context().summary_writer_resource, tensor, name=name) @@ -390,3 +502,10 @@ def import_event(tensor, name=None): def eval_dir(model_dir, name=None): """Construct a logdir for an eval summary writer.""" return os.path.join(model_dir, "eval" if not name else "eval_" + name) + + +def _serialize_graph(arbitrary_graph): + if isinstance(arbitrary_graph, ops.Graph): + return arbitrary_graph.as_graph_def(add_shapes=True).SerializeToString() + else: + return arbitrary_graph.SerializeToString() diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py new file mode 100644 index 0000000000..8f85f67a25 --- /dev/null +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.summary import summary_ops +from tensorflow.contrib.summary import summary_test_internal +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + +get_all = summary_test_internal.get_all + + +class DbTest(summary_test_internal.SummaryDbTest): + + def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self): + with self.assertRaises(TypeError): + summary_ops.graph(ops.Graph()) + with self.assertRaises(TypeError): + summary_ops.graph('') + + def testGraphSummary(self): + training_util.get_or_create_global_step() + name = 'hi' + graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) + with self.test_session(): + with self.create_summary_db_writer().as_default(): + summary_ops.initialize(graph=graph) + six.assertCountEqual(self, [name], + get_all(self.db, 'SELECT node_name FROM Nodes')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 6e1a746815..09169fa6d7 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -12,20 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools -import os import tempfile import six -import sqlite3 from tensorflow.contrib.summary import summary_ops +from tensorflow.contrib.summary import summary_test_internal from tensorflow.contrib.summary import summary_test_util +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import dtypes @@ -36,6 +35,9 @@ from tensorflow.python.ops import state_ops from tensorflow.python.platform import gfile from tensorflow.python.training import training_util +get_all = summary_test_internal.get_all +get_one = summary_test_internal.get_one + class TargetTest(test_util.TensorFlowTestCase): @@ -108,22 +110,7 @@ class TargetTest(test_util.TensorFlowTestCase): self.assertEqual(events[1].summary.value[0].tag, 'scalar') -class DbTest(test_util.TensorFlowTestCase): - - def setUp(self): - self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') - if os.path.exists(self.db_path): - os.unlink(self.db_path) - self.db = sqlite3.connect(self.db_path) - self.create_summary_db_writer = functools.partial( - summary_ops.create_summary_db_writer, - db_uri=self.db_path, - experiment_name='experiment', - run_name='run', - user_name='user') - - def tearDown(self): - self.db.close() +class DbTest(summary_test_internal.SummaryDbTest): def testIntegerSummaries(self): step = training_util.create_global_step() @@ -186,13 +173,15 @@ class DbTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): self.create_summary_db_writer(user_name='@') - -def get_one(db, q, *p): - return db.execute(q, p).fetchone()[0] - - -def get_all(db, q, *p): - return unroll(db.execute(q, p).fetchall()) + def testGraphSummary(self): + training_util.get_or_create_global_step() + name = 'hi' + graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) + with summary_ops.always_record_summaries(): + with self.create_summary_db_writer().as_default(): + summary_ops.graph(graph) + six.assertCountEqual(self, [name], + get_all(self.db, 'SELECT node_name FROM Nodes')) def get_tensor(db, tag_id, step): @@ -205,9 +194,5 @@ def int64(x): return array_ops.constant(x, dtypes.int64) -def unroll(list_of_tuples): - return sum(list_of_tuples, ()) - - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/summary/summary_test_internal.py b/tensorflow/contrib/summary/summary_test_internal.py new file mode 100644 index 0000000000..54233f2f50 --- /dev/null +++ b/tensorflow/contrib/summary/summary_test_internal.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Internal helpers for tests in this directory.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os +import sqlite3 + +from tensorflow.contrib.summary import summary_ops +from tensorflow.python.framework import test_util + + +class SummaryDbTest(test_util.TensorFlowTestCase): + """Helper for summary database testing.""" + + def setUp(self): + super(SummaryDbTest, self).setUp() + self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') + if os.path.exists(self.db_path): + os.unlink(self.db_path) + self.db = sqlite3.connect(self.db_path) + self.create_summary_db_writer = functools.partial( + summary_ops.create_summary_db_writer, + db_uri=self.db_path, + experiment_name='experiment', + run_name='run', + user_name='user') + + def tearDown(self): + self.db.close() + super(SummaryDbTest, self).tearDown() + + +def get_one(db, q, *p): + return db.execute(q, p).fetchone()[0] + + +def get_all(db, q, *p): + return unroll(db.execute(q, p).fetchall()) + + +def unroll(list_of_tuples): + return sum(list_of_tuples, ()) |