aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function_test.py')
-rw-r--r--tensorflow/python/eager/function_test.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index b568af9bce..b9e29635f8 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -247,6 +247,22 @@ class FunctionTest(test.TestCase):
y = f(x)
self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0)
+ def testDefunNumpyArraysConvertedToTensors(self):
+
+ def f(x):
+ return x
+
+ x = random_ops.random_uniform([2, 2]).numpy()
+ defined = function.defun(f)
+ defined(x)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+
+ x = random_ops.random_uniform([2, 2]).numpy()
+ defined(x)
+ # A NumPy array with different values but the same shape and dtype
+ # shouldn't trigger another function definition.
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)