aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-03-12 13:00:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-12 13:07:14 -0700
commitdc15b875893d55793c419840446dc809bcb7383f (patch)
tree0502da2518b97df43d66ae256091e6a05cf8cd54 /tensorflow/python/lib
parentbae670486f2cf87983476067103a019bbdf86333 (diff)
Fix another eager PyObject leak
Shockingly this one was also due to PySequence_GetItem. PiperOrigin-RevId: 188765548
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.cc9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 317bdc2e14..8247d354db 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -84,6 +84,7 @@ bool IsPyDimension(PyObject* obj) {
}
Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
+ std::vector<Safe_PyObjectPtr> refs_to_clean;
while (true) {
// We test strings first, in case a string is considered a sequence.
if (IsPyString(obj)) {
@@ -93,6 +94,7 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
if (length > 0) {
shape->AddDim(length);
obj = PySequence_GetItem(obj, 0);
+ refs_to_clean.push_back(make_safe(obj));
continue;
} else if (length == 0) {
shape->AddDim(length);
@@ -167,14 +169,15 @@ const char ErrorFoundFloat[] =
if (shape.dims() > 1) { \
/* Iterate over outer dim, and recursively convert each element. */ \
const int64 s = shape.dim_size(0); \
- if (TF_PREDICT_FALSE(s != PySequence_Length(obj))) { \
+ Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \
+ if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \
return ErrorRectangular; \
} \
TensorShape rest = shape; \
rest.RemoveDim(0); \
for (int64 i = 0; i < s; ++i) { \
- const char* error = \
- FUNCTION##Helper(PySequence_GetItem(obj, i), rest, buf); \
+ const char* error = FUNCTION##Helper( \
+ PySequence_Fast_GET_ITEM(seq.get(), i), rest, buf); \
if (TF_PREDICT_FALSE(error != nullptr)) return error; \
} \
} else { \