diff options
author | 2017-11-15 11:31:43 -0800 | |
---|---|---|
committer | 2017-11-15 11:35:49 -0800 | |
commit | 6fb721d608c4cd3855fe8793099a629428b9853c (patch) | |
tree | faef08ed8bac4f5a8b065825a4405ef8a12e875f /tensorflow | |
parent | b7b183b90aee8a4f4808f7d90a2c7a54a942e640 (diff) |
Add graph writer op to contrib/summary
This change also defines a simple SQL data model for tf.GraphDef, which
should move us closer to a world where TensorBoard can render the graph
explorer without having to download the entire thing to the browser, as
that could potentially be hundreds of megabytes.
PiperOrigin-RevId: 175854921
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/summary/BUILD | 29 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 149 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_graph_test.py | 52 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 47 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_test_internal.py | 59 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/schema.cc | 141 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer.cc | 272 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc | 78 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/summary_interface.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/kernels/summary_interface.h | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/summary_kernels.cc | 25 | ||||
-rw-r--r-- | tensorflow/core/ops/summary_ops.cc | 13 | ||||
-rw-r--r-- | tensorflow/tools/pip_package/pip_smoke_test.py | 3 |
15 files changed, 751 insertions, 135 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index d1beafcb28..3892654f25 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -25,13 +25,12 @@ py_test( srcs_version = "PY2AND3", deps = [ ":summary_ops", + ":summary_test_internal", ":summary_test_util", "//tensorflow/python:array_ops", "//tensorflow/python:errors", "//tensorflow/python:framework", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:ops", "//tensorflow/python:platform", "//tensorflow/python:state_ops", "//tensorflow/python:training", @@ -41,6 +40,20 @@ py_test( ], ) +py_test( + name = "summary_ops_graph_test", + srcs = ["summary_ops_graph_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":summary_ops", + ":summary_test_internal", + "//tensorflow/python:client_testlib", + "//tensorflow/python:ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + ], +) + py_library( name = "summary_ops", srcs = ["summary_ops.py"], @@ -98,3 +111,15 @@ py_library( "//tensorflow/python:platform", ], ) + +py_library( + name = "summary_test_internal", + testonly = 1, + srcs = ["summary_test_internal.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/python:lib", + "//tensorflow/python:platform", + ], +) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index 813e8b2b09..a73193f460 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -32,11 +32,14 @@ from tensorflow.contrib.summary.summary_ops import create_summary_db_writer from tensorflow.contrib.summary.summary_ops import create_summary_file_writer from tensorflow.contrib.summary.summary_ops import eval_dir from tensorflow.contrib.summary.summary_ops import generic +from tensorflow.contrib.summary.summary_ops import graph from tensorflow.contrib.summary.summary_ops import histogram from tensorflow.contrib.summary.summary_ops import image from tensorflow.contrib.summary.summary_ops import import_event +from tensorflow.contrib.summary.summary_ops import initialize from tensorflow.contrib.summary.summary_ops import never_record_summaries from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps from tensorflow.contrib.summary.summary_ops import scalar from tensorflow.contrib.summary.summary_ops import should_record_summaries from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op +from tensorflow.contrib.summary.summary_ops import SummaryWriter diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index f6be99f6ae..a72c0c80aa 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -27,6 +27,7 @@ import time import six from tensorflow.contrib.summary import gen_summary_ops +from tensorflow.core.framework import graph_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -99,25 +100,32 @@ def never_record_summaries(): class SummaryWriter(object): - """Encapsulates a summary writer.""" + """Encapsulates a stateful summary writer resource. - def __init__(self, resource): + See also: + - @{tf.contrib.summary.create_summary_file_writer} + - @{tf.contrib.summary.create_summary_db_writer} + """ + + def __init__(self, resource): self._resource = resource if context.in_eager_mode(): self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="cpu:0") def set_as_default(self): + """Enables this summary writer for the current thread.""" context.context().summary_writer_resource = self._resource @tf_contextlib.contextmanager def as_default(self): + """Enables summary writing within a `with` block.""" if self._resource is None: - yield + yield self else: old = context.context().summary_writer_resource context.context().summary_writer_resource = self._resource - yield + yield self # Flushes the summary writer in eager mode or in graph functions, but not # in legacy graph mode (you're on your own there). with ops.device("cpu:0"): @@ -125,6 +133,43 @@ class SummaryWriter(object): context.context().summary_writer_resource = old +def initialize( + graph=None, # pylint: disable=redefined-outer-name + session=None): + """Initializes summary writing for graph execution mode. + + This helper method provides a higher-level alternative to using + @{tf.contrib.summary.summary_writer_initializer_op} and + @{tf.contrib.summary.graph}. + + Most users will also want to call @{tf.train.create_global_step} + which can happen before or after this function is called. + + Args: + graph: A @{tf.Graph} or @{tf.GraphDef} to output to the writer. + This function will not write the default graph by default. When + writing to an event log file, the associated step will be zero. + session: So this method can call @{tf.Session.run}. This defaults + to @{tf.get_default_session}. + + Raises: + RuntimeError: If in eager mode, or if the current thread has no + default @{tf.contrib.summary.SummaryWriter}. + ValueError: If session wasn't passed and no default session. + """ + if context.context().summary_writer_resource is None: + raise RuntimeError("No default tf.contrib.summary.SummaryWriter found") + if session is None: + session = ops.get_default_session() + if session is None: + raise ValueError("session must be passed if no default session exists") + session.run(summary_writer_initializer_op()) + if graph is not None: + data = _serialize_graph(graph) + x = array_ops.placeholder(dtypes.string) + session.run(_graph(x, 0), feed_dict={x: data}) + + def create_summary_file_writer(logdir, max_queue=None, flush_millis=None, @@ -192,10 +237,10 @@ def create_summary_db_writer(db_uri, Experiment will not be associated with a User. Must be valid as both a DNS label and Linux username. name: Shared name for this SummaryWriter resource stored to default - Graph. + @{tf.Graph}. Returns: - A new SummaryWriter instance. + A @{tf.contrib.summary.SummaryWriter} instance. """ with ops.device("cpu:0"): if experiment_name is None: @@ -240,7 +285,16 @@ def _nothing(): def all_summary_ops(): - """Graph-mode only. Returns all summary ops.""" + """Graph-mode only. Returns all summary ops. + + Please note this excludes @{tf.contrib.summary.graph} ops. + + Returns: + The summary ops. + + Raises: + RuntimeError: If in Eager mode. + """ if context.in_eager_mode(): raise RuntimeError( "tf.contrib.summary.all_summary_ops is only supported in graph mode.") @@ -248,7 +302,14 @@ def all_summary_ops(): def summary_writer_initializer_op(): - """Graph-mode only. Returns the list of ops to create all summary writers.""" + """Graph-mode only. Returns the list of ops to create all summary writers. + + Returns: + The initializer ops. + + Raises: + RuntimeError: If in Eager mode. + """ if context.in_eager_mode(): raise RuntimeError( "tf.contrib.summary.summary_writer_initializer_op is only " @@ -367,21 +428,72 @@ def audio(name, tensor, sample_rate, max_outputs, family=None, return summary_writer_function(name, tensor, function, family=family) +def graph(param, step=None, name=None): + """Writes a TensorFlow graph to the summary interface. + + The graph summary is, strictly speaking, not a summary. Conditions + like @{tf.contrib.summary.never_record_summaries} do not apply. Only + a single graph can be associated with a particular run. If multiple + graphs are written, then only the last one will be considered by + TensorBoard. + + When not using eager execution mode, the user should consider passing + the `graph` parameter to @{tf.contrib.summary.initialize} instead of + calling this function. Otherwise special care needs to be taken when + using the graph to record the graph. + + Args: + param: A @{tf.Tensor} containing a serialized graph proto. When + eager execution is enabled, this function will automatically + coerce @{tf.Graph}, @{tf.GraphDef}, and string types. + step: The global step variable. This doesn't have useful semantics + for graph summaries, but is used anyway, due to the structure of + event log files. This defaults to the global step. + name: A name for the operation (optional). + + Returns: + The created @{tf.Operation} or a @{tf.no_op} if summary writing has + not been enabled for this context. + + Raises: + TypeError: If `param` isn't already a @{tf.Tensor} in graph mode. + """ + if not context.in_eager_mode() and not isinstance(param, ops.Tensor): + raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph " + "mode, but was: %s" % type(param)) + writer = context.context().summary_writer_resource + if writer is None: + return control_flow_ops.no_op() + with ops.device("cpu:0"): + if step is None: + step = training_util.get_global_step() + else: + step = ops.convert_to_tensor(step, dtypes.int64) + if isinstance(param, (ops.Graph, graph_pb2.GraphDef)): + tensor = ops.convert_to_tensor(_serialize_graph(param), dtypes.string) + else: + tensor = array_ops.identity(param) + return gen_summary_ops.write_graph_summary(writer, step, tensor, name=name) + +_graph = graph # for functions with a graph parameter + + def import_event(tensor, name=None): - """Writes a tf.Event binary proto. + """Writes a @{tf.Event} binary proto. When using create_summary_db_writer(), this can be used alongside - tf.TFRecordReader to load event logs into the database. Please note - that this is lower level than the other summary functions and will - ignore any conditions set by methods like should_record_summaries(). + @{tf.TFRecordReader} to load event logs into the database. Please + note that this is lower level than the other summary functions and + will ignore any conditions set by methods like + @{tf.contrib.summary.should_record_summaries}. Args: - tensor: A `Tensor` of type `string` containing a serialized `Event` - proto. + tensor: A @{tf.Tensor} of type `string` containing a serialized + @{tf.Event} proto. name: A name for the operation (optional). Returns: - The created Operation. + The created @{tf.Operation}. """ return gen_summary_ops.import_event( context.context().summary_writer_resource, tensor, name=name) @@ -390,3 +502,10 @@ def import_event(tensor, name=None): def eval_dir(model_dir, name=None): """Construct a logdir for an eval summary writer.""" return os.path.join(model_dir, "eval" if not name else "eval_" + name) + + +def _serialize_graph(arbitrary_graph): + if isinstance(arbitrary_graph, ops.Graph): + return arbitrary_graph.as_graph_def(add_shapes=True).SerializeToString() + else: + return arbitrary_graph.SerializeToString() diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py new file mode 100644 index 0000000000..8f85f67a25 --- /dev/null +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.summary import summary_ops +from tensorflow.contrib.summary import summary_test_internal +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + +get_all = summary_test_internal.get_all + + +class DbTest(summary_test_internal.SummaryDbTest): + + def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self): + with self.assertRaises(TypeError): + summary_ops.graph(ops.Graph()) + with self.assertRaises(TypeError): + summary_ops.graph('') + + def testGraphSummary(self): + training_util.get_or_create_global_step() + name = 'hi' + graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) + with self.test_session(): + with self.create_summary_db_writer().as_default(): + summary_ops.initialize(graph=graph) + six.assertCountEqual(self, [name], + get_all(self.db, 'SELECT node_name FROM Nodes')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 6e1a746815..09169fa6d7 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -12,20 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools -import os import tempfile import six -import sqlite3 from tensorflow.contrib.summary import summary_ops +from tensorflow.contrib.summary import summary_test_internal from tensorflow.contrib.summary import summary_test_util +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import dtypes @@ -36,6 +35,9 @@ from tensorflow.python.ops import state_ops from tensorflow.python.platform import gfile from tensorflow.python.training import training_util +get_all = summary_test_internal.get_all +get_one = summary_test_internal.get_one + class TargetTest(test_util.TensorFlowTestCase): @@ -108,22 +110,7 @@ class TargetTest(test_util.TensorFlowTestCase): self.assertEqual(events[1].summary.value[0].tag, 'scalar') -class DbTest(test_util.TensorFlowTestCase): - - def setUp(self): - self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') - if os.path.exists(self.db_path): - os.unlink(self.db_path) - self.db = sqlite3.connect(self.db_path) - self.create_summary_db_writer = functools.partial( - summary_ops.create_summary_db_writer, - db_uri=self.db_path, - experiment_name='experiment', - run_name='run', - user_name='user') - - def tearDown(self): - self.db.close() +class DbTest(summary_test_internal.SummaryDbTest): def testIntegerSummaries(self): step = training_util.create_global_step() @@ -186,13 +173,15 @@ class DbTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): self.create_summary_db_writer(user_name='@') - -def get_one(db, q, *p): - return db.execute(q, p).fetchone()[0] - - -def get_all(db, q, *p): - return unroll(db.execute(q, p).fetchall()) + def testGraphSummary(self): + training_util.get_or_create_global_step() + name = 'hi' + graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) + with summary_ops.always_record_summaries(): + with self.create_summary_db_writer().as_default(): + summary_ops.graph(graph) + six.assertCountEqual(self, [name], + get_all(self.db, 'SELECT node_name FROM Nodes')) def get_tensor(db, tag_id, step): @@ -205,9 +194,5 @@ def int64(x): return array_ops.constant(x, dtypes.int64) -def unroll(list_of_tuples): - return sum(list_of_tuples, ()) - - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/summary/summary_test_internal.py b/tensorflow/contrib/summary/summary_test_internal.py new file mode 100644 index 0000000000..54233f2f50 --- /dev/null +++ b/tensorflow/contrib/summary/summary_test_internal.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Internal helpers for tests in this directory.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os +import sqlite3 + +from tensorflow.contrib.summary import summary_ops +from tensorflow.python.framework import test_util + + +class SummaryDbTest(test_util.TensorFlowTestCase): + """Helper for summary database testing.""" + + def setUp(self): + super(SummaryDbTest, self).setUp() + self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') + if os.path.exists(self.db_path): + os.unlink(self.db_path) + self.db = sqlite3.connect(self.db_path) + self.create_summary_db_writer = functools.partial( + summary_ops.create_summary_db_writer, + db_uri=self.db_path, + experiment_name='experiment', + run_name='run', + user_name='user') + + def tearDown(self): + self.db.close() + super(SummaryDbTest, self).tearDown() + + +def get_one(db, q, *p): + return db.execute(q, p).fetchone()[0] + + +def get_all(db, q, *p): + return unroll(db.execute(q, p).fetchall()) + + +def unroll(list_of_tuples): + return sum(list_of_tuples, ()) diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc index 98fff9e0ae..d63b2c6cc2 100644 --- a/tensorflow/contrib/tensorboard/db/schema.cc +++ b/tensorflow/contrib/tensorboard/db/schema.cc @@ -135,8 +135,7 @@ class SqliteSchema { /// the database. This field will be mutated if the run is /// restarted. /// description: Optional markdown information. - /// graph: Snappy tf.GraphDef proto with node field cleared. That - /// field can be recreated using GraphNodes and NodeDefs. + /// graph_id: ID of associated Graphs row. Status CreateRunsTable() { return Run(R"sql( CREATE TABLE IF NOT EXISTS Runs ( @@ -147,7 +146,7 @@ class SqliteSchema { inserted_time REAL, started_time REAL, description TEXT, - graph BLOB + graph_id INTEGER ) )sql"); } @@ -205,46 +204,78 @@ class SqliteSchema { )sql"); } - /// \brief Creates NodeDefs table. - /// - /// This table stores NodeDef protos which define the GraphDef for a - /// Run. This functions like a hash table so rows can be shared by - /// multiple Runs in an Experiment. + /// \brief Creates Graphs table. /// /// Fields: /// rowid: Ephemeral b-tree ID dictating locality. - /// experiment_id: Optional int64 for grouping rows. - /// node_def_id: Permanent >0 unique ID. - /// fingerprint: Optional farmhash::Fingerprint64() of uncompressed - /// node_def bytes, coerced to int64. - /// node_def: BLOB containing a Snappy tf.NodeDef proto. - Status CreateNodeDefsTable() { + /// graph_id: Permanent >0 unique ID. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the wall time of when the row was inserted into the + /// DB. It may be used as a hint for an archival job. + /// node_def: Contains Snappy tf.GraphDef proto. All fields will be + /// cleared except those not expressed in SQL. + Status CreateGraphsTable() { return Run(R"sql( - CREATE TABLE IF NOT EXISTS NodeDefs ( + CREATE TABLE IF NOT EXISTS Graphs ( rowid INTEGER PRIMARY KEY, - experiment_id INTEGER, - node_def_id INTEGER NOT NULL, - fingerprint INTEGER, - node_def TEXT + graph_id INTEGER NOT NULL, + inserted_time REAL, + graph_def BLOB ) )sql"); } - /// \brief Creates RunNodeDefs table. + /// \brief Creates Nodes table. /// - /// Table mapping Runs to NodeDefs. This is used to recreate the node - /// field of the GraphDef proto. + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// graph_id: Permanent >0 unique ID. + /// node_id: ID for this node. This is more like a 0-index within + /// the Graph. Please note indexes are allowed to be removed. + /// node_name: Unique name for this Node within Graph. This is + /// copied from the proto so it can be indexed. This is allowed + /// to be NULL to save space on the index, in which case the + /// node_def.name proto field must not be cleared. + /// op: Copied from tf.NodeDef proto. + /// device: Copied from tf.NodeDef proto. + /// node_def: Contains Snappy tf.NodeDef proto. All fields will be + /// cleared except those not expressed in SQL. + Status CreateNodesTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Nodes ( + rowid INTEGER PRIMARY KEY, + graph_id INTEGER NOT NULL, + node_id INTEGER NOT NULL, + node_name TEXT, + op TEXT, + device TEXT, + node_def BLOB + ) + )sql"); + } + + /// \brief Creates NodeInputs table. /// /// Fields: /// rowid: Ephemeral b-tree ID dictating locality. - /// run_id: Mandatory ID of associated Run. - /// node_def_id: Mandatory ID of associated NodeDef. - Status CreateRunNodeDefsTable() { + /// graph_id: Permanent >0 unique ID. + /// node_id: Index of Node in question. This can be considered the + /// 'to' vertex. + /// idx: Used for ordering inputs on a given Node. + /// input_node_id: Nodes.node_id of the corresponding input node. + /// This can be considered the 'from' vertex. + /// is_control: If non-zero, indicates this input is a controlled + /// dependency, which means this isn't an edge through which + /// tensors flow. NULL means 0. + Status CreateNodeInputsTable() { return Run(R"sql( - CREATE TABLE IF NOT EXISTS RunNodeDefs ( + CREATE TABLE IF NOT EXISTS NodeInputs ( rowid INTEGER PRIMARY KEY, - run_id INTEGER NOT NULL, - node_def_id INTEGER NOT NULL + graph_id INTEGER NOT NULL, + node_id INTEGER NOT NULL, + idx INTEGER NOT NULL, + input_node_id INTEGER NOT NULL, + is_control INTEGER ) )sql"); } @@ -297,11 +328,27 @@ class SqliteSchema { )sql"); } - /// \brief Uniquely indexes node_def_id on NodeDefs table. - Status CreateNodeDefIdIndex() { + /// \brief Uniquely indexes graph_id on Graphs table. + Status CreateGraphIdIndex() { return Run(R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS NodeDefIdIndex - ON NodeDefs (node_def_id) + CREATE UNIQUE INDEX IF NOT EXISTS GraphIdIndex + ON Graphs (graph_id) + )sql"); + } + + /// \brief Uniquely indexes (graph_id, node_id) on Nodes table. + Status CreateNodeIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS NodeIdIndex + ON Nodes (graph_id, node_id) + )sql"); + } + + /// \brief Uniquely indexes (graph_id, node_id, idx) on NodeInputs table. + Status CreateNodeInputsIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS NodeInputsIndex + ON NodeInputs (graph_id, node_id, idx) )sql"); } @@ -350,20 +397,12 @@ class SqliteSchema { )sql"); } - /// \brief Indexes (experiment_id, fingerprint) on NodeDefs table. - Status CreateNodeDefFingerprintIndex() { - return Run(R"sql( - CREATE INDEX IF NOT EXISTS NodeDefFingerprintIndex - ON NodeDefs (experiment_id, fingerprint) - WHERE fingerprint IS NOT NULL - )sql"); - } - - /// \brief Uniquely indexes (run_id, node_def_id) on RunNodeDefs table. - Status CreateRunNodeDefIndex() { + /// \brief Uniquely indexes (graph_id, node_name) on Nodes table. + Status CreateNodeNameIndex() { return Run(R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS RunNodeDefIndex - ON RunNodeDefs (run_id, node_def_id) + CREATE UNIQUE INDEX IF NOT EXISTS NodeNameIndex + ON Nodes (graph_id, node_name) + WHERE node_name IS NOT NULL )sql"); } @@ -387,22 +426,24 @@ Status SetupTensorboardSqliteDb(std::shared_ptr<Sqlite> db) { TF_RETURN_IF_ERROR(s.CreateRunsTable()); TF_RETURN_IF_ERROR(s.CreateExperimentsTable()); TF_RETURN_IF_ERROR(s.CreateUsersTable()); - TF_RETURN_IF_ERROR(s.CreateNodeDefsTable()); - TF_RETURN_IF_ERROR(s.CreateRunNodeDefsTable()); + TF_RETURN_IF_ERROR(s.CreateGraphsTable()); + TF_RETURN_IF_ERROR(s.CreateNodeInputsTable()); + TF_RETURN_IF_ERROR(s.CreateNodesTable()); TF_RETURN_IF_ERROR(s.CreateTensorIndex()); TF_RETURN_IF_ERROR(s.CreateTensorChunkIndex()); TF_RETURN_IF_ERROR(s.CreateTagIdIndex()); TF_RETURN_IF_ERROR(s.CreateRunIdIndex()); TF_RETURN_IF_ERROR(s.CreateExperimentIdIndex()); TF_RETURN_IF_ERROR(s.CreateUserIdIndex()); - TF_RETURN_IF_ERROR(s.CreateNodeDefIdIndex()); + TF_RETURN_IF_ERROR(s.CreateGraphIdIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeIdIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeInputsIndex()); TF_RETURN_IF_ERROR(s.CreateTagNameIndex()); TF_RETURN_IF_ERROR(s.CreateRunNameIndex()); TF_RETURN_IF_ERROR(s.CreateExperimentNameIndex()); TF_RETURN_IF_ERROR(s.CreateUserNameIndex()); TF_RETURN_IF_ERROR(s.CreateUserEmailIndex()); - TF_RETURN_IF_ERROR(s.CreateNodeDefFingerprintIndex()); - TF_RETURN_IF_ERROR(s.CreateRunNodeDefIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeNameIndex()); return Status::OK(); } diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index a26ad61660..ae063d24ef 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -15,17 +15,29 @@ limitations under the License. #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" #include "tensorflow/contrib/tensorboard/db/schema.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/snappy.h" #include "tensorflow/core/util/event.pb.h" namespace tensorflow { namespace { +double GetWallTime(Env* env) { + // TODO(@jart): Follow precise definitions for time laid out in schema. + // TODO(@jart): Use monotonic clock from gRPC codebase. + return static_cast<double>(env->NowMicros()) / 1.0e6; +} + int64 MakeRandomId() { + // TODO(@jart): Try generating ID in 2^24 space, falling back to 2^63 + // https://sqlite.org/src4/doc/trunk/www/varint.wiki int64 id = static_cast<int64>(random::New64() & ((1ULL << 63) - 1)); if (id == 0) { ++id; @@ -33,10 +45,201 @@ int64 MakeRandomId() { return id; } +Status Serialize(const protobuf::MessageLite& proto, string* output) { + output->clear(); + if (!proto.SerializeToString(output)) { + return errors::DataLoss("SerializeToString failed"); + } + return Status::OK(); +} + +Status Compress(const string& data, string* output) { + output->clear(); + if (!port::Snappy_Compress(data.data(), data.size(), output)) { + return errors::FailedPrecondition("TensorBase needs Snappy"); + } + return Status::OK(); +} + +Status BindProto(SqliteStatement* stmt, int parameter, + const protobuf::MessageLite& proto) { + string serialized; + TF_RETURN_IF_ERROR(Serialize(proto, &serialized)); + string compressed; + TF_RETURN_IF_ERROR(Compress(serialized, &compressed)); + stmt->BindBlobUnsafe(parameter, compressed); + return Status::OK(); +} + +Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) { + // TODO(@jart): Make portable between little and big endian systems. + // TODO(@jart): Use TensorChunks with minimal copying for big tensors. + // TODO(@jart): Add field to indicate encoding. + // TODO(@jart): Allow crunch tool to re-compress with zlib instead. + TensorProto p; + t.AsProtoTensorContent(&p); + return BindProto(stmt, parameter, p); +} + +class Transactor { + public: + explicit Transactor(std::shared_ptr<Sqlite> db) + : db_(std::move(db)), + begin_(db_->Prepare("BEGIN TRANSACTION")), + commit_(db_->Prepare("COMMIT TRANSACTION")), + rollback_(db_->Prepare("ROLLBACK TRANSACTION")) {} + + template <typename T, typename... Args> + Status Transact(T callback, Args&&... args) { + TF_RETURN_IF_ERROR(begin_.StepAndReset()); + Status s = callback(std::forward<Args>(args)...); + if (s.ok()) { + TF_RETURN_IF_ERROR(commit_.StepAndReset()); + } else { + TF_RETURN_WITH_CONTEXT_IF_ERROR(rollback_.StepAndReset(), s.ToString()); + } + return s; + } + + private: + std::shared_ptr<Sqlite> db_; + SqliteStatement begin_; + SqliteStatement commit_; + SqliteStatement rollback_; +}; + +class GraphSaver { + public: + static Status SaveToRun(Env* env, Sqlite* db, GraphDef* graph, int64 run_id) { + auto get = db->Prepare("SELECT graph_id FROM Runs WHERE run_id = ?"); + get.BindInt(1, run_id); + bool is_done; + TF_RETURN_IF_ERROR(get.Step(&is_done)); + int64 graph_id = is_done ? 0 : get.ColumnInt(0); + if (graph_id == 0) { + graph_id = MakeRandomId(); + // TODO(@jart): Check for ID collision. + auto set = db->Prepare("UPDATE Runs SET graph_id = ? WHERE run_id = ?"); + set.BindInt(1, graph_id); + set.BindInt(2, run_id); + TF_RETURN_IF_ERROR(set.StepAndReset()); + } + return Save(env, db, graph, graph_id); + } + + static Status Save(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) { + GraphSaver saver{env, db, graph, graph_id}; + saver.MapNameToNodeId(); + TF_RETURN_IF_ERROR(saver.SaveNodeInputs()); + TF_RETURN_IF_ERROR(saver.SaveNodes()); + TF_RETURN_IF_ERROR(saver.SaveGraph()); + return Status::OK(); + } + + private: + GraphSaver(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) + : env_(env), db_(db), graph_(graph), graph_id_(graph_id) {} + + void MapNameToNodeId() { + size_t toto = static_cast<size_t>(graph_->node_size()); + name_copies_.reserve(toto); + name_to_node_id_.reserve(toto); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + // Copy name into memory region, since we call clear_name() later. + // Then wrap in StringPiece so we can compare slices without copy. + name_copies_.emplace_back(graph_->node(node_id).name()); + name_to_node_id_.emplace(name_copies_.back(), node_id); + } + } + + Status SaveNodeInputs() { + auto purge = db_->Prepare("DELETE FROM NodeInputs WHERE graph_id = ?"); + purge.BindInt(1, graph_id_); + TF_RETURN_IF_ERROR(purge.StepAndReset()); + auto insert = db_->Prepare(R"sql( + INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control) + VALUES (?, ?, ?, ?, ?) + )sql"); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + const NodeDef& node = graph_->node(node_id); + for (int idx = 0; idx < node.input_size(); ++idx) { + StringPiece name = node.input(idx); + insert.BindInt(1, graph_id_); + insert.BindInt(2, node_id); + insert.BindInt(3, idx); + if (!name.empty() && name[0] == '^') { + name.remove_prefix(1); + insert.BindInt(5, 1); + } + auto e = name_to_node_id_.find(name); + if (e == name_to_node_id_.end()) { + return errors::DataLoss("Could not find node: ", name); + } + insert.BindInt(4, e->second); + TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(), + " -> ", name); + } + } + return Status::OK(); + } + + Status SaveNodes() { + auto purge = db_->Prepare("DELETE FROM Nodes WHERE graph_id = ?"); + purge.BindInt(1, graph_id_); + TF_RETURN_IF_ERROR(purge.StepAndReset()); + auto insert = db_->Prepare(R"sql( + INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def) + VALUES (?, ?, ?, ?, ?, ?) + )sql"); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + NodeDef* node = graph_->mutable_node(node_id); + insert.BindInt(1, graph_id_); + insert.BindInt(2, node_id); + insert.BindText(3, node->name()); + node->clear_name(); + if (!node->op().empty()) { + insert.BindText(4, node->op()); + node->clear_op(); + } + if (!node->device().empty()) { + insert.BindText(5, node->device()); + node->clear_device(); + } + node->clear_input(); + TF_RETURN_IF_ERROR(BindProto(&insert, 6, *node)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name()); + } + return Status::OK(); + } + + Status SaveGraph() { + auto insert = db_->Prepare(R"sql( + INSERT OR REPLACE INTO Graphs (graph_id, inserted_time, graph_def) + VALUES (?, ?, ?) + )sql"); + insert.BindInt(1, graph_id_); + insert.BindDouble(2, GetWallTime(env_)); + graph_->clear_node(); + TF_RETURN_IF_ERROR(BindProto(&insert, 3, *graph_)); + return insert.StepAndReset(); + } + + Env* env_; + Sqlite* db_; + GraphDef* graph_; + int64 graph_id_; + std::vector<string> name_copies_; + std::unordered_map<StringPiece, int64, StringPiece::Hasher> name_to_node_id_; +}; + class SummaryDbWriter : public SummaryWriterInterface { public: SummaryDbWriter(Env* env, std::shared_ptr<Sqlite> db) - : SummaryWriterInterface(), env_(env), db_(std::move(db)), run_id_(-1) {} + : SummaryWriterInterface(), + env_(env), + db_(std::move(db)), + txn_(db_), + run_id_{0LL} {} ~SummaryDbWriter() override {} Status Initialize(const string& experiment_name, const string& run_name, @@ -76,7 +279,7 @@ class SummaryDbWriter : public SummaryWriterInterface { // TODO(@jart): Check for random ID collisions without needing txn retry. insert_tensor_.BindInt(1, tag_id); insert_tensor_.BindInt(2, global_step); - insert_tensor_.BindDouble(3, GetWallTime()); + insert_tensor_.BindDouble(3, GetWallTime(env_)); switch (t.dtype()) { case DT_INT64: insert_tensor_.BindInt(4, t.scalar<int64>()()); @@ -85,22 +288,41 @@ class SummaryDbWriter : public SummaryWriterInterface { insert_tensor_.BindDouble(4, t.scalar<double>()()); break; default: - TF_RETURN_IF_ERROR(BindTensor(t)); + TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t)); break; } return insert_tensor_.StepAndReset(); } - Status WriteEvent(std::unique_ptr<Event> e) override { + Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override { mutex_lock ml(mu_); TF_RETURN_IF_ERROR(InitializeParents()); - if (e->what_case() == Event::WhatCase::kSummary) { - const Summary& summary = e->summary(); - for (int i = 0; i < summary.value_size(); ++i) { - TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i))); + return txn_.Transact(GraphSaver::SaveToRun, env_, db_.get(), g.get(), + run_id_); + } + + Status WriteEvent(std::unique_ptr<Event> e) override { + switch (e->what_case()) { + case Event::WhatCase::kSummary: { + mutex_lock ml(mu_); + TF_RETURN_IF_ERROR(InitializeParents()); + const Summary& summary = e->summary(); + for (int i = 0; i < summary.value_size(); ++i) { + TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i))); + } + return Status::OK(); } + case Event::WhatCase::kGraphDef: { + std::unique_ptr<GraphDef> graph{new GraphDef}; + if (!ParseProtoUnlimited(graph.get(), e->graph_def())) { + return errors::DataLoss("parse event.graph_def failed"); + } + return WriteGraph(e->step(), std::move(graph)); + } + default: + // TODO(@jart): Handle other stuff. + return Status::OK(); } - return Status::OK(); } Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { @@ -136,33 +358,8 @@ class SummaryDbWriter : public SummaryWriterInterface { string DebugString() override { return "SummaryDbWriter"; } private: - double GetWallTime() { - // TODO(@jart): Follow precise definitions for time laid out in schema. - // TODO(@jart): Use monotonic clock from gRPC codebase. - return static_cast<double>(env_->NowMicros()) / 1.0e6; - } - - Status BindTensor(const Tensor& t) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // TODO(@jart): Make portable between little and big endian systems. - // TODO(@jart): Use TensorChunks with minimal copying for big tensors. - TensorProto p; - t.AsProtoTensorContent(&p); - string encoded; - if (!p.SerializeToString(&encoded)) { - return errors::DataLoss("SerializeToString failed"); - } - // TODO(@jart): Put byte at beginning of blob to indicate encoding. - // TODO(@jart): Allow crunch tool to re-compress with zlib instead. - string compressed; - if (!port::Snappy_Compress(encoded.data(), encoded.size(), &compressed)) { - return errors::FailedPrecondition("TensorBase needs Snappy"); - } - insert_tensor_.BindBlobUnsafe(4, compressed); - return Status::OK(); - } - Status InitializeParents() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (run_id_ >= 0) { + if (run_id_ > 0) { return Status::OK(); } int64 user_id; @@ -195,7 +392,7 @@ class SummaryDbWriter : public SummaryWriterInterface { )sql"); insert_user.BindInt(1, *user_id); insert_user.BindText(2, user_name); - insert_user.BindDouble(3, GetWallTime()); + insert_user.BindDouble(3, GetWallTime(env_)); TF_RETURN_IF_ERROR(insert_user.StepAndReset()); } return Status::OK(); @@ -249,7 +446,7 @@ class SummaryDbWriter : public SummaryWriterInterface { } insert.BindInt(2, *id); insert.BindText(3, name); - insert.BindDouble(4, GetWallTime()); + insert.BindDouble(4, GetWallTime(env_)); TF_RETURN_IF_ERROR(insert.StepAndReset()); } return Status::OK(); @@ -276,6 +473,7 @@ class SummaryDbWriter : public SummaryWriterInterface { mutex mu_; Env* env_; std::shared_ptr<Sqlite> db_ GUARDED_BY(mu_); + Transactor txn_ GUARDED_BY(mu_); SqliteStatement insert_tensor_ GUARDED_BY(mu_); SqliteStatement update_metadata_ GUARDED_BY(mu_); string user_name_ GUARDED_BY(mu_); diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index c1af51e7b7..3431842ca2 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/db/sqlite.h" @@ -212,5 +214,81 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { kTolerance); } +TEST_F(SummaryDbWriterTest, WriteGraph) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_)); + env_.AdvanceByMillis(23); + GraphDef graph; + NodeDef* node = graph.add_node(); + node->set_name("x"); + node->set_op("Placeholder"); + node = graph.add_node(); + node->set_name("y"); + node->set_op("Placeholder"); + node = graph.add_node(); + node->set_name("z"); + node->set_op("Love"); + node = graph.add_node(); + node->set_name("+"); + node->set_op("Add"); + node->add_input("x"); + node->add_input("y"); + node->add_input("^z"); + node->set_device("tpu/lol"); + std::unique_ptr<Event> e{new Event}; + graph.SerializeToString(e->mutable_graph_def()); + TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Graphs")); + ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes")); + ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs")); + + int64 graph_id = QueryInt("SELECT graph_id FROM Graphs"); + EXPECT_GT(graph_id, 0LL); + EXPECT_EQ(graph_id, QueryInt("SELECT graph_id FROM Runs")); + EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs")); + EXPECT_FALSE(QueryString("SELECT graph_def FROM Graphs").empty()); + + EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("z", QueryString("SELECT node_name FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("+", QueryString("SELECT node_name FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ("Placeholder", + QueryString("SELECT op FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("Placeholder", + QueryString("SELECT op FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("Love", QueryString("SELECT op FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("Add", QueryString("SELECT op FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("tpu/lol", + QueryString("SELECT device FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(0LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(1LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(2LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2")); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5e19effe3d..b7386abdea 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6247,6 +6247,7 @@ tf_kernel_library( "//tensorflow/contrib/tensorboard/db:summary_db_writer", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:summary_ops_op_lib", "//tensorflow/core/lib/db:sqlite", ], diff --git a/tensorflow/core/kernels/summary_interface.cc b/tensorflow/core/kernels/summary_interface.cc index cd366f8c13..ad28d77ffd 100644 --- a/tensorflow/core/kernels/summary_interface.cc +++ b/tensorflow/core/kernels/summary_interface.cc @@ -17,6 +17,7 @@ limitations under the License. #include <utility> #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/summary.pb.h" @@ -393,6 +394,15 @@ class SummaryWriterImpl : public SummaryWriterInterface { return WriteEvent(std::move(e)); } + Status WriteGraph(int64 global_step, + std::unique_ptr<GraphDef> graph) override { + std::unique_ptr<Event> e{new Event}; + e->set_step(global_step); + e->set_wall_time(GetWallTime()); + graph->SerializeToString(e->mutable_graph_def()); + return WriteEvent(std::move(e)); + } + Status WriteEvent(std::unique_ptr<Event> event) override { mutex_lock ml(mu_); queue_.emplace_back(std::move(event)); diff --git a/tensorflow/core/kernels/summary_interface.h b/tensorflow/core/kernels/summary_interface.h index ccf3459e56..da1c28709f 100644 --- a/tensorflow/core/kernels/summary_interface.h +++ b/tensorflow/core/kernels/summary_interface.h @@ -17,6 +17,7 @@ limitations under the License. #include <memory> +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/util/event.pb.h" @@ -46,6 +47,9 @@ class SummaryWriterInterface : public ResourceBase { virtual Status WriteAudio(int64 global_step, Tensor t, const string& tag, int max_outputs_, float sample_rate) = 0; + virtual Status WriteGraph(int64 global_step, + std::unique_ptr<GraphDef> graph) = 0; + virtual Status WriteEvent(std::unique_ptr<Event> e) = 0; }; diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index 1fe2fc5b66..3706f51cf4 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/kernels/summary_interface.h" @@ -268,4 +269,28 @@ class WriteAudioSummaryOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU), WriteAudioSummaryOp); +class WriteGraphSummaryOp : public OpKernel { + public: + explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + const Tensor* t; + OP_REQUIRES_OK(ctx, ctx->input("global_step", &t)); + const int64 global_step = t->scalar<int64>()(); + OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); + std::unique_ptr<GraphDef> graph{new GraphDef}; + if (!ParseProtoUnlimited(graph.get(), t->scalar<string>()())) { + ctx->CtxFailureWithWarning( + errors::DataLoss("Bad tf.GraphDef binary proto tensor string")); + return; + } + OP_REQUIRES_OK(ctx, s->WriteGraph(global_step, std::move(graph))); + } +}; +REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU), + WriteGraphSummaryOp); + } // namespace tensorflow diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc index 5efbac7ad7..7f6d8b06cd 100644 --- a/tensorflow/core/ops/summary_ops.cc +++ b/tensorflow/core/ops/summary_ops.cc @@ -256,4 +256,17 @@ sample_rate: The sample rate of the signal in hertz. max_outputs: Max number of batch elements to generate audio for. )doc"); +REGISTER_OP("WriteGraphSummary") + .Input("writer: resource") + .Input("global_step: int64") + .Input("tensor: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Writes a `GraphDef` protocol buffer to a `SummaryWriter`. + +writer: Handle of `SummaryWriter`. +global_step: The step to write the summary for. +tensor: A scalar string of the serialized tf.GraphDef proto. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index cc46dd5162..3677aaa886 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -66,6 +66,9 @@ BLACKLIST = [ "//tensorflow/contrib/timeseries/examples:data/period_trend.csv", # pylint:disable=line-too-long "//tensorflow/contrib/timeseries/python/timeseries:test_utils", "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:test_utils", # pylint:disable=line-too-long + + # TODO(yifeif): Remove when py_library(testonly=1) is ignored. + "//tensorflow/contrib/summary:summary_test_internal", ] |