diff options
author | 2016-11-09 09:10:34 -0800 | |
---|---|---|
committer | 2016-11-09 09:22:58 -0800 | |
commit | d9da9721f45950035f5087c59f9bc6910e232271 (patch) | |
tree | 788dbff3186d9c03cc5e069081ca09a42334c931 | |
parent | 2201ada5a00bb06148a9baf0d50c4dd07ecc4864 (diff) |
C API: Slight re-organization of code that deletes input tensors.
This change itself is a no-op. However, my plan is to change the
API contract with the TF_Session*Run functions in a follow up
change so that they do NOT take ownership of the input tensor values.
This just makes it slightly easier to do so.
Change: 138646376
-rw-r--r-- | tensorflow/c/c_api.cc | 54 |
1 files changed, 29 insertions, 25 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index ced231396a..35fef4f741 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -443,25 +443,25 @@ static bool TF_Run_Inputs( std::vector<std::pair<tensorflow::string, Tensor>>* input_pairs, TF_Status* status) { const int ninputs = input_pairs->size(); - bool ok = true; for (int i = 0; i < ninputs; ++i) { TF_Tensor* src = c_inputs[i]; - if (ok) { - if (c_inputs[i]->dtype != TF_STRING) { - (*input_pairs)[i].second = tensorflow::TensorCApi::MakeTensor( - src->dtype, src->shape, src->buffer); - } else { - // TF_STRING tensors require copying since Tensor class expects - // a sequence of string objects. - ok = tensorflow::TF_Tensor_DecodeStrings(src, &(*input_pairs)[i].second, - status); - // Must keep looping through all c_inputs even if there is an error - // so that TF_DeleteTensor() is called unconditionally on all c_inputs. - } + if (c_inputs[i]->dtype != TF_STRING) { + (*input_pairs)[i].second = tensorflow::TensorCApi::MakeTensor( + src->dtype, src->shape, src->buffer); + } else if (!tensorflow::TF_Tensor_DecodeStrings( + src, &(*input_pairs)[i].second, status)) { + // TF_STRING tensors require copying since Tensor class expects + // a sequence of string objects. + return false; } - TF_DeleteTensor(src); } - return ok; + return true; +} + +static void TF_DeleteTensors(TF_Tensor* const* tensors, int num) { + for (int i = 0; i < num; ++i) { + TF_DeleteTensor(tensors[i]); + } } static void TF_Run_Helper( @@ -542,7 +542,9 @@ void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, TF_Buffer* run_metadata, TF_Status* status) { TF_Run_Setup(noutputs, c_outputs, status); std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); - if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; + const bool ok = TF_Run_Inputs(c_inputs, &input_pairs, status); + TF_DeleteTensors(c_inputs, ninputs); + if (!ok) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = c_input_names[i]; } @@ -603,7 +605,9 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle, TF_Status* status) { TF_Run_Setup(noutputs, c_outputs, status); std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); - if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; + const bool ok = TF_Run_Inputs(c_inputs, &input_pairs, status); + TF_DeleteTensors(c_inputs, ninputs); + if (!ok) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = c_input_names[i]; } @@ -1695,9 +1699,7 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). if (!ExtendSessionGraphHelper(session, status)) { - for (int i = 0; i < ninputs; ++i) { - TF_DeleteTensor(input_values[i]); - } + TF_DeleteTensors(input_values, ninputs); return; } @@ -1705,7 +1707,9 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, // Convert from TF_Port and TF_Tensor to a string and Tensor. std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); - if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; + const bool ok = TF_Run_Inputs(input_values, &input_pairs, status); + TF_DeleteTensors(input_values, ninputs); + if (!ok) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = PortName(inputs[i]); } @@ -1771,9 +1775,7 @@ void TF_SessionPRun(TF_Session* session, const char* handle, // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). if (!ExtendSessionGraphHelper(session, status)) { - for (int i = 0; i < ninputs; ++i) { - TF_DeleteTensor(input_values[i]); - } + TF_DeleteTensors(input_values, ninputs); return; } @@ -1781,7 +1783,9 @@ void TF_SessionPRun(TF_Session* session, const char* handle, // Convert from TF_Port and TF_Tensor to a string and Tensor. std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); - if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; + const bool ok = TF_Run_Inputs(input_values, &input_pairs, status); + TF_DeleteTensors(input_values, ninputs); + if (!ok) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = PortName(inputs[i]); } |