diff options
author | 2018-08-06 16:15:54 -0700 | |
---|---|---|
committer | 2018-08-06 16:27:27 -0700 | |
commit | 6ee02aafc8aab5c138e407e74ac423f41836ef1c (patch) | |
tree | 37c72c21c5729718cf1cf5327275fffdaa82e418 | |
parent | b8b5866d82ce7adbb34acccb8e6392fb8a130886 (diff) |
Don't store NumPy arrays in the defun cache. Instead, treat them like Tensors.
PiperOrigin-RevId: 207626998
-rw-r--r-- | tensorflow/python/eager/function.py | 7 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 16 |
2 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index ca0c6b18eb..a8d1ee633e 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -838,11 +838,10 @@ def _encode_arg(arg): _TensorType(arg.values.dtype, arg.values._shape_tuple()), _TensorType(arg.indices.dtype, arg.indices._shape_tuple()), ]) - # pylint: enable=protected-access elif isinstance(arg, np.ndarray): - # TODO(akshayka): Consider instead converting this NumPy array to a Tensor - # and encoding it with a _TensorType. - return ("array", arg.shape, tuple(arg.reshape(-1))) + tensor = ops.convert_to_tensor(arg) + return _TensorType(tensor.dtype, tensor._shape_tuple()) + # pylint: enable=protected-access elif isinstance(arg, (list, tuple)): return tuple([_encode_arg(elem) for elem in arg]) elif isinstance(arg, dict): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index b568af9bce..b9e29635f8 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -247,6 +247,22 @@ class FunctionTest(test.TestCase): y = f(x) self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0) + def testDefunNumpyArraysConvertedToTensors(self): + + def f(x): + return x + + x = random_ops.random_uniform([2, 2]).numpy() + defined = function.defun(f) + defined(x) + self.assertEqual(len(defined._arguments_to_functions), 1) + + x = random_ops.random_uniform([2, 2]).numpy() + defined(x) + # A NumPy array with different values but the same shape and dtype + # shouldn't trigger another function definition. + self.assertEqual(len(defined._arguments_to_functions), 1) + def testDefunCapturedInt32(self): x = constant_op.constant(1, dtype=dtypes.int32) |