aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/importer.py
diff options
context:
space:
mode:
authorGravatar Josh Levenberg <josh11b@tensorflow.org>2016-05-09 13:53:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-09 15:01:56 -0700
commit6e02bf0299687b74557ff31b641b8d6d51cba4bf (patch)
tree17fdecef8ffcdd45b8366634ac45690a73dd9b87 /tensorflow/python/framework/importer.py
parentbf7e5fe193c5ae870cff3e4c8c3dffc9bdaa74b9 (diff)
Support limited forward compatibility when importing a MetaGraphDef.
Change: 121880284
Diffstat (limited to 'tensorflow/python/framework/importer.py')
-rw-r--r--tensorflow/python/framework/importer.py41
1 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 3d8b4e460e..afb3b55d83 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -23,8 +23,8 @@ import contextlib
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
-from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import compat
@@ -142,8 +142,15 @@ def _MaybeDevice(device):
yield
+def _FindAttrInOpDef(attr_name, op_def):
+ for attr_def in op_def.attr:
+ if attr_name == attr_def.name:
+ return attr_def
+ return None
+
+
def import_graph_def(graph_def, input_map=None, return_elements=None,
- name=None, op_dict=None):
+ name=None, op_dict=None, producer_op_list=None):
"""Imports the TensorFlow graph in `graph_def` into the Python `Graph`.
This function provides a way to import a serialized TensorFlow
@@ -167,6 +174,12 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
Must contain an `OpDef` proto for each op type named in `graph_def`.
If omitted, uses the `OpDef` protos registered in the global registry.
+ producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
+ list of `OpDef`s used by the producer of the graph. If provided, attrs
+ for ops in `graph_def` that are not in `op_dict` that have their default
+ value according to `producer_op_list` will be removed. This will allow
+ some more `GraphDef`s produced by later binaries to be accepted by
+ earlier binaries.
Returns:
A list of `Operation` and/or `Tensor` objects from the imported graph,
@@ -213,6 +226,11 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
if op_dict is None:
op_dict = op_def_registry.get_registered_ops()
+ if producer_op_list is None:
+ producer_op_dict = None
+ else:
+ producer_op_dict = {op.name: op for op in producer_op_list.op}
+
with ops.op_scope(input_map.values(), name, 'import'):
g = ops.get_default_graph()
g.graph_def_versions.CopyFrom(graph_def.versions)
@@ -233,6 +251,21 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
value = node.attr[key]
if value is None or value.WhichOneof('value') is None:
node.attr[key].CopyFrom(attr_def.default_value)
+ if producer_op_dict:
+ # Remove any default attr values that aren't in op_def.
+ if node.op in producer_op_dict:
+ producer_op_def = producer_op_dict[node.op]
+ # We make a copy of node.attr to iterate through since we
+ # may modify node.attr inside the loop.
+ for key in list(node.attr):
+ if _FindAttrInOpDef(key, op_def) is None:
+ # No attr_def in consumer, look in producer.
+ attr_def = _FindAttrInOpDef(key, producer_op_def)
+ if (attr_def and attr_def.HasField('default_value') and
+ node.attr[key] == attr_def.default_value):
+ # Unknown attr had default value in producer, delete it
+ # so it can be understood by consumer.
+ del node.attr[key]
output_types = _OutputTypes(node, op_dict)
name_to_op[node.name] = g.create_op(
@@ -326,8 +359,8 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
_InvalidNodeMessage(
node,
'Input types mismatch (expected %r but got %r)'
- % (", ".join(dtypes.as_dtype(x).name for x in input_types),
- ", ".join(x.name for x in op._input_dtypes))))
+ % (', '.join(dtypes.as_dtype(x).name for x in input_types),
+ ', '.join(x.name for x in op._input_dtypes))))
# pylint: enable=protected_access
# Execute shape inference for this op.