aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-10-04 11:54:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 12:00:12 -0700
commit2390b48b11efda60a0f68a683c94af9612a5306f (patch)
tree91ca82cd7b7e23487d2ec1e2b3261673e6589713 /tensorflow/python
parent31619b408551907030dc25d8270f8997a0d9e6aa (diff)
Add a separator between shape and dtype in cache key encoding.
It was possible that we could mix shapes and types (T111 could mean a tensor of dtype 1 and shape (1, 1) or a tensor of dtype 11 and shape (1)). PiperOrigin-RevId: 215777629
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/function_test.py44
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc34
2 files changed, 58 insertions, 20 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 9ce367a837..a2cfb4b476 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1255,6 +1255,44 @@ class FunctionTest(test.TestCase):
defined(Foo())
self.assertEqual(len(defined._function_cache), 2)
+ def testCacheTensorShapeDtypeCollision(self):
+
+ def func(t):
+ return t + t
+
+ defined = function.defun(func)
+ t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
+ defined(t)
+ self.assertEqual(len(defined._function_cache), 1)
+
+ t = constant_op.constant([1.0], dtype=dtypes.complex128)
+ defined(t)
+ self.assertEqual(len(defined._function_cache), 2)
+
+ def testCacheTensorUnknownShapesCollision(self):
+
+ def func(t):
+ return t + t
+
+ with context.graph_mode(), self.cached_session():
+ defined = function.defun(func)
+
+ p = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ defined(p)
+ self.assertEqual(len(defined._function_cache), 1)
+
+ p = array_ops.placeholder(dtype=dtypes.float32, shape=[None])
+ defined(p)
+ self.assertEqual(len(defined._function_cache), 2)
+
+ p = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
+ defined(p)
+ self.assertEqual(len(defined._function_cache), 3)
+
+ t = constant_op.constant(1.0, dtype=dtypes.float32)
+ defined(t)
+ self.assertEqual(len(defined._function_cache), 4)
+
def testPythonFunctionWithDefaultArgs(self):
def func(foo, bar=1, baz=2):
@@ -1271,17 +1309,17 @@ class FunctionTest(test.TestCase):
return tuple(key[0] for key in defined._function_cache)
# `True` corresponds to the fact that we're executing eagerly
- self.assertIn(('tRRR', (0, 1, 20)), cache_keys())
+ self.assertIn(('URRR', (0, 1, 20)), cache_keys())
defined(1) # bar=1, baz=2
- self.assertIn(('tRRR', (1, 1, 2)), cache_keys())
+ self.assertIn(('URRR', (1, 1, 2)), cache_keys())
# This matches the previous call.
defined(foo=1)
self.assertEqual(len(defined._function_cache), 2)
defined(1, 2, 3)
- self.assertIn(('tRRR', (1, 2, 3)), cache_keys())
+ self.assertIn(('URRR', (1, 2, 3)), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index ae1e12f9c3..6193f40ce8 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -2747,11 +2747,15 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
}
namespace {
-
-tensorflow::int64 GetPyNoneHash() {
- tensorflow::int64 py_none_hash = PyObject_Hash(Py_None);
- return py_none_hash;
-}
+const char kTensor[] = "T";
+const char kIndexedSlices[] = "I";
+const char kList[] = "L";
+const char kTuple[] = "U";
+const char kDict[] = "D";
+const char kRaw[] = "R";
+const char kShape[] = "s";
+const char kDType[] = "d";
+const char kNone[] = "n";
struct EncodeResult {
string str;
@@ -2784,8 +2788,10 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
TFE_TensorHandle* t = EagerTensor_Handle(arg);
tensorflow::TensorShape tensor_shape;
TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape));
- absl::StrAppend(&result->str, t->handle->dtype);
+ absl::StrAppend(&result->str, kDType, t->handle->dtype);
+
+ absl::StrAppend(&result->str, kShape);
for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
absl::StrAppend(&result->str, dim_size);
}
@@ -2812,7 +2818,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
- absl::StrAppend(&result->str, dtype);
+ absl::StrAppend(&result->str, kDType, dtype);
static char _shape_tuple[] = "_shape_tuple";
tensorflow::Safe_PyObjectPtr shape_tuple(
PyObject_CallMethod(arg, _shape_tuple, nullptr));
@@ -2824,10 +2830,11 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
if (shape_tuple.get() == Py_None) {
// Unknown shape, encode that directly.
- absl::StrAppend(&result->str, GetPyNoneHash());
+ absl::StrAppend(&result->str, kNone);
return tensorflow::Status::OK();
}
+ absl::StrAppend(&result->str, kShape);
tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
shape_tuple.get(), "shape_tuple didn't return a sequence"));
@@ -2835,7 +2842,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
if (item == Py_None) {
- absl::StrAppend(&result->str, GetPyNoneHash());
+ absl::StrAppend(&result->str, kNone);
} else {
absl::StrAppend(&result->str, MakeInt(item));
}
@@ -2844,13 +2851,6 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
return tensorflow::Status::OK();
}
-const char kTensor[] = "T";
-const char kIndexedSlices[] = "I";
-const char kList[] = "L";
-const char kTuple[] = "t";
-const char kDict[] = "D";
-const char kRaw[] = "R";
-
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result);
// This function doesn't set the type of sequence before
@@ -2864,7 +2864,7 @@ tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(arg_seq.get(), i);
if (item == Py_None) {
- absl::StrAppend(&result->str, GetPyNoneHash());
+ absl::StrAppend(&result->str, kNone);
} else {
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(item, result));
}