aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/meta_graph.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-10-06 16:08:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-06 17:18:08 -0700
commit752c24070f890519729e1c8f4c562dd25b9e98bd (patch)
tree1629770a728a7d9f2b7036445586555ceebb28c8 /tensorflow/python/framework/meta_graph.py
parent12167ee7a192fcfe0e9f468af88ebcced2bc4a2e (diff)
Allow writing the MetaGraphDef to an event summary.
Supervisor and graph_actions now do this by default if saving a graph is requested. TensorBoard can now fall back to using the MetaGraphDef if the GraphDef is not available in the Event. Other changes include: * factor out metagraph code from training.saver to framework.meta_graph * update the summary writer, session hooks, and supervisor to support these new summaries Change: 135424442
Diffstat (limited to 'tensorflow/python/framework/meta_graph.py')
-rw-r--r--tensorflow/python/framework/meta_graph.py245
1 files changed, 245 insertions, 0 deletions
diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py
new file mode 100644
index 0000000000..fdc0179652
--- /dev/null
+++ b/tensorflow/python/framework/meta_graph.py
@@ -0,0 +1,245 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Helpers to manipulate metagraphs in python.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from google.protobuf.any_pb2 import Any
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import compat
+
+
+def ops_used_by_graph_def(graph_def):
+ """Collect the list of ops used by a graph.
+
+ Does not validate that the ops are all registered.
+
+ Args:
+ graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
+
+ Returns:
+ A list of strings, each naming an op used by the graph.
+ """
+ # Map function names to definitions
+ name_to_function = {}
+ for fun in graph_def.library.function:
+ name_to_function[fun.signature.name] = fun
+
+ # Collect the list of op names. Since functions can reference functions, we
+ # need a recursive traversal.
+ used_ops = set() # Includes both primitive ops and functions
+ functions_to_process = [] # A subset of used_ops
+
+ def mark_op_as_used(op):
+ if op not in used_ops and op in name_to_function:
+ functions_to_process.append(name_to_function[op])
+ used_ops.add(op)
+
+ for node in graph_def.node:
+ mark_op_as_used(node.op)
+ while functions_to_process:
+ fun = functions_to_process.pop()
+ if fun.node_def:
+ for node in fun.node_def:
+ mark_op_as_used(node.op)
+ else: # TODO(josh11b): Eventually remove this case.
+ for node in fun.node:
+ mark_op_as_used(node.op)
+
+ return [op for op in used_ops if op not in name_to_function]
+
+
+def stripped_op_list_for_graph(graph_def):
+ """Collect the stripped OpDefs for ops used by a graph.
+
+ This function computes the `stripped_op_list` field of `MetaGraphDef` and
+ similar protos. The result can be communicated from the producer to the
+ consumer, which can then use the C++ function
+ `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
+
+ Args:
+ graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
+
+ Returns:
+ An `OpList` of ops used by the graph.
+
+ Raises:
+ ValueError: If an unregistered op is used.
+ """
+ # This is the Python equivalent of StrippedOpListForGraph in C++.
+ # Unfortunately, since the Python op registry can differ from that in C++, we
+ # can't remove the duplication using swig (at least naively).
+ # TODO(irving): Support taking graphs directly.
+
+ used_ops = ops_used_by_graph_def(graph_def)
+
+ # Verify that all used ops are registered.
+ registered_ops = op_def_registry.get_registered_ops()
+ # These internal ops used by functions are not registered, so we need to
+ # whitelist them. # TODO(irving): Do something better here.
+ op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
+ for op in used_ops:
+ if op not in registered_ops and op not in op_whitelist:
+ raise ValueError("Op %s is used by the graph, but is not registered" % op)
+
+ # Build the stripped op list in sorted order
+ return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops)
+ if op in registered_ops])
+
+
+def _get_kind_name(item):
+ """Returns the kind name in CollectionDef.
+
+ Args:
+ item: A data item.
+
+ Returns:
+ The string representation of the kind in CollectionDef.
+ """
+ if isinstance(item, (six.string_types, six.binary_type)):
+ kind = "bytes_list"
+ elif isinstance(item, six.integer_types):
+ kind = "int64_list"
+ elif isinstance(item, float):
+ kind = "float_list"
+ elif isinstance(item, Any):
+ kind = "any_list"
+ else:
+ kind = "node_list"
+ return kind
+
+
+def add_collection_def(meta_graph_def, key):
+ """Adds a collection to MetaGraphDef protocol buffer.
+
+ Args:
+ meta_graph_def: MetaGraphDef protocol buffer.
+ key: One of the GraphKeys or user-defined string.
+ """
+ if not isinstance(key, six.string_types) and not isinstance(key, bytes):
+ logging.warning("Only collections with string type keys will be "
+ "serialized. This key has %s", type(key))
+ return
+ collection_list = ops.get_collection(key)
+ if not collection_list:
+ return
+ try:
+ col_def = meta_graph_def.collection_def[key]
+ to_proto = ops.get_to_proto_function(key)
+ proto_type = ops.get_collection_proto_type(key)
+ if to_proto:
+ kind = "bytes_list"
+ for x in collection_list:
+ # Additional type check to make sure the returned proto is indeed
+ # what we expect.
+ proto = to_proto(x)
+ if not isinstance(proto, proto_type):
+ raise TypeError("proto %s is not type %s" % (proto, proto_type))
+ getattr(col_def, kind).value.append(proto.SerializeToString())
+ else:
+ kind = _get_kind_name(collection_list[0])
+ if kind == "node_list":
+ getattr(col_def, kind).value.extend([x.name for x in collection_list])
+ elif kind == "bytes_list":
+ # NOTE(opensource): This force conversion is to work around the fact
+ # that Python3 distinguishes between bytes and strings.
+ getattr(col_def, kind).value.extend(
+ [compat.as_bytes(x) for x in collection_list])
+ else:
+ getattr(col_def, kind).value.extend([x for x in collection_list])
+ except Exception as e: # pylint: disable=broad-except
+ logging.warning("Error encountered when serializing %s.\n"
+ "Type is unsupported, or the types of the items don't "
+ "match field type in CollectionDef.\n%s", key, str(e))
+ if key in meta_graph_def.collection_def:
+ del meta_graph_def.collection_def[key]
+ return
+
+
+def create_meta_graph_def(meta_info_def=None,
+ graph_def=None,
+ saver_def=None,
+ collection_list=None):
+ """Construct and returns a `MetaGraphDef` protocol buffer.
+
+ Args:
+ meta_info_def: `MetaInfoDef` protocol buffer.
+ graph_def: `GraphDef` protocol buffer.
+ saver_def: `SaverDef` protocol buffer.
+ collection_list: List of string keys to collect.
+
+ Returns:
+ MetaGraphDef protocol buffer.
+
+ Raises:
+ TypeError: If the arguments are not of the correct proto buffer type.
+ """
+ # Type check.
+ if meta_info_def and not isinstance(meta_info_def,
+ meta_graph_pb2.MetaGraphDef.MetaInfoDef):
+ raise TypeError("meta_info_def must be of type MetaInfoDef, not %s",
+ type(meta_info_def))
+ if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
+ raise TypeError("graph_def must be of type GraphDef, not %s",
+ type(graph_def))
+ if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
+ raise TypeError("saver_def must be of type SaverDef, not %s",
+ type(saver_def))
+
+ # Creates a MetaGraphDef proto.
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ # Adds meta_info_def.
+ if meta_info_def:
+ meta_graph_def.meta_info_def.MergeFrom(meta_info_def)
+
+ # Adds graph_def or the default.
+ if not graph_def:
+ meta_graph_def.graph_def.MergeFrom(ops.get_default_graph().as_graph_def(
+ add_shapes=True))
+ else:
+ meta_graph_def.graph_def.MergeFrom(graph_def)
+
+ # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
+ # pylint: disable=g-explicit-length-test
+ if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
+ meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
+ stripped_op_list_for_graph(meta_graph_def.graph_def))
+ # pylint: enable=g-explicit-length-test
+
+ # Adds saver_def.
+ if saver_def:
+ meta_graph_def.saver_def.MergeFrom(saver_def)
+
+ # Adds collection_list.
+ if collection_list:
+ clist = collection_list
+ else:
+ clist = ops.get_all_collection_keys()
+ for ctype in clist:
+ add_collection_def(meta_graph_def, ctype)
+ return meta_graph_def