diff options
author | Derek Murray <mrry@google.com> | 2016-09-17 16:38:39 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-17 17:47:30 -0700 |
commit | 9814f5efc297692ef04309d94eb582d19bf2aef0 (patch) | |
tree | 3b8ebf627caecb81bb501b2477ee8d9cf115a4b9 | |
parent | 697b9a911355562e7fbeb5578bcffd4515169b2c (diff) |
SWIG simplification: Move the FeedVector conversion into C++.
This change moves the %typemap(in) FeedVector into C++ code in
tf_session_helper.cc, which now implements the conversion from a
string-to-array dictionary into the arguments to `TF_Run()`.
Change: 133494142
-rw-r--r-- | tensorflow/python/client/tf_session.i | 51 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.cc | 69 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.h | 16 |
3 files changed, 49 insertions, 87 deletions
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 50824b0bc0..4490124585 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -76,57 +76,6 @@ tensorflow::ImportNumpy(); // BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper() //////////////////////////////////////////////////////////////////////////////// -// The wrapper takes a vector of pairs of feed names and feed -// values. In Python this is represented as dictionary mapping strings -// to numpy arrays. -%typemap(in) const tensorflow::FeedVector& inputs ( - tensorflow::FeedVector temp, - tensorflow::Safe_PyObjectPtr temp_string_list(tensorflow::make_safe(nullptr)), - tensorflow::Safe_PyObjectPtr temp_array_list(tensorflow::make_safe(nullptr))) { - if (!PyDict_Check($input)) { - SWIG_fail; - } - - temp_string_list = tensorflow::make_safe(PyList_New(0)); - if (!temp_string_list) { - SWIG_fail; - } - temp_array_list = tensorflow::make_safe(PyList_New(0)); - if (!temp_array_list) { - SWIG_fail; - } - - PyObject* key; - PyObject* value; - Py_ssize_t pos = 0; - while (PyDict_Next($input, &pos, &key, &value)) { - char* key_string = PyBytes_AsString(key); - if (!key_string) { - SWIG_fail; - } - - // The ndarray must be stored as contiguous bytes in C (row-major) order. - PyObject* array_object = PyArray_FromAny( - value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr); - if (!array_object) { - SWIG_fail; - } - PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_object); - - // Keep a reference to the key and the array, in case the incoming dict is - // modified, and/or to avoid leaking references on failure. - if (PyList_Append(temp_string_list.get(), key) == -1) { - SWIG_fail; - } - if (PyList_Append(temp_array_list.get(), array_object) == -1) { - SWIG_fail; - } - - temp.push_back(std::make_pair(key_string, array)); - } - $1 = &temp; -} - // The wrapper also takes a list of fetch and target names. In Python this is // represented as a list of strings. %typemap(in) const tensorflow::NameVector& ( diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 5becc2de63..f36360edac 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -394,34 +394,52 @@ Safe_PyObjectPtr make_safe(PyObject* o) { } void TF_Run_wrapper_helper(TF_Session* session, const char* handle, - const TF_Buffer* run_options, - const FeedVector& inputs, + const TF_Buffer* run_options, PyObject* feed_dict, const NameVector& output_names, const NameVector& target_nodes, TF_Status* out_status, PyObjectVector* out_values, TF_Buffer* run_outputs) { + static const char* kFeedDictErrorMsg = + "feed_dict must be a dictionary mapping strings to NumPy arrays."; + // 1. Convert the feed inputs to the appropriate form for TF_Run. + if (!PyDict_Check(feed_dict)) { + Set_TF_Status_from_Status(out_status, + errors::InvalidArgument(kFeedDictErrorMsg)); + return; + } + NameVector input_names; - Safe_PyObjectVector - py_inputs_safe; // Used to decref the input arrays on failure. Safe_TF_TensorVector inputs_safe; // Used to delete tensors on failure. TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run. - for (const auto& name_and_array : inputs) { - py_inputs_safe.emplace_back( - make_safe(reinterpret_cast<PyObject*>(name_and_array.second))); - } + PyObject* key; + PyObject* value; + Py_ssize_t pos = 0; + int index = 0; + Status s; + while (PyDict_Next(feed_dict, &pos, &key, &value)) { + char* key_string = PyBytes_AsString(key); + if (!key_string) { + Set_TF_Status_from_Status(out_status, + errors::InvalidArgument(kFeedDictErrorMsg)); + return; + } + input_names.push_back(key_string); - Status result; - for (size_t i = 0; i < inputs.size(); ++i) { - input_names.push_back(inputs[i].first); - PyArrayObject* array = inputs[i].second; + PyArrayObject* array = reinterpret_cast<PyArrayObject*>( + PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); + if (!array) { + Set_TF_Status_from_Status(out_status, + errors::InvalidArgument(kFeedDictErrorMsg)); + return; + } // Convert numpy dtype to TensorFlow dtype. TF_DataType dtype = TF_FLOAT; - result = PyArray_TYPE_to_TF_DataType(array, &dtype); - if (!result.ok()) { - Set_TF_Status_from_Status(out_status, result); + s = PyArray_TYPE_to_TF_DataType(array, &dtype); + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); return; } @@ -446,9 +464,6 @@ void TF_Run_wrapper_helper(TF_Session* session, const char* handle, TF_AllocateTensor(dtype, dims.data(), dims.size(), size); std::memcpy(TF_TensorData(tensor), PyArray_DATA(array), size); inputs_safe.emplace_back(make_safe(tensor)); - // The destruction of the numpy array will now be handled by the - // inputs_safe destructor. - py_inputs_safe[i].reset(); } else { size_t size = 0; void* encoded = nullptr; @@ -463,11 +478,9 @@ void TF_Run_wrapper_helper(TF_Session* session, const char* handle, delete[] reinterpret_cast<char*>(data); }, array))); - // The destruction of the numpy array will now be handled by the - // inputs_safe destructor. - py_inputs_safe[i].reset(); } inputs_unsafe.push_back(inputs_safe.back().get()); + ++index; } // 2. Allocate a container for the output data. @@ -513,9 +526,9 @@ void TF_Run_wrapper_helper(TF_Session* session, const char* handle, Safe_PyObjectVector py_outputs_safe; for (size_t i = 0; i < output_names.size(); ++i) { PyObject* py_array; - result = TF_Tensor_to_PyObject(outputs[i], &py_array); - if (!result.ok()) { - Set_TF_Status_from_Status(out_status, result); + s = TF_Tensor_to_PyObject(outputs[i], &py_array); + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); return; } py_outputs_safe.emplace_back(make_safe(py_array)); @@ -532,10 +545,10 @@ void TF_Run_wrapper_helper(TF_Session* session, const char* handle, // If *out_status is OK, the caller becomes the owner of the PyObjects // in *out_values. void TF_Run_wrapper(TF_Session* session, const TF_Buffer* run_options, - const FeedVector& inputs, const NameVector& output_names, + PyObject* feed_dict, const NameVector& output_names, const NameVector& target_nodes, TF_Status* out_status, PyObjectVector* out_values, TF_Buffer* run_outputs) { - TF_Run_wrapper_helper(session, nullptr, run_options, inputs, output_names, + TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names, target_nodes, out_status, out_values, run_outputs); } @@ -558,9 +571,9 @@ void TF_PRunSetup_wrapper(TF_Session* session, const NameVector& input_names, // If *out_status is OK, the caller becomes the owner of the PyObjects // in *out_values. void TF_PRun_wrapper(TF_Session* session, const char* handle, - const FeedVector& inputs, const NameVector& output_names, + PyObject* feed_dict, const NameVector& output_names, TF_Status* out_status, PyObjectVector* out_values) { - TF_Run_wrapper_helper(session, handle, nullptr, inputs, output_names, + TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names, NameVector(), out_status, out_values, nullptr); } diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 83cab586d8..6b19740368 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -30,12 +30,6 @@ namespace tensorflow { // Container types for the various arguments and temporary values used // in the wrapper. -// A FeedVector is a vector of tensor name and numpy array pairs. The -// name is a borrowed C string. -typedef tensorflow::gtl::InlinedVector<std::pair<const char*, PyArrayObject*>, - 8> - FeedVector; - // A NameVector is a vector of tensor or operation names, as borrowed // C strings. typedef tensorflow::gtl::InlinedVector<const char*, 8> NameVector; @@ -56,6 +50,9 @@ Safe_PyObjectPtr make_safe(PyObject* o); // stolen by the implementation (i.e. the implementation will // eventually call Py_DECREF on each array input). // +// The PyObject* feed_dict must be a dictionary mapping strings to +// NumPy arrays. This function does not modify its reference count. +// // On success, the tensors corresponding to output_names[0,noutputs-1] // are placed in out_values[], and these outputs[] become the property // of the caller (the caller must eventually call Py_DECREF on them). @@ -63,7 +60,7 @@ Safe_PyObjectPtr make_safe(PyObject* o); // On failure, out_status contains a tensorflow::Status with an error // message. void TF_Run_wrapper(TF_Session* session, const TF_Buffer* run_options, - const FeedVector& inputs, const NameVector& output_names, + PyObject* feed_dict, const NameVector& output_names, const NameVector& target_nodes, TF_Status* out_status, PyObjectVector* out_values, TF_Buffer* run_outputs); @@ -84,6 +81,9 @@ void TF_PRunSetup_wrapper(TF_Session* session, const NameVector& input_names, // Continue to run the graph with additional feeds and fetches. The // execution state is uniquely identified by the handle. // +// The PyObject* feed_dict must be a dictionary mapping strings to +// NumPy arrays. This function does not modify its reference count. +// // On success, the tensors corresponding to output_names[0,noutputs-1] // are placed in out_values[], and these outputs[] become the property // of the caller (the caller must eventually call Py_DECREF on them). @@ -93,7 +93,7 @@ void TF_PRunSetup_wrapper(TF_Session* session, const NameVector& input_names, // // NOTE: This is EXPERIMENTAL and subject to change. void TF_PRun_wrapper(TF_Session* session, const char* handle, - const FeedVector& inputs, const NameVector& output_names, + PyObject* feed_dict, const NameVector& output_names, TF_Status* out_status, PyObjectVector* out_values); // Wrapper for TF_Reset that converts the string vectors to character arrays. |