diff options
author | 2016-10-06 16:08:04 -0800 | |
---|---|---|
committer | 2016-10-06 17:18:08 -0700 | |
commit | 752c24070f890519729e1c8f4c562dd25b9e98bd (patch) | |
tree | 1629770a728a7d9f2b7036445586555ceebb28c8 /tensorflow/python/framework/meta_graph.py | |
parent | 12167ee7a192fcfe0e9f468af88ebcced2bc4a2e (diff) |
Allow writing the MetaGraphDef to an event summary.
Supervisor and graph_actions now do this by default if saving a graph is requested.
TensorBoard can now fall back to using the MetaGraphDef if the GraphDef is not
available in the Event.
Other changes include:
* factor out metagraph code from training.saver to framework.meta_graph
* update the summary writer, session hooks, and supervisor to support these new summaries
Change: 135424442
Diffstat (limited to 'tensorflow/python/framework/meta_graph.py')
-rw-r--r-- | tensorflow/python/framework/meta_graph.py | 245 |
1 files changed, 245 insertions, 0 deletions
diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py new file mode 100644 index 0000000000..fdc0179652 --- /dev/null +++ b/tensorflow/python/framework/meta_graph.py @@ -0,0 +1,245 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Helpers to manipulate metagraphs in python. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from google.protobuf.any_pb2 import Any + +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import op_def_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.framework import op_def_registry +from tensorflow.python.framework import ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat + + +def ops_used_by_graph_def(graph_def): + """Collect the list of ops used by a graph. + + Does not validate that the ops are all registered. + + Args: + graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. + + Returns: + A list of strings, each naming an op used by the graph. + """ + # Map function names to definitions + name_to_function = {} + for fun in graph_def.library.function: + name_to_function[fun.signature.name] = fun + + # Collect the list of op names. Since functions can reference functions, we + # need a recursive traversal. + used_ops = set() # Includes both primitive ops and functions + functions_to_process = [] # A subset of used_ops + + def mark_op_as_used(op): + if op not in used_ops and op in name_to_function: + functions_to_process.append(name_to_function[op]) + used_ops.add(op) + + for node in graph_def.node: + mark_op_as_used(node.op) + while functions_to_process: + fun = functions_to_process.pop() + if fun.node_def: + for node in fun.node_def: + mark_op_as_used(node.op) + else: # TODO(josh11b): Eventually remove this case. + for node in fun.node: + mark_op_as_used(node.op) + + return [op for op in used_ops if op not in name_to_function] + + +def stripped_op_list_for_graph(graph_def): + """Collect the stripped OpDefs for ops used by a graph. + + This function computes the `stripped_op_list` field of `MetaGraphDef` and + similar protos. The result can be communicated from the producer to the + consumer, which can then use the C++ function + `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. + + Args: + graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. + + Returns: + An `OpList` of ops used by the graph. + + Raises: + ValueError: If an unregistered op is used. + """ + # This is the Python equivalent of StrippedOpListForGraph in C++. + # Unfortunately, since the Python op registry can differ from that in C++, we + # can't remove the duplication using swig (at least naively). + # TODO(irving): Support taking graphs directly. + + used_ops = ops_used_by_graph_def(graph_def) + + # Verify that all used ops are registered. + registered_ops = op_def_registry.get_registered_ops() + # These internal ops used by functions are not registered, so we need to + # whitelist them. # TODO(irving): Do something better here. + op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList") + for op in used_ops: + if op not in registered_ops and op not in op_whitelist: + raise ValueError("Op %s is used by the graph, but is not registered" % op) + + # Build the stripped op list in sorted order + return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops) + if op in registered_ops]) + + +def _get_kind_name(item): + """Returns the kind name in CollectionDef. + + Args: + item: A data item. + + Returns: + The string representation of the kind in CollectionDef. + """ + if isinstance(item, (six.string_types, six.binary_type)): + kind = "bytes_list" + elif isinstance(item, six.integer_types): + kind = "int64_list" + elif isinstance(item, float): + kind = "float_list" + elif isinstance(item, Any): + kind = "any_list" + else: + kind = "node_list" + return kind + + +def add_collection_def(meta_graph_def, key): + """Adds a collection to MetaGraphDef protocol buffer. + + Args: + meta_graph_def: MetaGraphDef protocol buffer. + key: One of the GraphKeys or user-defined string. + """ + if not isinstance(key, six.string_types) and not isinstance(key, bytes): + logging.warning("Only collections with string type keys will be " + "serialized. This key has %s", type(key)) + return + collection_list = ops.get_collection(key) + if not collection_list: + return + try: + col_def = meta_graph_def.collection_def[key] + to_proto = ops.get_to_proto_function(key) + proto_type = ops.get_collection_proto_type(key) + if to_proto: + kind = "bytes_list" + for x in collection_list: + # Additional type check to make sure the returned proto is indeed + # what we expect. + proto = to_proto(x) + if not isinstance(proto, proto_type): + raise TypeError("proto %s is not type %s" % (proto, proto_type)) + getattr(col_def, kind).value.append(proto.SerializeToString()) + else: + kind = _get_kind_name(collection_list[0]) + if kind == "node_list": + getattr(col_def, kind).value.extend([x.name for x in collection_list]) + elif kind == "bytes_list": + # NOTE(opensource): This force conversion is to work around the fact + # that Python3 distinguishes between bytes and strings. + getattr(col_def, kind).value.extend( + [compat.as_bytes(x) for x in collection_list]) + else: + getattr(col_def, kind).value.extend([x for x in collection_list]) + except Exception as e: # pylint: disable=broad-except + logging.warning("Error encountered when serializing %s.\n" + "Type is unsupported, or the types of the items don't " + "match field type in CollectionDef.\n%s", key, str(e)) + if key in meta_graph_def.collection_def: + del meta_graph_def.collection_def[key] + return + + +def create_meta_graph_def(meta_info_def=None, + graph_def=None, + saver_def=None, + collection_list=None): + """Construct and returns a `MetaGraphDef` protocol buffer. + + Args: + meta_info_def: `MetaInfoDef` protocol buffer. + graph_def: `GraphDef` protocol buffer. + saver_def: `SaverDef` protocol buffer. + collection_list: List of string keys to collect. + + Returns: + MetaGraphDef protocol buffer. + + Raises: + TypeError: If the arguments are not of the correct proto buffer type. + """ + # Type check. + if meta_info_def and not isinstance(meta_info_def, + meta_graph_pb2.MetaGraphDef.MetaInfoDef): + raise TypeError("meta_info_def must be of type MetaInfoDef, not %s", + type(meta_info_def)) + if graph_def and not isinstance(graph_def, graph_pb2.GraphDef): + raise TypeError("graph_def must be of type GraphDef, not %s", + type(graph_def)) + if saver_def and not isinstance(saver_def, saver_pb2.SaverDef): + raise TypeError("saver_def must be of type SaverDef, not %s", + type(saver_def)) + + # Creates a MetaGraphDef proto. + meta_graph_def = meta_graph_pb2.MetaGraphDef() + # Adds meta_info_def. + if meta_info_def: + meta_graph_def.meta_info_def.MergeFrom(meta_info_def) + + # Adds graph_def or the default. + if not graph_def: + meta_graph_def.graph_def.MergeFrom(ops.get_default_graph().as_graph_def( + add_shapes=True)) + else: + meta_graph_def.graph_def.MergeFrom(graph_def) + + # Fills in meta_info_def.stripped_op_list using the ops from graph_def. + # pylint: disable=g-explicit-length-test + if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0: + meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( + stripped_op_list_for_graph(meta_graph_def.graph_def)) + # pylint: enable=g-explicit-length-test + + # Adds saver_def. + if saver_def: + meta_graph_def.saver_def.MergeFrom(saver_def) + + # Adds collection_list. + if collection_list: + clist = collection_list + else: + clist = ops.get_all_collection_keys() + for ctype in clist: + add_collection_def(meta_graph_def, ctype) + return meta_graph_def |