aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2017-11-15 11:31:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-15 11:35:49 -0800
commit6fb721d608c4cd3855fe8793099a629428b9853c (patch)
treefaef08ed8bac4f5a8b065825a4405ef8a12e875f /tensorflow
parentb7b183b90aee8a4f4808f7d90a2c7a54a942e640 (diff)
Add graph writer op to contrib/summary
This change also defines a simple SQL data model for tf.GraphDef, which should move us closer to a world where TensorBoard can render the graph explorer without having to download the entire thing to the browser, as that could potentially be hundreds of megabytes. PiperOrigin-RevId: 175854921
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/summary/BUILD29
-rw-r--r--tensorflow/contrib/summary/summary.py3
-rw-r--r--tensorflow/contrib/summary/summary_ops.py149
-rw-r--r--tensorflow/contrib/summary/summary_ops_graph_test.py52
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py47
-rw-r--r--tensorflow/contrib/summary/summary_test_internal.py59
-rw-r--r--tensorflow/contrib/tensorboard/db/schema.cc141
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer.cc272
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc78
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/summary_interface.cc10
-rw-r--r--tensorflow/core/kernels/summary_interface.h4
-rw-r--r--tensorflow/core/kernels/summary_kernels.cc25
-rw-r--r--tensorflow/core/ops/summary_ops.cc13
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py3
15 files changed, 751 insertions, 135 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index d1beafcb28..3892654f25 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -25,13 +25,12 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":summary_ops",
+ ":summary_test_internal",
":summary_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
@@ -41,6 +40,20 @@ py_test(
],
)
+py_test(
+ name = "summary_ops_graph_test",
+ srcs = ["summary_ops_graph_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":summary_ops",
+ ":summary_test_internal",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ ],
+)
+
py_library(
name = "summary_ops",
srcs = ["summary_ops.py"],
@@ -98,3 +111,15 @@ py_library(
"//tensorflow/python:platform",
],
)
+
+py_library(
+ name = "summary_test_internal",
+ testonly = 1,
+ srcs = ["summary_test_internal.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ "//tensorflow/python:lib",
+ "//tensorflow/python:platform",
+ ],
+)
diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py
index 813e8b2b09..a73193f460 100644
--- a/tensorflow/contrib/summary/summary.py
+++ b/tensorflow/contrib/summary/summary.py
@@ -32,11 +32,14 @@ from tensorflow.contrib.summary.summary_ops import create_summary_db_writer
from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
from tensorflow.contrib.summary.summary_ops import eval_dir
from tensorflow.contrib.summary.summary_ops import generic
+from tensorflow.contrib.summary.summary_ops import graph
from tensorflow.contrib.summary.summary_ops import histogram
from tensorflow.contrib.summary.summary_ops import image
from tensorflow.contrib.summary.summary_ops import import_event
+from tensorflow.contrib.summary.summary_ops import initialize
from tensorflow.contrib.summary.summary_ops import never_record_summaries
from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps
from tensorflow.contrib.summary.summary_ops import scalar
from tensorflow.contrib.summary.summary_ops import should_record_summaries
from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op
+from tensorflow.contrib.summary.summary_ops import SummaryWriter
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index f6be99f6ae..a72c0c80aa 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -27,6 +27,7 @@ import time
import six
from tensorflow.contrib.summary import gen_summary_ops
+from tensorflow.core.framework import graph_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -99,25 +100,32 @@ def never_record_summaries():
class SummaryWriter(object):
- """Encapsulates a summary writer."""
+ """Encapsulates a stateful summary writer resource.
- def __init__(self, resource):
+ See also:
+ - @{tf.contrib.summary.create_summary_file_writer}
+ - @{tf.contrib.summary.create_summary_db_writer}
+ """
+
+ def __init__(self, resource):
self._resource = resource
if context.in_eager_mode():
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device="cpu:0")
def set_as_default(self):
+ """Enables this summary writer for the current thread."""
context.context().summary_writer_resource = self._resource
@tf_contextlib.contextmanager
def as_default(self):
+ """Enables summary writing within a `with` block."""
if self._resource is None:
- yield
+ yield self
else:
old = context.context().summary_writer_resource
context.context().summary_writer_resource = self._resource
- yield
+ yield self
# Flushes the summary writer in eager mode or in graph functions, but not
# in legacy graph mode (you're on your own there).
with ops.device("cpu:0"):
@@ -125,6 +133,43 @@ class SummaryWriter(object):
context.context().summary_writer_resource = old
+def initialize(
+ graph=None, # pylint: disable=redefined-outer-name
+ session=None):
+ """Initializes summary writing for graph execution mode.
+
+ This helper method provides a higher-level alternative to using
+ @{tf.contrib.summary.summary_writer_initializer_op} and
+ @{tf.contrib.summary.graph}.
+
+ Most users will also want to call @{tf.train.create_global_step}
+ which can happen before or after this function is called.
+
+ Args:
+ graph: A @{tf.Graph} or @{tf.GraphDef} to output to the writer.
+ This function will not write the default graph by default. When
+ writing to an event log file, the associated step will be zero.
+ session: So this method can call @{tf.Session.run}. This defaults
+ to @{tf.get_default_session}.
+
+ Raises:
+ RuntimeError: If in eager mode, or if the current thread has no
+ default @{tf.contrib.summary.SummaryWriter}.
+ ValueError: If session wasn't passed and no default session.
+ """
+ if context.context().summary_writer_resource is None:
+ raise RuntimeError("No default tf.contrib.summary.SummaryWriter found")
+ if session is None:
+ session = ops.get_default_session()
+ if session is None:
+ raise ValueError("session must be passed if no default session exists")
+ session.run(summary_writer_initializer_op())
+ if graph is not None:
+ data = _serialize_graph(graph)
+ x = array_ops.placeholder(dtypes.string)
+ session.run(_graph(x, 0), feed_dict={x: data})
+
+
def create_summary_file_writer(logdir,
max_queue=None,
flush_millis=None,
@@ -192,10 +237,10 @@ def create_summary_db_writer(db_uri,
Experiment will not be associated with a User. Must be valid as
both a DNS label and Linux username.
name: Shared name for this SummaryWriter resource stored to default
- Graph.
+ @{tf.Graph}.
Returns:
- A new SummaryWriter instance.
+ A @{tf.contrib.summary.SummaryWriter} instance.
"""
with ops.device("cpu:0"):
if experiment_name is None:
@@ -240,7 +285,16 @@ def _nothing():
def all_summary_ops():
- """Graph-mode only. Returns all summary ops."""
+ """Graph-mode only. Returns all summary ops.
+
+ Please note this excludes @{tf.contrib.summary.graph} ops.
+
+ Returns:
+ The summary ops.
+
+ Raises:
+ RuntimeError: If in Eager mode.
+ """
if context.in_eager_mode():
raise RuntimeError(
"tf.contrib.summary.all_summary_ops is only supported in graph mode.")
@@ -248,7 +302,14 @@ def all_summary_ops():
def summary_writer_initializer_op():
- """Graph-mode only. Returns the list of ops to create all summary writers."""
+ """Graph-mode only. Returns the list of ops to create all summary writers.
+
+ Returns:
+ The initializer ops.
+
+ Raises:
+ RuntimeError: If in Eager mode.
+ """
if context.in_eager_mode():
raise RuntimeError(
"tf.contrib.summary.summary_writer_initializer_op is only "
@@ -367,21 +428,72 @@ def audio(name, tensor, sample_rate, max_outputs, family=None,
return summary_writer_function(name, tensor, function, family=family)
+def graph(param, step=None, name=None):
+ """Writes a TensorFlow graph to the summary interface.
+
+ The graph summary is, strictly speaking, not a summary. Conditions
+ like @{tf.contrib.summary.never_record_summaries} do not apply. Only
+ a single graph can be associated with a particular run. If multiple
+ graphs are written, then only the last one will be considered by
+ TensorBoard.
+
+ When not using eager execution mode, the user should consider passing
+ the `graph` parameter to @{tf.contrib.summary.initialize} instead of
+ calling this function. Otherwise special care needs to be taken when
+ using the graph to record the graph.
+
+ Args:
+ param: A @{tf.Tensor} containing a serialized graph proto. When
+ eager execution is enabled, this function will automatically
+ coerce @{tf.Graph}, @{tf.GraphDef}, and string types.
+ step: The global step variable. This doesn't have useful semantics
+ for graph summaries, but is used anyway, due to the structure of
+ event log files. This defaults to the global step.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created @{tf.Operation} or a @{tf.no_op} if summary writing has
+ not been enabled for this context.
+
+ Raises:
+ TypeError: If `param` isn't already a @{tf.Tensor} in graph mode.
+ """
+ if not context.in_eager_mode() and not isinstance(param, ops.Tensor):
+ raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph "
+ "mode, but was: %s" % type(param))
+ writer = context.context().summary_writer_resource
+ if writer is None:
+ return control_flow_ops.no_op()
+ with ops.device("cpu:0"):
+ if step is None:
+ step = training_util.get_global_step()
+ else:
+ step = ops.convert_to_tensor(step, dtypes.int64)
+ if isinstance(param, (ops.Graph, graph_pb2.GraphDef)):
+ tensor = ops.convert_to_tensor(_serialize_graph(param), dtypes.string)
+ else:
+ tensor = array_ops.identity(param)
+ return gen_summary_ops.write_graph_summary(writer, step, tensor, name=name)
+
+_graph = graph # for functions with a graph parameter
+
+
def import_event(tensor, name=None):
- """Writes a tf.Event binary proto.
+ """Writes a @{tf.Event} binary proto.
When using create_summary_db_writer(), this can be used alongside
- tf.TFRecordReader to load event logs into the database. Please note
- that this is lower level than the other summary functions and will
- ignore any conditions set by methods like should_record_summaries().
+ @{tf.TFRecordReader} to load event logs into the database. Please
+ note that this is lower level than the other summary functions and
+ will ignore any conditions set by methods like
+ @{tf.contrib.summary.should_record_summaries}.
Args:
- tensor: A `Tensor` of type `string` containing a serialized `Event`
- proto.
+ tensor: A @{tf.Tensor} of type `string` containing a serialized
+ @{tf.Event} proto.
name: A name for the operation (optional).
Returns:
- The created Operation.
+ The created @{tf.Operation}.
"""
return gen_summary_ops.import_event(
context.context().summary_writer_resource, tensor, name=name)
@@ -390,3 +502,10 @@ def import_event(tensor, name=None):
def eval_dir(model_dir, name=None):
"""Construct a logdir for an eval summary writer."""
return os.path.join(model_dir, "eval" if not name else "eval_" + name)
+
+
+def _serialize_graph(arbitrary_graph):
+ if isinstance(arbitrary_graph, ops.Graph):
+ return arbitrary_graph.as_graph_def(add_shapes=True).SerializeToString()
+ else:
+ return arbitrary_graph.SerializeToString()
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py
new file mode 100644
index 0000000000..8f85f67a25
--- /dev/null
+++ b/tensorflow/contrib/summary/summary_ops_graph_test.py
@@ -0,0 +1,52 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib.summary import summary_ops
+from tensorflow.contrib.summary import summary_test_internal
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import training_util
+
+get_all = summary_test_internal.get_all
+
+
+class DbTest(summary_test_internal.SummaryDbTest):
+
+ def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self):
+ with self.assertRaises(TypeError):
+ summary_ops.graph(ops.Graph())
+ with self.assertRaises(TypeError):
+ summary_ops.graph('')
+
+ def testGraphSummary(self):
+ training_util.get_or_create_global_step()
+ name = 'hi'
+ graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
+ with self.test_session():
+ with self.create_summary_db_writer().as_default():
+ summary_ops.initialize(graph=graph)
+ six.assertCountEqual(self, [name],
+ get_all(self.db, 'SELECT node_name FROM Nodes'))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 6e1a746815..09169fa6d7 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -12,20 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
-import os
import tempfile
import six
-import sqlite3
from tensorflow.contrib.summary import summary_ops
+from tensorflow.contrib.summary import summary_test_internal
from tensorflow.contrib.summary import summary_test_util
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import node_def_pb2
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
@@ -36,6 +35,9 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
+get_all = summary_test_internal.get_all
+get_one = summary_test_internal.get_one
+
class TargetTest(test_util.TensorFlowTestCase):
@@ -108,22 +110,7 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
-class DbTest(test_util.TensorFlowTestCase):
-
- def setUp(self):
- self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
- if os.path.exists(self.db_path):
- os.unlink(self.db_path)
- self.db = sqlite3.connect(self.db_path)
- self.create_summary_db_writer = functools.partial(
- summary_ops.create_summary_db_writer,
- db_uri=self.db_path,
- experiment_name='experiment',
- run_name='run',
- user_name='user')
-
- def tearDown(self):
- self.db.close()
+class DbTest(summary_test_internal.SummaryDbTest):
def testIntegerSummaries(self):
step = training_util.create_global_step()
@@ -186,13 +173,15 @@ class DbTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
self.create_summary_db_writer(user_name='@')
-
-def get_one(db, q, *p):
- return db.execute(q, p).fetchone()[0]
-
-
-def get_all(db, q, *p):
- return unroll(db.execute(q, p).fetchall())
+ def testGraphSummary(self):
+ training_util.get_or_create_global_step()
+ name = 'hi'
+ graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
+ with summary_ops.always_record_summaries():
+ with self.create_summary_db_writer().as_default():
+ summary_ops.graph(graph)
+ six.assertCountEqual(self, [name],
+ get_all(self.db, 'SELECT node_name FROM Nodes'))
def get_tensor(db, tag_id, step):
@@ -205,9 +194,5 @@ def int64(x):
return array_ops.constant(x, dtypes.int64)
-def unroll(list_of_tuples):
- return sum(list_of_tuples, ())
-
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/summary/summary_test_internal.py b/tensorflow/contrib/summary/summary_test_internal.py
new file mode 100644
index 0000000000..54233f2f50
--- /dev/null
+++ b/tensorflow/contrib/summary/summary_test_internal.py
@@ -0,0 +1,59 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Internal helpers for tests in this directory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import os
+import sqlite3
+
+from tensorflow.contrib.summary import summary_ops
+from tensorflow.python.framework import test_util
+
+
+class SummaryDbTest(test_util.TensorFlowTestCase):
+ """Helper for summary database testing."""
+
+ def setUp(self):
+ super(SummaryDbTest, self).setUp()
+ self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
+ if os.path.exists(self.db_path):
+ os.unlink(self.db_path)
+ self.db = sqlite3.connect(self.db_path)
+ self.create_summary_db_writer = functools.partial(
+ summary_ops.create_summary_db_writer,
+ db_uri=self.db_path,
+ experiment_name='experiment',
+ run_name='run',
+ user_name='user')
+
+ def tearDown(self):
+ self.db.close()
+ super(SummaryDbTest, self).tearDown()
+
+
+def get_one(db, q, *p):
+ return db.execute(q, p).fetchone()[0]
+
+
+def get_all(db, q, *p):
+ return unroll(db.execute(q, p).fetchall())
+
+
+def unroll(list_of_tuples):
+ return sum(list_of_tuples, ())
diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc
index 98fff9e0ae..d63b2c6cc2 100644
--- a/tensorflow/contrib/tensorboard/db/schema.cc
+++ b/tensorflow/contrib/tensorboard/db/schema.cc
@@ -135,8 +135,7 @@ class SqliteSchema {
/// the database. This field will be mutated if the run is
/// restarted.
/// description: Optional markdown information.
- /// graph: Snappy tf.GraphDef proto with node field cleared. That
- /// field can be recreated using GraphNodes and NodeDefs.
+ /// graph_id: ID of associated Graphs row.
Status CreateRunsTable() {
return Run(R"sql(
CREATE TABLE IF NOT EXISTS Runs (
@@ -147,7 +146,7 @@ class SqliteSchema {
inserted_time REAL,
started_time REAL,
description TEXT,
- graph BLOB
+ graph_id INTEGER
)
)sql");
}
@@ -205,46 +204,78 @@ class SqliteSchema {
)sql");
}
- /// \brief Creates NodeDefs table.
- ///
- /// This table stores NodeDef protos which define the GraphDef for a
- /// Run. This functions like a hash table so rows can be shared by
- /// multiple Runs in an Experiment.
+ /// \brief Creates Graphs table.
///
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
- /// experiment_id: Optional int64 for grouping rows.
- /// node_def_id: Permanent >0 unique ID.
- /// fingerprint: Optional farmhash::Fingerprint64() of uncompressed
- /// node_def bytes, coerced to int64.
- /// node_def: BLOB containing a Snappy tf.NodeDef proto.
- Status CreateNodeDefsTable() {
+ /// graph_id: Permanent >0 unique ID.
+ /// inserted_time: Float UNIX timestamp with µs precision. This is
+ /// always the wall time of when the row was inserted into the
+ /// DB. It may be used as a hint for an archival job.
+ /// node_def: Contains Snappy tf.GraphDef proto. All fields will be
+ /// cleared except those not expressed in SQL.
+ Status CreateGraphsTable() {
return Run(R"sql(
- CREATE TABLE IF NOT EXISTS NodeDefs (
+ CREATE TABLE IF NOT EXISTS Graphs (
rowid INTEGER PRIMARY KEY,
- experiment_id INTEGER,
- node_def_id INTEGER NOT NULL,
- fingerprint INTEGER,
- node_def TEXT
+ graph_id INTEGER NOT NULL,
+ inserted_time REAL,
+ graph_def BLOB
)
)sql");
}
- /// \brief Creates RunNodeDefs table.
+ /// \brief Creates Nodes table.
///
- /// Table mapping Runs to NodeDefs. This is used to recreate the node
- /// field of the GraphDef proto.
+ /// Fields:
+ /// rowid: Ephemeral b-tree ID dictating locality.
+ /// graph_id: Permanent >0 unique ID.
+ /// node_id: ID for this node. This is more like a 0-index within
+ /// the Graph. Please note indexes are allowed to be removed.
+ /// node_name: Unique name for this Node within Graph. This is
+ /// copied from the proto so it can be indexed. This is allowed
+ /// to be NULL to save space on the index, in which case the
+ /// node_def.name proto field must not be cleared.
+ /// op: Copied from tf.NodeDef proto.
+ /// device: Copied from tf.NodeDef proto.
+ /// node_def: Contains Snappy tf.NodeDef proto. All fields will be
+ /// cleared except those not expressed in SQL.
+ Status CreateNodesTable() {
+ return Run(R"sql(
+ CREATE TABLE IF NOT EXISTS Nodes (
+ rowid INTEGER PRIMARY KEY,
+ graph_id INTEGER NOT NULL,
+ node_id INTEGER NOT NULL,
+ node_name TEXT,
+ op TEXT,
+ device TEXT,
+ node_def BLOB
+ )
+ )sql");
+ }
+
+ /// \brief Creates NodeInputs table.
///
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
- /// run_id: Mandatory ID of associated Run.
- /// node_def_id: Mandatory ID of associated NodeDef.
- Status CreateRunNodeDefsTable() {
+ /// graph_id: Permanent >0 unique ID.
+ /// node_id: Index of Node in question. This can be considered the
+ /// 'to' vertex.
+ /// idx: Used for ordering inputs on a given Node.
+ /// input_node_id: Nodes.node_id of the corresponding input node.
+ /// This can be considered the 'from' vertex.
+ /// is_control: If non-zero, indicates this input is a controlled
+ /// dependency, which means this isn't an edge through which
+ /// tensors flow. NULL means 0.
+ Status CreateNodeInputsTable() {
return Run(R"sql(
- CREATE TABLE IF NOT EXISTS RunNodeDefs (
+ CREATE TABLE IF NOT EXISTS NodeInputs (
rowid INTEGER PRIMARY KEY,
- run_id INTEGER NOT NULL,
- node_def_id INTEGER NOT NULL
+ graph_id INTEGER NOT NULL,
+ node_id INTEGER NOT NULL,
+ idx INTEGER NOT NULL,
+ input_node_id INTEGER NOT NULL,
+ is_control INTEGER
)
)sql");
}
@@ -297,11 +328,27 @@ class SqliteSchema {
)sql");
}
- /// \brief Uniquely indexes node_def_id on NodeDefs table.
- Status CreateNodeDefIdIndex() {
+ /// \brief Uniquely indexes graph_id on Graphs table.
+ Status CreateGraphIdIndex() {
return Run(R"sql(
- CREATE UNIQUE INDEX IF NOT EXISTS NodeDefIdIndex
- ON NodeDefs (node_def_id)
+ CREATE UNIQUE INDEX IF NOT EXISTS GraphIdIndex
+ ON Graphs (graph_id)
+ )sql");
+ }
+
+ /// \brief Uniquely indexes (graph_id, node_id) on Nodes table.
+ Status CreateNodeIdIndex() {
+ return Run(R"sql(
+ CREATE UNIQUE INDEX IF NOT EXISTS NodeIdIndex
+ ON Nodes (graph_id, node_id)
+ )sql");
+ }
+
+ /// \brief Uniquely indexes (graph_id, node_id, idx) on NodeInputs table.
+ Status CreateNodeInputsIndex() {
+ return Run(R"sql(
+ CREATE UNIQUE INDEX IF NOT EXISTS NodeInputsIndex
+ ON NodeInputs (graph_id, node_id, idx)
)sql");
}
@@ -350,20 +397,12 @@ class SqliteSchema {
)sql");
}
- /// \brief Indexes (experiment_id, fingerprint) on NodeDefs table.
- Status CreateNodeDefFingerprintIndex() {
- return Run(R"sql(
- CREATE INDEX IF NOT EXISTS NodeDefFingerprintIndex
- ON NodeDefs (experiment_id, fingerprint)
- WHERE fingerprint IS NOT NULL
- )sql");
- }
-
- /// \brief Uniquely indexes (run_id, node_def_id) on RunNodeDefs table.
- Status CreateRunNodeDefIndex() {
+ /// \brief Uniquely indexes (graph_id, node_name) on Nodes table.
+ Status CreateNodeNameIndex() {
return Run(R"sql(
- CREATE UNIQUE INDEX IF NOT EXISTS RunNodeDefIndex
- ON RunNodeDefs (run_id, node_def_id)
+ CREATE UNIQUE INDEX IF NOT EXISTS NodeNameIndex
+ ON Nodes (graph_id, node_name)
+ WHERE node_name IS NOT NULL
)sql");
}
@@ -387,22 +426,24 @@ Status SetupTensorboardSqliteDb(std::shared_ptr<Sqlite> db) {
TF_RETURN_IF_ERROR(s.CreateRunsTable());
TF_RETURN_IF_ERROR(s.CreateExperimentsTable());
TF_RETURN_IF_ERROR(s.CreateUsersTable());
- TF_RETURN_IF_ERROR(s.CreateNodeDefsTable());
- TF_RETURN_IF_ERROR(s.CreateRunNodeDefsTable());
+ TF_RETURN_IF_ERROR(s.CreateGraphsTable());
+ TF_RETURN_IF_ERROR(s.CreateNodeInputsTable());
+ TF_RETURN_IF_ERROR(s.CreateNodesTable());
TF_RETURN_IF_ERROR(s.CreateTensorIndex());
TF_RETURN_IF_ERROR(s.CreateTensorChunkIndex());
TF_RETURN_IF_ERROR(s.CreateTagIdIndex());
TF_RETURN_IF_ERROR(s.CreateRunIdIndex());
TF_RETURN_IF_ERROR(s.CreateExperimentIdIndex());
TF_RETURN_IF_ERROR(s.CreateUserIdIndex());
- TF_RETURN_IF_ERROR(s.CreateNodeDefIdIndex());
+ TF_RETURN_IF_ERROR(s.CreateGraphIdIndex());
+ TF_RETURN_IF_ERROR(s.CreateNodeIdIndex());
+ TF_RETURN_IF_ERROR(s.CreateNodeInputsIndex());
TF_RETURN_IF_ERROR(s.CreateTagNameIndex());
TF_RETURN_IF_ERROR(s.CreateRunNameIndex());
TF_RETURN_IF_ERROR(s.CreateExperimentNameIndex());
TF_RETURN_IF_ERROR(s.CreateUserNameIndex());
TF_RETURN_IF_ERROR(s.CreateUserEmailIndex());
- TF_RETURN_IF_ERROR(s.CreateNodeDefFingerprintIndex());
- TF_RETURN_IF_ERROR(s.CreateRunNodeDefIndex());
+ TF_RETURN_IF_ERROR(s.CreateNodeNameIndex());
return Status::OK();
}
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index a26ad61660..ae063d24ef 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -15,17 +15,29 @@ limitations under the License.
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
#include "tensorflow/contrib/tensorboard/db/schema.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/db/sqlite.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/snappy.h"
#include "tensorflow/core/util/event.pb.h"
namespace tensorflow {
namespace {
+double GetWallTime(Env* env) {
+ // TODO(@jart): Follow precise definitions for time laid out in schema.
+ // TODO(@jart): Use monotonic clock from gRPC codebase.
+ return static_cast<double>(env->NowMicros()) / 1.0e6;
+}
+
int64 MakeRandomId() {
+ // TODO(@jart): Try generating ID in 2^24 space, falling back to 2^63
+ // https://sqlite.org/src4/doc/trunk/www/varint.wiki
int64 id = static_cast<int64>(random::New64() & ((1ULL << 63) - 1));
if (id == 0) {
++id;
@@ -33,10 +45,201 @@ int64 MakeRandomId() {
return id;
}
+Status Serialize(const protobuf::MessageLite& proto, string* output) {
+ output->clear();
+ if (!proto.SerializeToString(output)) {
+ return errors::DataLoss("SerializeToString failed");
+ }
+ return Status::OK();
+}
+
+Status Compress(const string& data, string* output) {
+ output->clear();
+ if (!port::Snappy_Compress(data.data(), data.size(), output)) {
+ return errors::FailedPrecondition("TensorBase needs Snappy");
+ }
+ return Status::OK();
+}
+
+Status BindProto(SqliteStatement* stmt, int parameter,
+ const protobuf::MessageLite& proto) {
+ string serialized;
+ TF_RETURN_IF_ERROR(Serialize(proto, &serialized));
+ string compressed;
+ TF_RETURN_IF_ERROR(Compress(serialized, &compressed));
+ stmt->BindBlobUnsafe(parameter, compressed);
+ return Status::OK();
+}
+
+Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) {
+ // TODO(@jart): Make portable between little and big endian systems.
+ // TODO(@jart): Use TensorChunks with minimal copying for big tensors.
+ // TODO(@jart): Add field to indicate encoding.
+ // TODO(@jart): Allow crunch tool to re-compress with zlib instead.
+ TensorProto p;
+ t.AsProtoTensorContent(&p);
+ return BindProto(stmt, parameter, p);
+}
+
+class Transactor {
+ public:
+ explicit Transactor(std::shared_ptr<Sqlite> db)
+ : db_(std::move(db)),
+ begin_(db_->Prepare("BEGIN TRANSACTION")),
+ commit_(db_->Prepare("COMMIT TRANSACTION")),
+ rollback_(db_->Prepare("ROLLBACK TRANSACTION")) {}
+
+ template <typename T, typename... Args>
+ Status Transact(T callback, Args&&... args) {
+ TF_RETURN_IF_ERROR(begin_.StepAndReset());
+ Status s = callback(std::forward<Args>(args)...);
+ if (s.ok()) {
+ TF_RETURN_IF_ERROR(commit_.StepAndReset());
+ } else {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(rollback_.StepAndReset(), s.ToString());
+ }
+ return s;
+ }
+
+ private:
+ std::shared_ptr<Sqlite> db_;
+ SqliteStatement begin_;
+ SqliteStatement commit_;
+ SqliteStatement rollback_;
+};
+
+class GraphSaver {
+ public:
+ static Status SaveToRun(Env* env, Sqlite* db, GraphDef* graph, int64 run_id) {
+ auto get = db->Prepare("SELECT graph_id FROM Runs WHERE run_id = ?");
+ get.BindInt(1, run_id);
+ bool is_done;
+ TF_RETURN_IF_ERROR(get.Step(&is_done));
+ int64 graph_id = is_done ? 0 : get.ColumnInt(0);
+ if (graph_id == 0) {
+ graph_id = MakeRandomId();
+ // TODO(@jart): Check for ID collision.
+ auto set = db->Prepare("UPDATE Runs SET graph_id = ? WHERE run_id = ?");
+ set.BindInt(1, graph_id);
+ set.BindInt(2, run_id);
+ TF_RETURN_IF_ERROR(set.StepAndReset());
+ }
+ return Save(env, db, graph, graph_id);
+ }
+
+ static Status Save(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) {
+ GraphSaver saver{env, db, graph, graph_id};
+ saver.MapNameToNodeId();
+ TF_RETURN_IF_ERROR(saver.SaveNodeInputs());
+ TF_RETURN_IF_ERROR(saver.SaveNodes());
+ TF_RETURN_IF_ERROR(saver.SaveGraph());
+ return Status::OK();
+ }
+
+ private:
+ GraphSaver(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id)
+ : env_(env), db_(db), graph_(graph), graph_id_(graph_id) {}
+
+ void MapNameToNodeId() {
+ size_t toto = static_cast<size_t>(graph_->node_size());
+ name_copies_.reserve(toto);
+ name_to_node_id_.reserve(toto);
+ for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
+ // Copy name into memory region, since we call clear_name() later.
+ // Then wrap in StringPiece so we can compare slices without copy.
+ name_copies_.emplace_back(graph_->node(node_id).name());
+ name_to_node_id_.emplace(name_copies_.back(), node_id);
+ }
+ }
+
+ Status SaveNodeInputs() {
+ auto purge = db_->Prepare("DELETE FROM NodeInputs WHERE graph_id = ?");
+ purge.BindInt(1, graph_id_);
+ TF_RETURN_IF_ERROR(purge.StepAndReset());
+ auto insert = db_->Prepare(R"sql(
+ INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control)
+ VALUES (?, ?, ?, ?, ?)
+ )sql");
+ for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
+ const NodeDef& node = graph_->node(node_id);
+ for (int idx = 0; idx < node.input_size(); ++idx) {
+ StringPiece name = node.input(idx);
+ insert.BindInt(1, graph_id_);
+ insert.BindInt(2, node_id);
+ insert.BindInt(3, idx);
+ if (!name.empty() && name[0] == '^') {
+ name.remove_prefix(1);
+ insert.BindInt(5, 1);
+ }
+ auto e = name_to_node_id_.find(name);
+ if (e == name_to_node_id_.end()) {
+ return errors::DataLoss("Could not find node: ", name);
+ }
+ insert.BindInt(4, e->second);
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(),
+ " -> ", name);
+ }
+ }
+ return Status::OK();
+ }
+
+ Status SaveNodes() {
+ auto purge = db_->Prepare("DELETE FROM Nodes WHERE graph_id = ?");
+ purge.BindInt(1, graph_id_);
+ TF_RETURN_IF_ERROR(purge.StepAndReset());
+ auto insert = db_->Prepare(R"sql(
+ INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def)
+ VALUES (?, ?, ?, ?, ?, ?)
+ )sql");
+ for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
+ NodeDef* node = graph_->mutable_node(node_id);
+ insert.BindInt(1, graph_id_);
+ insert.BindInt(2, node_id);
+ insert.BindText(3, node->name());
+ node->clear_name();
+ if (!node->op().empty()) {
+ insert.BindText(4, node->op());
+ node->clear_op();
+ }
+ if (!node->device().empty()) {
+ insert.BindText(5, node->device());
+ node->clear_device();
+ }
+ node->clear_input();
+ TF_RETURN_IF_ERROR(BindProto(&insert, 6, *node));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name());
+ }
+ return Status::OK();
+ }
+
+ Status SaveGraph() {
+ auto insert = db_->Prepare(R"sql(
+ INSERT OR REPLACE INTO Graphs (graph_id, inserted_time, graph_def)
+ VALUES (?, ?, ?)
+ )sql");
+ insert.BindInt(1, graph_id_);
+ insert.BindDouble(2, GetWallTime(env_));
+ graph_->clear_node();
+ TF_RETURN_IF_ERROR(BindProto(&insert, 3, *graph_));
+ return insert.StepAndReset();
+ }
+
+ Env* env_;
+ Sqlite* db_;
+ GraphDef* graph_;
+ int64 graph_id_;
+ std::vector<string> name_copies_;
+ std::unordered_map<StringPiece, int64, StringPiece::Hasher> name_to_node_id_;
+};
+
class SummaryDbWriter : public SummaryWriterInterface {
public:
SummaryDbWriter(Env* env, std::shared_ptr<Sqlite> db)
- : SummaryWriterInterface(), env_(env), db_(std::move(db)), run_id_(-1) {}
+ : SummaryWriterInterface(),
+ env_(env),
+ db_(std::move(db)),
+ txn_(db_),
+ run_id_{0LL} {}
~SummaryDbWriter() override {}
Status Initialize(const string& experiment_name, const string& run_name,
@@ -76,7 +279,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
// TODO(@jart): Check for random ID collisions without needing txn retry.
insert_tensor_.BindInt(1, tag_id);
insert_tensor_.BindInt(2, global_step);
- insert_tensor_.BindDouble(3, GetWallTime());
+ insert_tensor_.BindDouble(3, GetWallTime(env_));
switch (t.dtype()) {
case DT_INT64:
insert_tensor_.BindInt(4, t.scalar<int64>()());
@@ -85,22 +288,41 @@ class SummaryDbWriter : public SummaryWriterInterface {
insert_tensor_.BindDouble(4, t.scalar<double>()());
break;
default:
- TF_RETURN_IF_ERROR(BindTensor(t));
+ TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t));
break;
}
return insert_tensor_.StepAndReset();
}
- Status WriteEvent(std::unique_ptr<Event> e) override {
+ Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override {
mutex_lock ml(mu_);
TF_RETURN_IF_ERROR(InitializeParents());
- if (e->what_case() == Event::WhatCase::kSummary) {
- const Summary& summary = e->summary();
- for (int i = 0; i < summary.value_size(); ++i) {
- TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i)));
+ return txn_.Transact(GraphSaver::SaveToRun, env_, db_.get(), g.get(),
+ run_id_);
+ }
+
+ Status WriteEvent(std::unique_ptr<Event> e) override {
+ switch (e->what_case()) {
+ case Event::WhatCase::kSummary: {
+ mutex_lock ml(mu_);
+ TF_RETURN_IF_ERROR(InitializeParents());
+ const Summary& summary = e->summary();
+ for (int i = 0; i < summary.value_size(); ++i) {
+ TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i)));
+ }
+ return Status::OK();
}
+ case Event::WhatCase::kGraphDef: {
+ std::unique_ptr<GraphDef> graph{new GraphDef};
+ if (!ParseProtoUnlimited(graph.get(), e->graph_def())) {
+ return errors::DataLoss("parse event.graph_def failed");
+ }
+ return WriteGraph(e->step(), std::move(graph));
+ }
+ default:
+ // TODO(@jart): Handle other stuff.
+ return Status::OK();
}
- return Status::OK();
}
Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
@@ -136,33 +358,8 @@ class SummaryDbWriter : public SummaryWriterInterface {
string DebugString() override { return "SummaryDbWriter"; }
private:
- double GetWallTime() {
- // TODO(@jart): Follow precise definitions for time laid out in schema.
- // TODO(@jart): Use monotonic clock from gRPC codebase.
- return static_cast<double>(env_->NowMicros()) / 1.0e6;
- }
-
- Status BindTensor(const Tensor& t) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- // TODO(@jart): Make portable between little and big endian systems.
- // TODO(@jart): Use TensorChunks with minimal copying for big tensors.
- TensorProto p;
- t.AsProtoTensorContent(&p);
- string encoded;
- if (!p.SerializeToString(&encoded)) {
- return errors::DataLoss("SerializeToString failed");
- }
- // TODO(@jart): Put byte at beginning of blob to indicate encoding.
- // TODO(@jart): Allow crunch tool to re-compress with zlib instead.
- string compressed;
- if (!port::Snappy_Compress(encoded.data(), encoded.size(), &compressed)) {
- return errors::FailedPrecondition("TensorBase needs Snappy");
- }
- insert_tensor_.BindBlobUnsafe(4, compressed);
- return Status::OK();
- }
-
Status InitializeParents() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (run_id_ >= 0) {
+ if (run_id_ > 0) {
return Status::OK();
}
int64 user_id;
@@ -195,7 +392,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
)sql");
insert_user.BindInt(1, *user_id);
insert_user.BindText(2, user_name);
- insert_user.BindDouble(3, GetWallTime());
+ insert_user.BindDouble(3, GetWallTime(env_));
TF_RETURN_IF_ERROR(insert_user.StepAndReset());
}
return Status::OK();
@@ -249,7 +446,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
}
insert.BindInt(2, *id);
insert.BindText(3, name);
- insert.BindDouble(4, GetWallTime());
+ insert.BindDouble(4, GetWallTime(env_));
TF_RETURN_IF_ERROR(insert.StepAndReset());
}
return Status::OK();
@@ -276,6 +473,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
mutex mu_;
Env* env_;
std::shared_ptr<Sqlite> db_ GUARDED_BY(mu_);
+ Transactor txn_ GUARDED_BY(mu_);
SqliteStatement insert_tensor_ GUARDED_BY(mu_);
SqliteStatement update_metadata_ GUARDED_BY(mu_);
string user_name_ GUARDED_BY(mu_);
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index c1af51e7b7..3431842ca2 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/db/sqlite.h"
@@ -212,5 +214,81 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
kTolerance);
}
+TEST_F(SummaryDbWriterTest, WriteGraph) {
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_));
+ env_.AdvanceByMillis(23);
+ GraphDef graph;
+ NodeDef* node = graph.add_node();
+ node->set_name("x");
+ node->set_op("Placeholder");
+ node = graph.add_node();
+ node->set_name("y");
+ node->set_op("Placeholder");
+ node = graph.add_node();
+ node->set_name("z");
+ node->set_op("Love");
+ node = graph.add_node();
+ node->set_name("+");
+ node->set_op("Add");
+ node->add_input("x");
+ node->add_input("y");
+ node->add_input("^z");
+ node->set_device("tpu/lol");
+ std::unique_ptr<Event> e{new Event};
+ graph.SerializeToString(e->mutable_graph_def());
+ TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs"));
+ ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Graphs"));
+ ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes"));
+ ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs"));
+
+ int64 graph_id = QueryInt("SELECT graph_id FROM Graphs");
+ EXPECT_GT(graph_id, 0LL);
+ EXPECT_EQ(graph_id, QueryInt("SELECT graph_id FROM Runs"));
+ EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs"));
+ EXPECT_FALSE(QueryString("SELECT graph_def FROM Graphs").empty());
+
+ EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0"));
+ EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1"));
+ EXPECT_EQ("z", QueryString("SELECT node_name FROM Nodes WHERE node_id = 2"));
+ EXPECT_EQ("+", QueryString("SELECT node_name FROM Nodes WHERE node_id = 3"));
+
+ EXPECT_EQ("Placeholder",
+ QueryString("SELECT op FROM Nodes WHERE node_id = 0"));
+ EXPECT_EQ("Placeholder",
+ QueryString("SELECT op FROM Nodes WHERE node_id = 1"));
+ EXPECT_EQ("Love", QueryString("SELECT op FROM Nodes WHERE node_id = 2"));
+ EXPECT_EQ("Add", QueryString("SELECT op FROM Nodes WHERE node_id = 3"));
+
+ EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 0"));
+ EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 1"));
+ EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 2"));
+ EXPECT_EQ("tpu/lol",
+ QueryString("SELECT device FROM Nodes WHERE node_id = 3"));
+
+ EXPECT_EQ(graph_id,
+ QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 0"));
+ EXPECT_EQ(graph_id,
+ QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 1"));
+ EXPECT_EQ(graph_id,
+ QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 2"));
+
+ EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 0"));
+ EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 1"));
+ EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 2"));
+
+ EXPECT_EQ(0LL,
+ QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 0"));
+ EXPECT_EQ(1LL,
+ QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 1"));
+ EXPECT_EQ(2LL,
+ QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 2"));
+
+ EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 0"));
+ EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 1"));
+ EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2"));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 5e19effe3d..b7386abdea 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -6247,6 +6247,7 @@ tf_kernel_library(
"//tensorflow/contrib/tensorboard/db:summary_db_writer",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:summary_ops_op_lib",
"//tensorflow/core/lib/db:sqlite",
],
diff --git a/tensorflow/core/kernels/summary_interface.cc b/tensorflow/core/kernels/summary_interface.cc
index cd366f8c13..ad28d77ffd 100644
--- a/tensorflow/core/kernels/summary_interface.cc
+++ b/tensorflow/core/kernels/summary_interface.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/summary.pb.h"
@@ -393,6 +394,15 @@ class SummaryWriterImpl : public SummaryWriterInterface {
return WriteEvent(std::move(e));
}
+ Status WriteGraph(int64 global_step,
+ std::unique_ptr<GraphDef> graph) override {
+ std::unique_ptr<Event> e{new Event};
+ e->set_step(global_step);
+ e->set_wall_time(GetWallTime());
+ graph->SerializeToString(e->mutable_graph_def());
+ return WriteEvent(std::move(e));
+ }
+
Status WriteEvent(std::unique_ptr<Event> event) override {
mutex_lock ml(mu_);
queue_.emplace_back(std::move(event));
diff --git a/tensorflow/core/kernels/summary_interface.h b/tensorflow/core/kernels/summary_interface.h
index ccf3459e56..da1c28709f 100644
--- a/tensorflow/core/kernels/summary_interface.h
+++ b/tensorflow/core/kernels/summary_interface.h
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
+#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/util/event.pb.h"
@@ -46,6 +47,9 @@ class SummaryWriterInterface : public ResourceBase {
virtual Status WriteAudio(int64 global_step, Tensor t, const string& tag,
int max_outputs_, float sample_rate) = 0;
+ virtual Status WriteGraph(int64 global_step,
+ std::unique_ptr<GraphDef> graph) = 0;
+
virtual Status WriteEvent(std::unique_ptr<Event> e) = 0;
};
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index 1fe2fc5b66..3706f51cf4 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
+#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/kernels/summary_interface.h"
@@ -268,4 +269,28 @@ class WriteAudioSummaryOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU),
WriteAudioSummaryOp);
+class WriteGraphSummaryOp : public OpKernel {
+ public:
+ explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ SummaryWriterInterface* s;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
+ core::ScopedUnref unref(s);
+ const Tensor* t;
+ OP_REQUIRES_OK(ctx, ctx->input("global_step", &t));
+ const int64 global_step = t->scalar<int64>()();
+ OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
+ std::unique_ptr<GraphDef> graph{new GraphDef};
+ if (!ParseProtoUnlimited(graph.get(), t->scalar<string>()())) {
+ ctx->CtxFailureWithWarning(
+ errors::DataLoss("Bad tf.GraphDef binary proto tensor string"));
+ return;
+ }
+ OP_REQUIRES_OK(ctx, s->WriteGraph(global_step, std::move(graph)));
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU),
+ WriteGraphSummaryOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc
index 5efbac7ad7..7f6d8b06cd 100644
--- a/tensorflow/core/ops/summary_ops.cc
+++ b/tensorflow/core/ops/summary_ops.cc
@@ -256,4 +256,17 @@ sample_rate: The sample rate of the signal in hertz.
max_outputs: Max number of batch elements to generate audio for.
)doc");
+REGISTER_OP("WriteGraphSummary")
+ .Input("writer: resource")
+ .Input("global_step: int64")
+ .Input("tensor: string")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Writes a `GraphDef` protocol buffer to a `SummaryWriter`.
+
+writer: Handle of `SummaryWriter`.
+global_step: The step to write the summary for.
+tensor: A scalar string of the serialized tf.GraphDef proto.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index cc46dd5162..3677aaa886 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -66,6 +66,9 @@ BLACKLIST = [
"//tensorflow/contrib/timeseries/examples:data/period_trend.csv", # pylint:disable=line-too-long
"//tensorflow/contrib/timeseries/python/timeseries:test_utils",
"//tensorflow/contrib/timeseries/python/timeseries/state_space_models:test_utils", # pylint:disable=line-too-long
+
+ # TODO(yifeif): Remove when py_library(testonly=1) is ignored.
+ "//tensorflow/contrib/summary:summary_test_internal",
]