aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-09-17 16:38:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-17 17:47:30 -0700
commit9814f5efc297692ef04309d94eb582d19bf2aef0 (patch)
tree3b8ebf627caecb81bb501b2477ee8d9cf115a4b9
parent697b9a911355562e7fbeb5578bcffd4515169b2c (diff)
SWIG simplification: Move the FeedVector conversion into C++.
This change moves the %typemap(in) FeedVector into C++ code in tf_session_helper.cc, which now implements the conversion from a string-to-array dictionary into the arguments to `TF_Run()`. Change: 133494142
-rw-r--r--tensorflow/python/client/tf_session.i51
-rw-r--r--tensorflow/python/client/tf_session_helper.cc69
-rw-r--r--tensorflow/python/client/tf_session_helper.h16
3 files changed, 49 insertions, 87 deletions
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 50824b0bc0..4490124585 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -76,57 +76,6 @@ tensorflow::ImportNumpy();
// BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper()
////////////////////////////////////////////////////////////////////////////////
-// The wrapper takes a vector of pairs of feed names and feed
-// values. In Python this is represented as dictionary mapping strings
-// to numpy arrays.
-%typemap(in) const tensorflow::FeedVector& inputs (
- tensorflow::FeedVector temp,
- tensorflow::Safe_PyObjectPtr temp_string_list(tensorflow::make_safe(nullptr)),
- tensorflow::Safe_PyObjectPtr temp_array_list(tensorflow::make_safe(nullptr))) {
- if (!PyDict_Check($input)) {
- SWIG_fail;
- }
-
- temp_string_list = tensorflow::make_safe(PyList_New(0));
- if (!temp_string_list) {
- SWIG_fail;
- }
- temp_array_list = tensorflow::make_safe(PyList_New(0));
- if (!temp_array_list) {
- SWIG_fail;
- }
-
- PyObject* key;
- PyObject* value;
- Py_ssize_t pos = 0;
- while (PyDict_Next($input, &pos, &key, &value)) {
- char* key_string = PyBytes_AsString(key);
- if (!key_string) {
- SWIG_fail;
- }
-
- // The ndarray must be stored as contiguous bytes in C (row-major) order.
- PyObject* array_object = PyArray_FromAny(
- value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr);
- if (!array_object) {
- SWIG_fail;
- }
- PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_object);
-
- // Keep a reference to the key and the array, in case the incoming dict is
- // modified, and/or to avoid leaking references on failure.
- if (PyList_Append(temp_string_list.get(), key) == -1) {
- SWIG_fail;
- }
- if (PyList_Append(temp_array_list.get(), array_object) == -1) {
- SWIG_fail;
- }
-
- temp.push_back(std::make_pair(key_string, array));
- }
- $1 = &temp;
-}
-
// The wrapper also takes a list of fetch and target names. In Python this is
// represented as a list of strings.
%typemap(in) const tensorflow::NameVector& (
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index 5becc2de63..f36360edac 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -394,34 +394,52 @@ Safe_PyObjectPtr make_safe(PyObject* o) {
}
void TF_Run_wrapper_helper(TF_Session* session, const char* handle,
- const TF_Buffer* run_options,
- const FeedVector& inputs,
+ const TF_Buffer* run_options, PyObject* feed_dict,
const NameVector& output_names,
const NameVector& target_nodes,
TF_Status* out_status, PyObjectVector* out_values,
TF_Buffer* run_outputs) {
+ static const char* kFeedDictErrorMsg =
+ "feed_dict must be a dictionary mapping strings to NumPy arrays.";
+
// 1. Convert the feed inputs to the appropriate form for TF_Run.
+ if (!PyDict_Check(feed_dict)) {
+ Set_TF_Status_from_Status(out_status,
+ errors::InvalidArgument(kFeedDictErrorMsg));
+ return;
+ }
+
NameVector input_names;
- Safe_PyObjectVector
- py_inputs_safe; // Used to decref the input arrays on failure.
Safe_TF_TensorVector inputs_safe; // Used to delete tensors on failure.
TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run.
- for (const auto& name_and_array : inputs) {
- py_inputs_safe.emplace_back(
- make_safe(reinterpret_cast<PyObject*>(name_and_array.second)));
- }
+ PyObject* key;
+ PyObject* value;
+ Py_ssize_t pos = 0;
+ int index = 0;
+ Status s;
+ while (PyDict_Next(feed_dict, &pos, &key, &value)) {
+ char* key_string = PyBytes_AsString(key);
+ if (!key_string) {
+ Set_TF_Status_from_Status(out_status,
+ errors::InvalidArgument(kFeedDictErrorMsg));
+ return;
+ }
+ input_names.push_back(key_string);
- Status result;
- for (size_t i = 0; i < inputs.size(); ++i) {
- input_names.push_back(inputs[i].first);
- PyArrayObject* array = inputs[i].second;
+ PyArrayObject* array = reinterpret_cast<PyArrayObject*>(
+ PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
+ if (!array) {
+ Set_TF_Status_from_Status(out_status,
+ errors::InvalidArgument(kFeedDictErrorMsg));
+ return;
+ }
// Convert numpy dtype to TensorFlow dtype.
TF_DataType dtype = TF_FLOAT;
- result = PyArray_TYPE_to_TF_DataType(array, &dtype);
- if (!result.ok()) {
- Set_TF_Status_from_Status(out_status, result);
+ s = PyArray_TYPE_to_TF_DataType(array, &dtype);
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
return;
}
@@ -446,9 +464,6 @@ void TF_Run_wrapper_helper(TF_Session* session, const char* handle,
TF_AllocateTensor(dtype, dims.data(), dims.size(), size);
std::memcpy(TF_TensorData(tensor), PyArray_DATA(array), size);
inputs_safe.emplace_back(make_safe(tensor));
- // The destruction of the numpy array will now be handled by the
- // inputs_safe destructor.
- py_inputs_safe[i].reset();
} else {
size_t size = 0;
void* encoded = nullptr;
@@ -463,11 +478,9 @@ void TF_Run_wrapper_helper(TF_Session* session, const char* handle,
delete[] reinterpret_cast<char*>(data);
},
array)));
- // The destruction of the numpy array will now be handled by the
- // inputs_safe destructor.
- py_inputs_safe[i].reset();
}
inputs_unsafe.push_back(inputs_safe.back().get());
+ ++index;
}
// 2. Allocate a container for the output data.
@@ -513,9 +526,9 @@ void TF_Run_wrapper_helper(TF_Session* session, const char* handle,
Safe_PyObjectVector py_outputs_safe;
for (size_t i = 0; i < output_names.size(); ++i) {
PyObject* py_array;
- result = TF_Tensor_to_PyObject(outputs[i], &py_array);
- if (!result.ok()) {
- Set_TF_Status_from_Status(out_status, result);
+ s = TF_Tensor_to_PyObject(outputs[i], &py_array);
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
return;
}
py_outputs_safe.emplace_back(make_safe(py_array));
@@ -532,10 +545,10 @@ void TF_Run_wrapper_helper(TF_Session* session, const char* handle,
// If *out_status is OK, the caller becomes the owner of the PyObjects
// in *out_values.
void TF_Run_wrapper(TF_Session* session, const TF_Buffer* run_options,
- const FeedVector& inputs, const NameVector& output_names,
+ PyObject* feed_dict, const NameVector& output_names,
const NameVector& target_nodes, TF_Status* out_status,
PyObjectVector* out_values, TF_Buffer* run_outputs) {
- TF_Run_wrapper_helper(session, nullptr, run_options, inputs, output_names,
+ TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names,
target_nodes, out_status, out_values, run_outputs);
}
@@ -558,9 +571,9 @@ void TF_PRunSetup_wrapper(TF_Session* session, const NameVector& input_names,
// If *out_status is OK, the caller becomes the owner of the PyObjects
// in *out_values.
void TF_PRun_wrapper(TF_Session* session, const char* handle,
- const FeedVector& inputs, const NameVector& output_names,
+ PyObject* feed_dict, const NameVector& output_names,
TF_Status* out_status, PyObjectVector* out_values) {
- TF_Run_wrapper_helper(session, handle, nullptr, inputs, output_names,
+ TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names,
NameVector(), out_status, out_values, nullptr);
}
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 83cab586d8..6b19740368 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -30,12 +30,6 @@ namespace tensorflow {
// Container types for the various arguments and temporary values used
// in the wrapper.
-// A FeedVector is a vector of tensor name and numpy array pairs. The
-// name is a borrowed C string.
-typedef tensorflow::gtl::InlinedVector<std::pair<const char*, PyArrayObject*>,
- 8>
- FeedVector;
-
// A NameVector is a vector of tensor or operation names, as borrowed
// C strings.
typedef tensorflow::gtl::InlinedVector<const char*, 8> NameVector;
@@ -56,6 +50,9 @@ Safe_PyObjectPtr make_safe(PyObject* o);
// stolen by the implementation (i.e. the implementation will
// eventually call Py_DECREF on each array input).
//
+// The PyObject* feed_dict must be a dictionary mapping strings to
+// NumPy arrays. This function does not modify its reference count.
+//
// On success, the tensors corresponding to output_names[0,noutputs-1]
// are placed in out_values[], and these outputs[] become the property
// of the caller (the caller must eventually call Py_DECREF on them).
@@ -63,7 +60,7 @@ Safe_PyObjectPtr make_safe(PyObject* o);
// On failure, out_status contains a tensorflow::Status with an error
// message.
void TF_Run_wrapper(TF_Session* session, const TF_Buffer* run_options,
- const FeedVector& inputs, const NameVector& output_names,
+ PyObject* feed_dict, const NameVector& output_names,
const NameVector& target_nodes, TF_Status* out_status,
PyObjectVector* out_values, TF_Buffer* run_outputs);
@@ -84,6 +81,9 @@ void TF_PRunSetup_wrapper(TF_Session* session, const NameVector& input_names,
// Continue to run the graph with additional feeds and fetches. The
// execution state is uniquely identified by the handle.
//
+// The PyObject* feed_dict must be a dictionary mapping strings to
+// NumPy arrays. This function does not modify its reference count.
+//
// On success, the tensors corresponding to output_names[0,noutputs-1]
// are placed in out_values[], and these outputs[] become the property
// of the caller (the caller must eventually call Py_DECREF on them).
@@ -93,7 +93,7 @@ void TF_PRunSetup_wrapper(TF_Session* session, const NameVector& input_names,
//
// NOTE: This is EXPERIMENTAL and subject to change.
void TF_PRun_wrapper(TF_Session* session, const char* handle,
- const FeedVector& inputs, const NameVector& output_names,
+ PyObject* feed_dict, const NameVector& output_names,
TF_Status* out_status, PyObjectVector* out_values);
// Wrapper for TF_Reset that converts the string vectors to character arrays.