diff options
Diffstat (limited to 'tensorflow/python/lib/core/py_func.cc')
-rw-r--r-- | tensorflow/python/lib/core/py_func.cc | 53 |
1 files changed, 30 insertions, 23 deletions
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 57139986af..7c107138be 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -333,6 +333,35 @@ class NumpyTensorBuffer : public TensorBuffer { void* data_; }; +Status PyObjectToString(PyObject* obj, string* str) { + char* py_bytes; + Py_ssize_t size; + if (PyBytes_AsStringAndSize(obj, &py_bytes, &size) != -1) { + str->assign(py_bytes, size); + return Status::OK(); + } +#if PY_MAJOR_VERSION >= 3 + const char* ptr = PyUnicode_AsUTF8AndSize(obj, &size); + if (ptr != nullptr) { + str->assign(ptr, size); + return Status::OK(); + } +#else + if (PyUnicode_Check(obj)) { + PyObject* unicode = PyUnicode_AsUTF8String(obj); + char* ptr; + if (unicode && PyString_AsStringAndSize(unicode, &ptr, &size) != -1) { + str->assign(ptr, size); + Py_DECREF(unicode); + return Status::OK(); + } + Py_XDECREF(unicode); + } +#endif + return errors::Unimplemented("Unsupported object type ", + obj->ob_type->tp_name); +} + Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) { PyArrayObject* input = reinterpret_cast<PyArrayObject*>(obj); DataType dtype = DT_INVALID; @@ -348,29 +377,7 @@ Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) { auto tflat = t.flat<string>(); PyObject** input_data = reinterpret_cast<PyObject**>(PyArray_DATA(input)); for (int i = 0; i < tflat.dimension(0); ++i) { - char* el; - Py_ssize_t el_size; - if (PyBytes_AsStringAndSize(input_data[i], &el, &el_size) == -1) { -#if PY_MAJOR_VERSION >= 3 - el = PyUnicode_AsUTF8AndSize(input_data[i], &el_size); -#else - el = nullptr; - if (PyUnicode_Check(input_data[i])) { - PyObject* unicode = PyUnicode_AsUTF8String(input_data[i]); - if (unicode) { - if (PyString_AsStringAndSize(unicode, &el, &el_size) == -1) { - Py_DECREF(unicode); - el = nullptr; - } - } - } -#endif - if (!el) { - return errors::Unimplemented("Unsupported object type ", - input_data[i]->ob_type->tp_name); - } - } - tflat(i) = string(el, el_size); + TF_RETURN_IF_ERROR(PyObjectToString(input_data[i], &tflat(i))); } *ret = t; break; |