aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/summary/plugin_asset.py
blob: 998fb30fa491bd35bd449fc0f7e3bc358c668074 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorBoard Plugin asset abstract class.

TensorBoard plugins may need to provide arbitrary assets, such as
configuration information for specific outputs, or vocabulary files, or sprite
images, etc.

This module contains methods that allow plugin assets to be specified at graph
construction time. Plugin authors define a PluginAsset which is treated as a
singleton on a per-graph basis. The PluginAsset has an assets method which
returns a dictionary of asset contents. The tf.summary.FileWriter
(or any other Summary writer) will serialize these assets in such a way that
TensorBoard can retrieve them.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc

from tensorflow.python.framework import ops

_PLUGIN_ASSET_PREFIX = "__tensorboard_plugin_asset__"


def get_plugin_asset(plugin_asset_cls, graph=None):
  """Acquire singleton PluginAsset instance from a graph.

  PluginAssets are always singletons, and are stored in tf Graph collections.
  This way, they can be defined anywhere the graph is being constructed, and
  if the same plugin is configured at many different points, the user can always
  modify the same instance.

  Args:
    plugin_asset_cls: The PluginAsset class
    graph: (optional) The graph to retrieve the instance from. If not specified,
      the default graph is used.

  Returns:
    An instance of the plugin_asset_class

  Raises:
    ValueError: If we have a plugin name collision, or if we unexpectedly find
      the wrong number of items in a collection.
  """
  if graph is None:
    graph = ops.get_default_graph()
  if not plugin_asset_cls.plugin_name:
    raise ValueError("Class %s has no plugin_name" % plugin_asset_cls.__name__)

  name = _PLUGIN_ASSET_PREFIX + plugin_asset_cls.plugin_name
  container = graph.get_collection(name)
  if container:
    if len(container) is not 1:
      raise ValueError("Collection for %s had %d items, expected 1" %
                       (name, len(container)))
    instance = container[0]
    if not isinstance(instance, plugin_asset_cls):
      raise ValueError("Plugin name collision between classes %s and %s" %
                       (plugin_asset_cls.__name__, instance.__class__.__name__))
  else:
    instance = plugin_asset_cls()
    graph.add_to_collection(name, instance)
    graph.add_to_collection(_PLUGIN_ASSET_PREFIX, plugin_asset_cls.plugin_name)
  return instance


def get_all_plugin_assets(graph=None):
  """Retrieve all PluginAssets stored in the graph collection.

  Args:
    graph: Optionally, the graph to get assets from. If unspecified, the default
      graph is used.

  Returns:
    A list with all PluginAsset instances in the graph.

  Raises:
    ValueError: if we unexpectedly find a collection with the wrong number of
      PluginAssets.

  """
  if graph is None:
    graph = ops.get_default_graph()

  out = []
  for name in graph.get_collection(_PLUGIN_ASSET_PREFIX):
    collection = graph.get_collection(_PLUGIN_ASSET_PREFIX + name)
    if len(collection) is not 1:
      raise ValueError("Collection for %s had %d items, expected 1" %
                       (name, len(collection)))
    out.append(collection[0])
  return out


class PluginAsset(object):
  """This abstract base class allows TensorBoard to serialize assets to disk.

  Plugin authors are expected to extend the PluginAsset class, so that it:
  - has a unique plugin_name
  - provides an assets method that returns an {asset_name: asset_contents}
    dictionary. For now, asset_contents are strings, although we may add
    StringIO support later.

  LifeCycle of a PluginAsset instance:
  - It is constructed when get_plugin_asset is called on the class for
    the first time.
  - It is configured by code that follows the calls to get_plugin_asset
  - When the containing graph is serialized by the tf.summary.FileWriter, the
    writer calls assets and the PluginAsset instance provides its contents to be
    written to disk.
  """
  __metaclass__ = abc.ABCMeta

  plugin_name = None

  @abc.abstractmethod
  def assets(self):
    """Provide all of the assets contained by the PluginAsset instance.

    The assets method should return a dictionary structured as
    {asset_name: asset_contents}. asset_contents is a string.

    This method will be called by the tf.summary.FileWriter when it is time to
    write the assets out to disk.
    """
    raise NotImplementedError()