diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-10-04 11:54:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 12:00:12 -0700 |
commit | 2390b48b11efda60a0f68a683c94af9612a5306f (patch) | |
tree | 91ca82cd7b7e23487d2ec1e2b3261673e6589713 /tensorflow/python | |
parent | 31619b408551907030dc25d8270f8997a0d9e6aa (diff) |
Add a separator between shape and dtype in cache key encoding.
It was possible that we could mix shapes and types (T111 could mean a tensor of dtype 1 and shape (1, 1) or a tensor of dtype 11 and shape (1)).
PiperOrigin-RevId: 215777629
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/function_test.py | 44 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 34 |
2 files changed, 58 insertions, 20 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 9ce367a837..a2cfb4b476 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1255,6 +1255,44 @@ class FunctionTest(test.TestCase): defined(Foo()) self.assertEqual(len(defined._function_cache), 2) + def testCacheTensorShapeDtypeCollision(self): + + def func(t): + return t + t + + defined = function.defun(func) + t = constant_op.constant([[1.0]], dtype=dtypes.complex64) + defined(t) + self.assertEqual(len(defined._function_cache), 1) + + t = constant_op.constant([1.0], dtype=dtypes.complex128) + defined(t) + self.assertEqual(len(defined._function_cache), 2) + + def testCacheTensorUnknownShapesCollision(self): + + def func(t): + return t + t + + with context.graph_mode(), self.cached_session(): + defined = function.defun(func) + + p = array_ops.placeholder(dtype=dtypes.float32, shape=None) + defined(p) + self.assertEqual(len(defined._function_cache), 1) + + p = array_ops.placeholder(dtype=dtypes.float32, shape=[None]) + defined(p) + self.assertEqual(len(defined._function_cache), 2) + + p = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None]) + defined(p) + self.assertEqual(len(defined._function_cache), 3) + + t = constant_op.constant(1.0, dtype=dtypes.float32) + defined(t) + self.assertEqual(len(defined._function_cache), 4) + def testPythonFunctionWithDefaultArgs(self): def func(foo, bar=1, baz=2): @@ -1271,17 +1309,17 @@ class FunctionTest(test.TestCase): return tuple(key[0] for key in defined._function_cache) # `True` corresponds to the fact that we're executing eagerly - self.assertIn(('tRRR', (0, 1, 20)), cache_keys()) + self.assertIn(('URRR', (0, 1, 20)), cache_keys()) defined(1) # bar=1, baz=2 - self.assertIn(('tRRR', (1, 1, 2)), cache_keys()) + self.assertIn(('URRR', (1, 1, 2)), cache_keys()) # This matches the previous call. defined(foo=1) self.assertEqual(len(defined._function_cache), 2) defined(1, 2, 3) - self.assertIn(('tRRR', (1, 2, 3)), cache_keys()) + self.assertIn(('URRR', (1, 2, 3)), cache_keys()) # This matches the previous call. defined(1, bar=2, baz=3) diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index ae1e12f9c3..6193f40ce8 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -2747,11 +2747,15 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, } namespace { - -tensorflow::int64 GetPyNoneHash() { - tensorflow::int64 py_none_hash = PyObject_Hash(Py_None); - return py_none_hash; -} +const char kTensor[] = "T"; +const char kIndexedSlices[] = "I"; +const char kList[] = "L"; +const char kTuple[] = "U"; +const char kDict[] = "D"; +const char kRaw[] = "R"; +const char kShape[] = "s"; +const char kDType[] = "d"; +const char kNone[] = "n"; struct EncodeResult { string str; @@ -2784,8 +2788,10 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) { TFE_TensorHandle* t = EagerTensor_Handle(arg); tensorflow::TensorShape tensor_shape; TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape)); - absl::StrAppend(&result->str, t->handle->dtype); + absl::StrAppend(&result->str, kDType, t->handle->dtype); + + absl::StrAppend(&result->str, kShape); for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) { absl::StrAppend(&result->str, dim_size); } @@ -2812,7 +2818,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) { tensorflow::DataType dtype = static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get())); - absl::StrAppend(&result->str, dtype); + absl::StrAppend(&result->str, kDType, dtype); static char _shape_tuple[] = "_shape_tuple"; tensorflow::Safe_PyObjectPtr shape_tuple( PyObject_CallMethod(arg, _shape_tuple, nullptr)); @@ -2824,10 +2830,11 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) { if (shape_tuple.get() == Py_None) { // Unknown shape, encode that directly. - absl::StrAppend(&result->str, GetPyNoneHash()); + absl::StrAppend(&result->str, kNone); return tensorflow::Status::OK(); } + absl::StrAppend(&result->str, kShape); tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast( shape_tuple.get(), "shape_tuple didn't return a sequence")); @@ -2835,7 +2842,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) { for (int i = 0; i < len; ++i) { PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i); if (item == Py_None) { - absl::StrAppend(&result->str, GetPyNoneHash()); + absl::StrAppend(&result->str, kNone); } else { absl::StrAppend(&result->str, MakeInt(item)); } @@ -2844,13 +2851,6 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) { return tensorflow::Status::OK(); } -const char kTensor[] = "T"; -const char kIndexedSlices[] = "I"; -const char kList[] = "L"; -const char kTuple[] = "t"; -const char kDict[] = "D"; -const char kRaw[] = "R"; - tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result); // This function doesn't set the type of sequence before @@ -2864,7 +2864,7 @@ tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type, for (int i = 0; i < len; ++i) { PyObject* item = PySequence_Fast_GET_ITEM(arg_seq.get(), i); if (item == Py_None) { - absl::StrAppend(&result->str, GetPyNoneHash()); + absl::StrAppend(&result->str, kNone); } else { TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(item, result)); } |