diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-12-12 11:42:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-12 12:03:49 -0800 |
commit | 2eaaadae5a0afc0a92ed81cca550d57bb9b29cc1 (patch) | |
tree | 48bbc063f51cb14d9d8f62476bb2a528cac87c1f | |
parent | 573d8ca0f0ceae9706958b34d758420d05a4908b (diff) |
Hashing function input/output args should be a part of the function
name hash when the name is not given explicitly.
Change: 141790846
-rw-r--r-- | tensorflow/python/framework/function.py | 6 | ||||
-rw-r--r-- | tensorflow/python/framework/function_test.py | 35 |
2 files changed, 40 insertions, 1 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 349198f1fb..32800ac307 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -694,6 +694,12 @@ class _DefinedFunction(object): for s in slist: update_str(s) + for adef in self._definition.signature.input_arg: + update_str(adef.SerializeToString()) + + for adef in self._definition.signature.output_arg: + update_str(adef.SerializeToString()) + for n in sorted(self._definition.node_def, key=lambda n: n.name): update_str(n.name) update_str(n.op) diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 3c559dee95..1b53714c5d 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -583,7 +583,40 @@ class FunctionTest(tf.test.TestCase): def Foo(x, y, z): return tf.tanh(tf.matmul(x, y) + z) - self.assertEqual("Foo_e0cb6030", Foo.instantiate([tf.float32] * 3).name) + self.assertEqual("Foo_d643acf7", Foo.instantiate([tf.float32] * 3).name) + + def testSignatureHash(self): + # Foo.Inner and Bar.Inner have identical function body but have + # different signatures. They should be treated as two different functions. + + @function.Defun() + def Foo(x): + + @function.Defun() + def Inner(x): + return x + 10. + + return Inner(x) + + @function.Defun() + def Bar(x): + + @function.Defun() + def Inner(x, unused_y, unused_z): + return x + 10. + + return Inner(x, 2., 3.) + + g = tf.Graph() + with g.as_default(): + x = tf.constant(10.0) + y = Foo(x) + z = Bar(x) + + with self.test_session(graph=g) as sess: + v0, v1 = sess.run([y, z]) + self.assertAllEqual(v0, 20.) + self.assertAllEqual(v1, 20.) class FunctionOverloadTest(tf.test.TestCase): |