aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/function_def_to_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/function_def_to_graph.py')
-rw-r--r--tensorflow/python/framework/function_def_to_graph.py32
1 files changed, 15 insertions, 17 deletions
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index 46c9c4c14a..1b09506662 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -25,7 +25,7 @@ from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
from tensorflow.python.framework import function
from tensorflow.python.framework import importer
-from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import ops
from tensorflow.python.framework import versions
from tensorflow.python.ops import cond_v2_impl
@@ -114,6 +114,10 @@ def function_def_to_graph_def(fdef, input_shapes=None):
producer=versions.GRAPH_DEF_VERSION,
min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
+ # Copy *all* functions from outer graph to `graph_def` so that both direct
+ # and indirect references are safely handled.
+ ops.get_default_graph()._copy_functions_to_graph_def(graph_def, 0) # pylint: disable=protected-access
+
if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
raise ValueError("Length of input_shapes must match the number of " +
"input_args. len(input_shapes): {} len(input_arg): {}".
@@ -142,24 +146,18 @@ def function_def_to_graph_def(fdef, input_shapes=None):
nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
for node_def in fdef.node_def:
- op_def = op_def_registry.get_registered_ops().get(node_def.op)
- if not op_def:
- # TODO(b/80470245): Support functions which refer other functions.
- raise NotImplementedError(
- "No op registered for {},".format(node_def.op) +
- " it may be a function. function_def_to_graph_def " +
- "currently does not support converting functions with " +
- "references to other graph functions.")
+ op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access
for attr in op_def.attr:
- if attr.type in ("func", "list(func)"):
- # TODO(b/80470245): Support functions which refer other functions.
- raise NotImplementedError("Unsupported attr {} ".format(attr.name) +
- " with type {}".format(attr.type) +
- " in op {}. ".format(op_def.name) +
- "function_def_to_graph_def currently does " +
- "not support converting functions with " +
- "references to other graph functions.")
+ if attr.type == "func":
+ fname = node_def.attr[attr.name].func.name
+ if not ops.get_default_graph()._is_function(fname): # pylint: disable=protected-access
+ raise ValueError("%s function not found." % fname)
+ elif attr.type == "list(func)":
+ for fn in node_def.attr[attr.name].list.func:
+ fname = fn.name
+ if not ops.get_default_graph()._is_function(fname): # pylint: disable=protected-access
+ raise ValueError("%s function not found." % fname)
# Iterate over output_args in op_def to build the map.
# Index of the output tensor in the flattened list of *all* output