diff options
Diffstat (limited to 'tensorflow/python/eager/function_test.py')
-rw-r--r-- | tensorflow/python/eager/function_test.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index b568af9bce..b9e29635f8 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -247,6 +247,22 @@ class FunctionTest(test.TestCase): y = f(x) self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0) + def testDefunNumpyArraysConvertedToTensors(self): + + def f(x): + return x + + x = random_ops.random_uniform([2, 2]).numpy() + defined = function.defun(f) + defined(x) + self.assertEqual(len(defined._arguments_to_functions), 1) + + x = random_ops.random_uniform([2, 2]).numpy() + defined(x) + # A NumPy array with different values but the same shape and dtype + # shouldn't trigger another function definition. + self.assertEqual(len(defined._arguments_to_functions), 1) + def testDefunCapturedInt32(self): x = constant_op.constant(1, dtype=dtypes.int32) |