aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-08 16:40:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-08 16:45:54 -0800
commit2d00e6f17df644077af331e5bcb47a0e8a0fa1b7 (patch)
tree1f646c3feb1016abf564f1e2b261c7a4c3abf770
parentd7d1c402f039d954e7bab181bb6132146f439850 (diff)
Re-enable the NodeDef version of functions, fixing issue where Python
functions create signatures with duplicate output argument names. Change: 141500446
-rw-r--r--tensorflow/core/framework/function.cc3
-rw-r--r--tensorflow/python/framework/function.py72
-rw-r--r--tensorflow/python/framework/function_test.py17
3 files changed, 72 insertions, 20 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 134bd4fadb..01fc231204 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -772,8 +772,7 @@ Status InstantiateFunction(const FunctionDef& fdef,
// Makes a copy of all attrs in fdef and substitutes placeholders.
// After this step, every attr is bound to a concrete value.
std::vector<InstantiateAttrValueMap> node_attrs;
- if (false && fdef.node_def_size() > 0) {
- // TODO(josh11b): enable this branch.
+ if (fdef.node_def_size() > 0) {
node_attrs.resize(fdef.node_def_size());
for (int i = 0; i < fdef.node_def_size(); ++i) {
for (auto attr : fdef.node_def(i).attr()) {
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index cc51966689..349198f1fb 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -42,10 +42,21 @@ def _make_argname_from_tensor_name(name):
return re.sub(":0$", "", name).replace(":", "_o")
-def _tensor_to_argdef(t, name=None):
+def _tensor_to_argdef(t, name=None, used_names=None):
+ """Convert tensor t to an argdef, with a specified name or a unique name."""
arg = op_def_pb2.OpDef.ArgDef()
if name is None:
arg.name = _make_argname_from_tensor_name(t.name)
+ if used_names is not None:
+ if arg.name in used_names:
+ i = 0
+ while True:
+ new_name = "%s_U%d" % (arg.name, i)
+ if new_name not in used_names:
+ arg.name = new_name
+ break
+ i += 1
+ used_names.add(arg.name)
else:
arg.name = name
arg.type = t.dtype.as_datatype_enum
@@ -74,6 +85,20 @@ def _add_input_array(op, start, limit, dtype, func):
return ret_name
+def _add_identity_dtype_proto(func, src, dst, dtype_proto):
+ node = function_pb2.FunctionDef.Node()
+ node.op = "Identity"
+ node.arg.append(src)
+ node.ret.append(dst)
+ node.attr["T"].CopyFrom(dtype_proto)
+ func.node.extend([node])
+
+
+def _add_identity_dtype_enum(func, src, dst, dtype):
+ dtype_proto = attr_value_pb2.AttrValue(type=dtype)
+ _add_identity_dtype_proto(func, src, dst, dtype_proto)
+
+
def _add_output_array(op, start, limit, dtype, func):
"""Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
dtype_proto = attr_value_pb2.AttrValue(type=dtype)
@@ -96,12 +121,9 @@ def _add_output_array(op, start, limit, dtype, func):
# uses of each element can be added easily later. These Identity
# will be eliminated before graph execution.
for i in xrange(num):
- node = function_pb2.FunctionDef.Node()
- node.op = "Identity"
- node.arg.append(ret_name + ":" + str(i))
- node.ret.append(_make_argname_from_tensor_name(op.outputs[i].name))
- node.attr["T"].CopyFrom(dtype_proto)
- func.node.extend([node])
+ _add_identity_dtype_proto(
+ func, ret_name + ":" + str(i),
+ _make_argname_from_tensor_name(op.outputs[i].name), dtype_proto)
return arg_name
@@ -114,12 +136,10 @@ def _add_output_list(op, start, limit, dtype_lst, func):
# uses of each element can be added easily later. These Identity
# will be eliminated before graph execution.
for i in xrange(num):
- node = function_pb2.FunctionDef.Node()
- node.op = "Identity"
- node.arg.append(ret_name + ":" + str(i))
- node.ret.append(_make_argname_from_tensor_name(op.outputs[i].name))
- node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype_lst[i]))
- func.node.extend([node])
+ _add_identity_dtype_enum(func,
+ ret_name + ":" + str(i),
+ _make_argname_from_tensor_name(op.outputs[i].name),
+ dtype_lst[i])
return ret_name
@@ -274,16 +294,23 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
"""
func = function_pb2.FunctionDef()
func.signature.name = "_"
- func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs])
+ used_names = set()
+ func.signature.input_arg.extend([_tensor_to_argdef(i, used_names=used_names)
+ for i in inputs])
if out_names is None:
- func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs])
+ used_names = set()
+ func.signature.output_arg.extend([
+ _tensor_to_argdef(o, used_names=used_names) for o in outputs])
elif len(outputs) != len(out_names):
raise ValueError(
"Length of out_names (%d) does not match number of outputs (%d): %s" %
(len(out_names), len(outputs), ", ".join(out_names)))
+ elif len(out_names) != len(set(out_names)):
+ raise ValueError(
+ "Must not have duplicates in out_names: %s" % ", ".join(out_names))
else:
func.signature.output_arg.extend([
- _tensor_to_argdef(o, n) for o, n in zip(outputs, out_names)])
+ _tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
func_arg_placeholders = set([i.name for i in inputs])
input_dict = _create_input_dict(graph, func_arg_placeholders)
@@ -293,9 +320,15 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
_add_op_node(op, func, input_dict)
if out_names is None:
- for o in outputs:
- k = _make_argname_from_tensor_name(o.name)
+ for index, o in enumerate(outputs):
+ k = func.signature.output_arg[index].name
func.ret[k] = input_dict[o.name]
+ # TODO(josh11b): Delete this once we switch fully to NodeDefs for
+ # function bodies.
+ orig = _make_argname_from_tensor_name(o.name)
+ if k != orig:
+ _add_identity_dtype_enum(func, orig, k,
+ func.signature.output_arg[index].type)
else:
for o, n in zip(outputs, out_names):
func.ret[n] = input_dict[o.name]
@@ -1006,6 +1039,9 @@ class Declare(object):
self._sig.name = func_name
def _to_argdef_list(args):
+ names = [n for n, t in args]
+ if len(names) != len(set(names)):
+ raise ValueError("Expected names to all be unique: %s" % str(names))
return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
for n, t in args]
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 85e4766db5..335d719083 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -54,6 +54,23 @@ class FunctionTest(tf.test.TestCase):
with tf.Session() as sess:
self.assertAllEqual([5.0], sess.run(call))
+ def testDefineFunctionDuplicateOutputs(self):
+
+ @function.Defun(tf.float32, func_name="Duplicate")
+ def Duplicate(a):
+ b = a + 1.0
+ return b, b
+
+ g = tf.Graph()
+ with g.as_default():
+ Duplicate([3.0])
+ func_sig = g.as_graph_def().library.function[0].signature
+ # The names given to both outputs should be different
+ # even though the same tensor is emitted to both.
+ out_names = [a.name for a in func_sig.output_arg]
+ self.assertEqual(2, len(out_names))
+ self.assertNotEqual(out_names[0], out_names[1])
+
def testGradientFunc(self):
@function.Defun(tf.float32, func_name="XSquarePlusOneFn")