aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib/core/py_func.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/lib/core/py_func.cc')
-rw-r--r--tensorflow/python/lib/core/py_func.cc53
1 files changed, 30 insertions, 23 deletions
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index 57139986af..7c107138be 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -333,6 +333,35 @@ class NumpyTensorBuffer : public TensorBuffer {
void* data_;
};
+Status PyObjectToString(PyObject* obj, string* str) {
+ char* py_bytes;
+ Py_ssize_t size;
+ if (PyBytes_AsStringAndSize(obj, &py_bytes, &size) != -1) {
+ str->assign(py_bytes, size);
+ return Status::OK();
+ }
+#if PY_MAJOR_VERSION >= 3
+ const char* ptr = PyUnicode_AsUTF8AndSize(obj, &size);
+ if (ptr != nullptr) {
+ str->assign(ptr, size);
+ return Status::OK();
+ }
+#else
+ if (PyUnicode_Check(obj)) {
+ PyObject* unicode = PyUnicode_AsUTF8String(obj);
+ char* ptr;
+ if (unicode && PyString_AsStringAndSize(unicode, &ptr, &size) != -1) {
+ str->assign(ptr, size);
+ Py_DECREF(unicode);
+ return Status::OK();
+ }
+ Py_XDECREF(unicode);
+ }
+#endif
+ return errors::Unimplemented("Unsupported object type ",
+ obj->ob_type->tp_name);
+}
+
Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) {
PyArrayObject* input = reinterpret_cast<PyArrayObject*>(obj);
DataType dtype = DT_INVALID;
@@ -348,29 +377,7 @@ Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) {
auto tflat = t.flat<string>();
PyObject** input_data = reinterpret_cast<PyObject**>(PyArray_DATA(input));
for (int i = 0; i < tflat.dimension(0); ++i) {
- char* el;
- Py_ssize_t el_size;
- if (PyBytes_AsStringAndSize(input_data[i], &el, &el_size) == -1) {
-#if PY_MAJOR_VERSION >= 3
- el = PyUnicode_AsUTF8AndSize(input_data[i], &el_size);
-#else
- el = nullptr;
- if (PyUnicode_Check(input_data[i])) {
- PyObject* unicode = PyUnicode_AsUTF8String(input_data[i]);
- if (unicode) {
- if (PyString_AsStringAndSize(unicode, &el, &el_size) == -1) {
- Py_DECREF(unicode);
- el = nullptr;
- }
- }
- }
-#endif
- if (!el) {
- return errors::Unimplemented("Unsupported object type ",
- input_data[i]->ob_type->tp_name);
- }
- }
- tflat(i) = string(el, el_size);
+ TF_RETURN_IF_ERROR(PyObjectToString(input_data[i], &tflat(i)));
}
*ret = t;
break;