aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-12 11:42:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-12 12:03:49 -0800
commit2eaaadae5a0afc0a92ed81cca550d57bb9b29cc1 (patch)
tree48bbc063f51cb14d9d8f62476bb2a528cac87c1f
parent573d8ca0f0ceae9706958b34d758420d05a4908b (diff)
Hashing function input/output args should be a part of the function
name hash when the name is not given explicitly. Change: 141790846
-rw-r--r--tensorflow/python/framework/function.py6
-rw-r--r--tensorflow/python/framework/function_test.py35
2 files changed, 40 insertions, 1 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 349198f1fb..32800ac307 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -694,6 +694,12 @@ class _DefinedFunction(object):
for s in slist:
update_str(s)
+ for adef in self._definition.signature.input_arg:
+ update_str(adef.SerializeToString())
+
+ for adef in self._definition.signature.output_arg:
+ update_str(adef.SerializeToString())
+
for n in sorted(self._definition.node_def, key=lambda n: n.name):
update_str(n.name)
update_str(n.op)
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 3c559dee95..1b53714c5d 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -583,7 +583,40 @@ class FunctionTest(tf.test.TestCase):
def Foo(x, y, z):
return tf.tanh(tf.matmul(x, y) + z)
- self.assertEqual("Foo_e0cb6030", Foo.instantiate([tf.float32] * 3).name)
+ self.assertEqual("Foo_d643acf7", Foo.instantiate([tf.float32] * 3).name)
+
+ def testSignatureHash(self):
+ # Foo.Inner and Bar.Inner have identical function body but have
+ # different signatures. They should be treated as two different functions.
+
+ @function.Defun()
+ def Foo(x):
+
+ @function.Defun()
+ def Inner(x):
+ return x + 10.
+
+ return Inner(x)
+
+ @function.Defun()
+ def Bar(x):
+
+ @function.Defun()
+ def Inner(x, unused_y, unused_z):
+ return x + 10.
+
+ return Inner(x, 2., 3.)
+
+ g = tf.Graph()
+ with g.as_default():
+ x = tf.constant(10.0)
+ y = Foo(x)
+ z = Bar(x)
+
+ with self.test_session(graph=g) as sess:
+ v0, v1 = sess.run([y, z])
+ self.assertAllEqual(v0, 20.)
+ self.assertAllEqual(v1, 20.)
class FunctionOverloadTest(tf.test.TestCase):