aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-08-06 16:15:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 16:27:27 -0700
commit6ee02aafc8aab5c138e407e74ac423f41836ef1c (patch)
tree37c72c21c5729718cf1cf5327275fffdaa82e418
parentb8b5866d82ce7adbb34acccb8e6392fb8a130886 (diff)
Don't store NumPy arrays in the defun cache. Instead, treat them like Tensors.
PiperOrigin-RevId: 207626998
-rw-r--r--tensorflow/python/eager/function.py7
-rw-r--r--tensorflow/python/eager/function_test.py16
2 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index ca0c6b18eb..a8d1ee633e 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -838,11 +838,10 @@ def _encode_arg(arg):
_TensorType(arg.values.dtype, arg.values._shape_tuple()),
_TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
])
- # pylint: enable=protected-access
elif isinstance(arg, np.ndarray):
- # TODO(akshayka): Consider instead converting this NumPy array to a Tensor
- # and encoding it with a _TensorType.
- return ("array", arg.shape, tuple(arg.reshape(-1)))
+ tensor = ops.convert_to_tensor(arg)
+ return _TensorType(tensor.dtype, tensor._shape_tuple())
+ # pylint: enable=protected-access
elif isinstance(arg, (list, tuple)):
return tuple([_encode_arg(elem) for elem in arg])
elif isinstance(arg, dict):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index b568af9bce..b9e29635f8 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -247,6 +247,22 @@ class FunctionTest(test.TestCase):
y = f(x)
self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0)
+ def testDefunNumpyArraysConvertedToTensors(self):
+
+ def f(x):
+ return x
+
+ x = random_ops.random_uniform([2, 2]).numpy()
+ defined = function.defun(f)
+ defined(x)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+
+ x = random_ops.random_uniform([2, 2]).numpy()
+ defined(x)
+ # A NumPy array with different values but the same shape and dtype
+ # shouldn't trigger another function definition.
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)