aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/summary
diff options
context:
space:
mode:
authorGravatar Dandelion Mané <dandelion@google.com>2017-03-06 17:27:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-06 17:49:29 -0800
commit1bfe682a62b35778aecaefd9a582a932e2792acf (patch)
treebb82416d4e48df90346a34a4e3f97c074e156960 /tensorflow/python/summary
parenteaffcab3a4af34ef08cc489e73d445db9af61424 (diff)
Add a system for TensorBoard plugin asset management to TensorFlow.
We need a system for getting arbitrary metadata and assets for TensorBoard plugins. Examples include configuration protobufs, vocabulary files, and sprite images. We need to be able to declare the plugin assets at graph construction time, and then serialize them to disk using the tf.summary.FileWriter. And we want to do so using the FileWriter.add_graph API, so that new features will Just Work for existing code that instantiates a FileWriter and adds the graph. This CL adds a system meeting those requirements. Plugin authors will create an asset class that extends tf.summary.PluginAsset, and implement a serialize method that dumps all the required assets to disk. At the time of graph construction, tf.summary.get_plugin_asset(AssetClass) gives them the current instance for their plugin assets (constructing it if necessary). They may then configure it. When tf.summary.FileWriter.add_graph is called with the graph, it retrieves all PluginAssets and makes a subdirectory for each one. Then, it calls serialize_to_directory on each asset. Change: 149367010
Diffstat (limited to 'tensorflow/python/summary')
-rw-r--r--tensorflow/python/summary/event_accumulator_test.py53
-rw-r--r--tensorflow/python/summary/plugin_asset.py140
-rw-r--r--tensorflow/python/summary/plugin_asset_test.py81
-rw-r--r--tensorflow/python/summary/summary.py6
-rw-r--r--tensorflow/python/summary/writer/writer.py17
-rw-r--r--tensorflow/python/summary/writer/writer_test.py34
6 files changed, 306 insertions, 25 deletions
diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py
index 49b1115624..0577486df5 100644
--- a/tensorflow/python/summary/event_accumulator_test.py
+++ b/tensorflow/python/summary/event_accumulator_test.py
@@ -52,7 +52,8 @@ class _EventGenerator(object):
Has additional convenience methods for adding test events.
"""
- def __init__(self, zero_out_timestamps=False):
+ def __init__(self, testcase, zero_out_timestamps=False):
+ self._testcase = testcase
self.items = []
self.zero_out_timestamps = zero_out_timestamps
@@ -155,6 +156,10 @@ class _EventGenerator(object):
"""Match the EventWriter API."""
self.AddEvent(event)
+ def get_logdir(self): # pylint: disable=invalid-name
+ """Return a temp directory for asset writing."""
+ return self._testcase.get_temp_dir()
+
class EventAccumulatorTest(test.TestCase):
@@ -215,13 +220,13 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
ea._GeneratorFromPath = self._real_generator
def testEmptyAccumulator(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
x = ea.EventAccumulator(gen)
x.Reload()
self.assertTagsEqual(x.Tags(), {})
def testTags(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
gen.AddScalar('s1')
gen.AddScalar('s2')
gen.AddHistogram('hst1')
@@ -241,7 +246,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
})
def testReload(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc.Reload()
self.assertTagsEqual(acc.Tags(), {})
@@ -263,7 +268,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
})
def testScalars(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
s1 = ea.ScalarEvent(wall_time=1, step=10, value=32)
s2 = ea.ScalarEvent(wall_time=2, step=12, value=64)
@@ -289,7 +294,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
self.assertEqual(expected_value, gotten_event.value[i])
def testHealthPills(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddHealthPill(13371337, 41, 'Add', 0, range(1, 13))
gen.AddHealthPill(13381338, 42, 'Add', 1, range(42, 54))
@@ -318,7 +323,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
gotten_events[1])
def testHistograms(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
val1 = ea.HistogramValue(
@@ -367,7 +372,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
self.assertEqual(acc.Histograms('hst2'), [hst2])
def testCompressedHistograms(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen, compression_bps=(0, 2500, 5000, 7500, 10000))
gen.AddHistogram(
@@ -418,7 +423,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
self.assertEqual(acc.CompressedHistograms('hst2'), [expected_cmphst2])
def testCompressedHistogramsWithEmptyHistogram(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen, compression_bps=(0, 2500, 5000, 7500, 10000))
gen.AddHistogram(
@@ -471,7 +476,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
self.assertAlmostEqual(vals[8].value, 1.0)
def testImages(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
im1 = ea.ImageEvent(
wall_time=1,
@@ -504,7 +509,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
self.assertEqual(acc.Images('im2'), [im2])
def testAudio(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
snd1 = ea.AudioEvent(
wall_time=1,
@@ -541,7 +546,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
self.assertEqual(acc.Audio('snd2'), [snd2])
def testKeyError(self):
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
acc.Reload()
with self.assertRaises(KeyError):
@@ -565,7 +570,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
def testNonValueEvents(self):
"""Tests that non-value events in the generator don't cause early exits."""
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddScalar('s1', wall_time=1, step=10, value=20)
gen.AddEvent(event_pb2.Event(wall_time=2, step=20, file_version='nots2'))
@@ -596,7 +601,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
warnings = []
self.stubs.Set(logging, 'warn', warnings.append)
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(
@@ -619,7 +624,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
def testOrphanedDataNotDiscardedIfFlagUnset(self):
"""Tests that events are not discarded if purge_orphaned_data is false.
"""
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen, purge_orphaned_data=False)
gen.AddEvent(
@@ -653,7 +658,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
warnings = []
self.stubs.Set(logging, 'warn', warnings.append)
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(
@@ -680,7 +685,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
def testOnlySummaryEventsTriggerDiscards(self):
"""Test that file version event does not trigger data purge."""
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddScalar('s1', wall_time=1, step=100, value=20)
ev1 = event_pb2.Event(wall_time=2, step=0, file_version='brain.Event:1')
@@ -698,7 +703,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
but this logic can only be used for event protos which have the SessionLog
enum, which was introduced to event.proto for file_version >= brain.Event:2.
"""
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(
event_pb2.Event(
@@ -720,7 +725,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
def testFirstEventTimestamp(self):
"""Test that FirstEventTimestamp() returns wall_time of the first event."""
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(
event_pb2.Event(
@@ -730,7 +735,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
def testReloadPopulatesFirstEventTimestamp(self):
"""Test that Reload() means FirstEventTimestamp() won't load events."""
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(
event_pb2.Event(
@@ -746,7 +751,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
def testFirstEventTimestampLoadsEvent(self):
"""Test that FirstEventTimestamp() doesn't discard the loaded event."""
- gen = _EventGenerator()
+ gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
gen.AddEvent(
event_pb2.Event(
@@ -758,7 +763,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
def testTFSummaryScalar(self):
"""Verify processing of tf.summary.scalar."""
- event_sink = _EventGenerator(zero_out_timestamps=True)
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
writer = SummaryToEventTransformer(event_sink)
with self.test_session() as sess:
ipt = array_ops.placeholder(dtypes.float32)
@@ -792,7 +797,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
def testTFSummaryImage(self):
"""Verify processing of tf.summary.image."""
- event_sink = _EventGenerator(zero_out_timestamps=True)
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
writer = SummaryToEventTransformer(event_sink)
with self.test_session() as sess:
ipt = array_ops.ones([10, 4, 4, 3], dtypes.uint8)
@@ -828,7 +833,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
def testTFSummaryTensor(self):
"""Verify processing of tf.summary.tensor."""
- event_sink = _EventGenerator(zero_out_timestamps=True)
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
writer = SummaryToEventTransformer(event_sink)
with self.test_session() as sess:
summary_lib.tensor_summary('scalar', constant_op.constant(1.0))
diff --git a/tensorflow/python/summary/plugin_asset.py b/tensorflow/python/summary/plugin_asset.py
new file mode 100644
index 0000000000..18b6bd4fc4
--- /dev/null
+++ b/tensorflow/python/summary/plugin_asset.py
@@ -0,0 +1,140 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorBoard Plugin asset abstract class.
+
+TensorBoard plugins may need to write arbitrary assets to disk, such as
+configuration information for specific outputs, or vocabulary files, or sprite
+images, etc.
+
+This module contains methods that allow plugin assets to be specified at graph
+construction time. Plugin authors define a PluginAsset which is treated as a
+singleton on a per-graph basis. The PluginAsset has a serialize_to_directory
+method which writes its assets to disk within a special plugin directory
+that the tf.summary.FileWriter creates.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.framework import ops
+
+_PLUGIN_ASSET_PREFIX = "__tensorboard_plugin_asset__"
+
+
+def get_plugin_asset(plugin_asset_cls, graph=None):
+ """Acquire singleton PluginAsset instance from a graph.
+
+ PluginAssets are always singletons, and are stored in tf Graph collections.
+ This way, they can be defined anywhere the graph is being constructed, and
+ if the same plugin is configured at many different points, the user can always
+ modify the same instance.
+
+ Args:
+ plugin_asset_cls: The PluginAsset class
+ graph: (optional) The graph to retrieve the instance from. If not specified,
+ the default graph is used.
+
+ Returns:
+ An instance of the plugin_asset_class
+
+ Raises:
+ ValueError: If we have a plugin name collision, or if we unexpectedly find
+ the wrong number of items in a collection.
+ """
+ if graph is None:
+ graph = ops.get_default_graph()
+ if not plugin_asset_cls.plugin_name:
+ raise ValueError("Class %s has no plugin_name" % plugin_asset_cls.__name__)
+
+ name = _PLUGIN_ASSET_PREFIX + plugin_asset_cls.plugin_name
+ container = graph.get_collection(name)
+ if container:
+ if len(container) is not 1:
+ raise ValueError("Collection for %s had %d items, expected 1" %
+ (name, len(container)))
+ instance = container[0]
+ if not isinstance(instance, plugin_asset_cls):
+ raise ValueError("Plugin name collision between classes %s and %s" %
+ (plugin_asset_cls.__name__, instance.__class__.__name__))
+ else:
+ instance = plugin_asset_cls()
+ graph.add_to_collection(name, instance)
+ graph.add_to_collection(_PLUGIN_ASSET_PREFIX, plugin_asset_cls.plugin_name)
+ return instance
+
+
+def get_all_plugin_assets(graph=None):
+ """Retrieve all PluginAssets stored in the graph collection.
+
+ Args:
+ graph: Optionally, the graph to get assets from. If unspecified, the default
+ graph is used.
+
+ Returns:
+ A list with all PluginAsset instances in the graph.
+
+ Raises:
+ ValueError: if we unexpectedly find a collection with the wrong number of
+ PluginAssets.
+
+ """
+ if graph is None:
+ graph = ops.get_default_graph()
+
+ out = []
+ for name in graph.get_collection(_PLUGIN_ASSET_PREFIX):
+ collection = graph.get_collection(_PLUGIN_ASSET_PREFIX + name)
+ if len(collection) is not 1:
+ raise ValueError("Collection for %s had %d items, expected 1" %
+ (name, len(collection)))
+ out.append(collection[0])
+ return out
+
+
+class PluginAsset(object):
+ """This abstract base class allows TensorBoard to serialize assets to disk.
+
+ Plugin authors are expected to extend the PluginAsset class, so that it:
+ - has a unique plugin_name
+ - provides a serialize_to_directory method that dumps its assets in that dir
+ - takes no constructor arguments
+
+ LifeCycle of a PluginAsset instance:
+ - It is constructed when get_plugin_asset is called on the class for
+ the first time.
+ - It is configured by code that follows the calls to get_plugin_asset
+ - When the containing graph is serialized by the tf.summary.FileWriter, the
+ writer calls serialize_to_directory and the PluginAsset instance dumps its
+ contents to disk.
+ """
+ __metaclass__ = abc.ABCMeta
+
+ plugin_name = None
+
+ @abc.abstractmethod
+ def serialize_to_directory(self, directory):
+ """Serialize the assets in this PluginAsset to given directory.
+
+ The directory will be specific to this plugin (as determined by the
+ plugin_key property on the class). This method will be called when the graph
+ containing this PluginAsset is given to a tf.summary.FileWriter.
+
+ Args:
+ directory: The directory path (as string) that serialize should write to.
+ """
+ raise NotImplementedError()
diff --git a/tensorflow/python/summary/plugin_asset_test.py b/tensorflow/python/summary/plugin_asset_test.py
new file mode 100644
index 0000000000..5f80a1a4db
--- /dev/null
+++ b/tensorflow/python/summary/plugin_asset_test.py
@@ -0,0 +1,81 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary import plugin_asset
+
+
+class _UnnamedPluginAsset(plugin_asset.PluginAsset):
+ """An example asset with a dummy serialize method provided, but no name."""
+
+ def serialize_to_directory(self, unused_directory):
+ pass
+
+
+class _ExamplePluginAsset(_UnnamedPluginAsset):
+ """Simple example asset."""
+ plugin_name = "_ExamplePluginAsset"
+
+
+class _OtherExampleAsset(_UnnamedPluginAsset):
+ """Simple example asset."""
+ plugin_name = "_OtherExampleAsset"
+
+
+class _ExamplePluginThatWillCauseCollision(_UnnamedPluginAsset):
+ plugin_name = "_ExamplePluginAsset"
+
+
+class PluginAssetTest(test_util.TensorFlowTestCase):
+
+ def testGetPluginAsset(self):
+ epa = plugin_asset.get_plugin_asset(_ExamplePluginAsset)
+ self.assertIsInstance(epa, _ExamplePluginAsset)
+ epa2 = plugin_asset.get_plugin_asset(_ExamplePluginAsset)
+ self.assertIs(epa, epa2)
+ opa = plugin_asset.get_plugin_asset(_OtherExampleAsset)
+ self.assertIsNot(epa, opa)
+
+ def testUnnamedPluginFails(self):
+ with self.assertRaises(ValueError):
+ plugin_asset.get_plugin_asset(_UnnamedPluginAsset)
+
+ def testPluginCollisionDetected(self):
+ plugin_asset.get_plugin_asset(_ExamplePluginAsset)
+ with self.assertRaises(ValueError):
+ plugin_asset.get_plugin_asset(_ExamplePluginThatWillCauseCollision)
+
+ def testGetAllPluginAssets(self):
+ epa = plugin_asset.get_plugin_asset(_ExamplePluginAsset)
+ opa = plugin_asset.get_plugin_asset(_OtherExampleAsset)
+ self.assertItemsEqual(plugin_asset.get_all_plugin_assets(), [epa, opa])
+
+ def testRespectsGraphArgument(self):
+ g1 = ops.Graph()
+ g2 = ops.Graph()
+ e1 = plugin_asset.get_plugin_asset(_ExamplePluginAsset, g1)
+ e2 = plugin_asset.get_plugin_asset(_ExamplePluginAsset, g2)
+
+ self.assertEqual(e1, plugin_asset.get_all_plugin_assets(g1)[0])
+ self.assertEqual(e2, plugin_asset.get_all_plugin_assets(g2)[0])
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index 6b0f5a85e6..44a6686cd6 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -27,6 +27,8 @@ See the @{$python/summary} guide.
@@merge
@@merge_all
@@get_summary_description
+@@get_plugin_asset
+@@get_all_plugin_assets
"""
from __future__ import absolute_import
@@ -36,6 +38,7 @@ from __future__ import print_function
import re as _re
from google.protobuf import json_format as _json_format
+
# exports Summary, SummaryDescription, Event, TaggedRunMetadata, SessionLog
# pylint: disable=unused-import
from tensorflow.core.framework.summary_pb2 import Summary
@@ -48,16 +51,19 @@ from tensorflow.core.util.event_pb2 import TaggedRunMetadata
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops as _ops
from tensorflow.python.ops import gen_logging_ops as _gen_logging_ops
+
# exports tensor_summary
# pylint: disable=unused-import
from tensorflow.python.ops.summary_ops import tensor_summary
# pylint: enable=unused-import
from tensorflow.python.platform import tf_logging as _logging
+
# exports FileWriter, FileWriterCache
# pylint: disable=unused-import
from tensorflow.python.summary.writer.writer import FileWriter
from tensorflow.python.summary.writer.writer_cache import FileWriterCache
# pylint: enable=unused-import
+
from tensorflow.python.util import compat as _compat
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py
index a737fd2c16..dec3ee96d5 100644
--- a/tensorflow/python/summary/writer/writer.py
+++ b/tensorflow/python/summary/writer/writer.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os.path
import time
from tensorflow.core.framework import graph_pb2
@@ -26,10 +27,15 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.summary import plugin_asset
from tensorflow.python.summary.writer.event_file_writer import EventFileWriter
+_PLUGINS_DIR = "plugins"
+
+
class SummaryToEventTransformer(object):
"""Abstractly implements the SummaryWriter API.
@@ -63,7 +69,7 @@ class SummaryToEventTransformer(object):
Args:
- event_writer: An EventWriter. Implements add_event method.
+ event_writer: An EventWriter. Implements add_event and get_logdir.
graph: A `Graph` object, such as `sess.graph`.
graph_def: DEPRECATED: Use the `graph` argument instead.
"""
@@ -158,6 +164,7 @@ class SummaryToEventTransformer(object):
# Serialize the graph with additional info.
true_graph_def = graph.as_graph_def(add_shapes=True)
+ self._write_tensorboard_metadata(graph)
elif (isinstance(graph, graph_pb2.GraphDef) or
isinstance(graph_def, graph_pb2.GraphDef)):
# The user passed a `GraphDef`.
@@ -178,6 +185,14 @@ class SummaryToEventTransformer(object):
# Finally, add the graph_def to the summary writer.
self._add_graph_def(true_graph_def, global_step)
+ def _write_tensorboard_metadata(self, graph):
+ assets = plugin_asset.get_all_plugin_assets(graph)
+ logdir = self.event_writer.get_logdir()
+ for asset in assets:
+ plugin_dir = os.path.join(logdir, _PLUGINS_DIR, asset.plugin_name)
+ gfile.MakeDirs(plugin_dir)
+ asset.serialize_to_directory(plugin_dir)
+
def add_meta_graph(self, meta_graph_def, global_step=None):
"""Adds a `MetaGraphDef` to the event file.
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index 245f752324..f9e9a8ff33 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -33,7 +33,9 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
+from tensorflow.python.summary import plugin_asset
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer
from tensorflow.python.summary.writer import writer_cache
@@ -354,5 +356,37 @@ class SummaryWriterCacheTest(test.TestCase):
self.assertFalse(sw1 == sw2)
+class ExamplePluginAsset(plugin_asset.PluginAsset):
+ plugin_name = "example"
+
+ def serialize_to_directory(self, directory):
+ foo = os.path.join(directory, "foo.txt")
+ bar = os.path.join(directory, "bar.txt")
+ with gfile.Open(foo, "w") as f:
+ f.write("foo!")
+ with gfile.Open(bar, "w") as f:
+ f.write("bar!")
+
+
+class PluginAssetsTest(test.TestCase):
+
+ def testPluginAssetSerialized(self):
+ with ops.Graph().as_default() as g:
+ plugin_asset.get_plugin_asset(ExamplePluginAsset)
+
+ logdir = self.get_temp_dir()
+ fw = writer.FileWriter(logdir)
+ fw.add_graph(g)
+ plugin_dir = os.path.join(logdir, writer._PLUGINS_DIR, "example")
+
+ with gfile.Open(os.path.join(plugin_dir, "foo.txt"), "r") as f:
+ content = f.read()
+ self.assertEqual(content, "foo!")
+
+ with gfile.Open(os.path.join(plugin_dir, "bar.txt"), "r") as f:
+ content = f.read()
+ self.assertEqual(content, "bar!")
+
+
if __name__ == "__main__":
test.main()