diff options
author | 2016-05-09 13:53:57 -0800 | |
---|---|---|
committer | 2016-05-09 15:01:56 -0700 | |
commit | 6e02bf0299687b74557ff31b641b8d6d51cba4bf (patch) | |
tree | 17fdecef8ffcdd45b8366634ac45690a73dd9b87 /tensorflow/python/framework/importer.py | |
parent | bf7e5fe193c5ae870cff3e4c8c3dffc9bdaa74b9 (diff) |
Support limited forward compatibility when importing a MetaGraphDef.
Change: 121880284
Diffstat (limited to 'tensorflow/python/framework/importer.py')
-rw-r--r-- | tensorflow/python/framework/importer.py | 41 |
1 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 3d8b4e460e..afb3b55d83 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -23,8 +23,8 @@ import contextlib from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import types_pb2 -from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import dtypes +from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.util import compat @@ -142,8 +142,15 @@ def _MaybeDevice(device): yield +def _FindAttrInOpDef(attr_name, op_def): + for attr_def in op_def.attr: + if attr_name == attr_def.name: + return attr_def + return None + + def import_graph_def(graph_def, input_map=None, return_elements=None, - name=None, op_dict=None): + name=None, op_dict=None, producer_op_list=None): """Imports the TensorFlow graph in `graph_def` into the Python `Graph`. This function provides a way to import a serialized TensorFlow @@ -167,6 +174,12 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. + producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) + list of `OpDef`s used by the producer of the graph. If provided, attrs + for ops in `graph_def` that are not in `op_dict` that have their default + value according to `producer_op_list` will be removed. This will allow + some more `GraphDef`s produced by later binaries to be accepted by + earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, @@ -213,6 +226,11 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, if op_dict is None: op_dict = op_def_registry.get_registered_ops() + if producer_op_list is None: + producer_op_dict = None + else: + producer_op_dict = {op.name: op for op in producer_op_list.op} + with ops.op_scope(input_map.values(), name, 'import'): g = ops.get_default_graph() g.graph_def_versions.CopyFrom(graph_def.versions) @@ -233,6 +251,21 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) + if producer_op_dict: + # Remove any default attr values that aren't in op_def. + if node.op in producer_op_dict: + producer_op_def = producer_op_dict[node.op] + # We make a copy of node.attr to iterate through since we + # may modify node.attr inside the loop. + for key in list(node.attr): + if _FindAttrInOpDef(key, op_def) is None: + # No attr_def in consumer, look in producer. + attr_def = _FindAttrInOpDef(key, producer_op_def) + if (attr_def and attr_def.HasField('default_value') and + node.attr[key] == attr_def.default_value): + # Unknown attr had default value in producer, delete it + # so it can be understood by consumer. + del node.attr[key] output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( @@ -326,8 +359,8 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' - % (", ".join(dtypes.as_dtype(x).name for x in input_types), - ", ".join(x.name for x in op._input_dtypes)))) + % (', '.join(dtypes.as_dtype(x).name for x in input_types), + ', '.join(x.name for x in op._input_dtypes)))) # pylint: enable=protected_access # Execute shape inference for this op. |