aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_test_util.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-07-17 08:29:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-17 08:33:52 -0700
commitbed7cbe389420abe760d86cd345e88d45eb624ab (patch)
treec4d5c9d789b06e7596b6473902c140638fbb5de7 /tensorflow/c/c_test_util.cc
parenta925c14596135b19bedb73e0f6ae3cd170180106 (diff)
Split out new while_loop_test.cc from c_api_test.cc
This change also separates shared functionality into c_test_util.h/cc. This brings c_api_test.cc to a mere 1715 LOC (further splits can be more easily done now too). PiperOrigin-RevId: 162216399
Diffstat (limited to 'tensorflow/c/c_test_util.cc')
-rw-r--r--tensorflow/c/c_test_util.cc304
1 files changed, 304 insertions, 0 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
new file mode 100644
index 0000000000..21603c1a07
--- /dev/null
+++ b/tensorflow/c/c_test_util.cc
@@ -0,0 +1,304 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/c_test_util.h"
+
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+
+static void Int32Deallocator(void* data, size_t, void* arg) {
+ delete[] static_cast<int32_t*>(data);
+}
+
+TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
+ int64_t num_values = 1;
+ for (int i = 0; i < num_dims; ++i) {
+ num_values *= dims[i];
+ }
+ TF_Tensor* t =
+ TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values);
+ memcpy(TF_TensorData(t), values, sizeof(char) * num_values);
+ return t;
+}
+
+TF_Tensor* Int32Tensor(int32_t v) {
+ const int num_bytes = sizeof(int32_t);
+ int32_t* values = new int32_t[1];
+ values[0] = v;
+ return TF_NewTensor(TF_INT32, nullptr, 0, values, num_bytes,
+ &Int32Deallocator, nullptr);
+}
+
+TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
+ TF_SetAttrType(desc, "dtype", TF_INT32);
+ return TF_FinishOperation(desc, s);
+}
+
+TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
+ TF_SetAttrTensor(desc, "value", t, s);
+ if (TF_GetCode(s) != TF_OK) return nullptr;
+ TF_SetAttrType(desc, "dtype", TF_TensorType(t));
+ return TF_FinishOperation(desc, s);
+}
+
+TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
+ return Const(tensor.get(), graph, s, name);
+}
+
+TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
+ TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
+ TF_AddInputList(desc, add_inputs, 2);
+ return TF_FinishOperation(desc, s);
+}
+
+TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
+ TF_Output inputs[2] = {l, r};
+ TF_AddInputList(desc, inputs, 2);
+ return TF_FinishOperation(desc, s);
+}
+
+TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
+ TF_Output neg_input = {n, 0};
+ TF_AddInput(desc, neg_input);
+ return TF_FinishOperation(desc, s);
+}
+
+TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
+ TF_Status* s) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than");
+ TF_AddInput(desc, l);
+ TF_AddInput(desc, r);
+ return TF_FinishOperation(desc, s);
+}
+
+bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
+ if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
+ return false;
+ }
+ bool found_dtype = false;
+ bool found_shape = false;
+ for (const auto& attr : node_def.attr()) {
+ if (attr.first == "dtype") {
+ if (attr.second.type() == tensorflow::DT_INT32) {
+ found_dtype = true;
+ } else {
+ return false;
+ }
+ } else if (attr.first == "shape") {
+ found_shape = true;
+ }
+ }
+ return found_dtype && found_shape;
+}
+
+bool IsScalarConst(const tensorflow::NodeDef& node_def, int v) {
+ if (node_def.op() != "Const" || node_def.name() != "scalar") {
+ return false;
+ }
+ bool found_dtype = false;
+ bool found_value = false;
+ for (const auto& attr : node_def.attr()) {
+ if (attr.first == "dtype") {
+ if (attr.second.type() == tensorflow::DT_INT32) {
+ found_dtype = true;
+ } else {
+ return false;
+ }
+ } else if (attr.first == "value") {
+ if (attr.second.has_tensor() &&
+ attr.second.tensor().int_val_size() == 1 &&
+ attr.second.tensor().int_val(0) == v) {
+ found_value = true;
+ } else {
+ return false;
+ }
+ }
+ }
+ return found_dtype && found_value;
+}
+
+bool IsAddN(const tensorflow::NodeDef& node_def, int n) {
+ if (node_def.op() != "AddN" || node_def.name() != "add" ||
+ node_def.input_size() != n) {
+ return false;
+ }
+ bool found_t = false;
+ bool found_n = false;
+ for (const auto& attr : node_def.attr()) {
+ if (attr.first == "T") {
+ if (attr.second.type() == tensorflow::DT_INT32) {
+ found_t = true;
+ } else {
+ return false;
+ }
+ } else if (attr.first == "N") {
+ if (attr.second.i() == n) {
+ found_n = true;
+ } else {
+ return false;
+ }
+ }
+ }
+ return found_t && found_n;
+}
+
+bool IsNeg(const tensorflow::NodeDef& node_def, const string& input) {
+ return node_def.op() == "Neg" && node_def.name() == "neg" &&
+ node_def.input_size() == 1 && node_def.input(0) == input;
+}
+
+bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def) {
+ TF_Status* s = TF_NewStatus();
+ TF_Buffer* buffer = TF_NewBuffer();
+ TF_GraphToGraphDef(graph, buffer, s);
+ bool ret = TF_GetCode(s) == TF_OK;
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
+ TF_DeleteBuffer(buffer);
+ TF_DeleteStatus(s);
+ return ret;
+}
+
+bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
+ TF_Status* s = TF_NewStatus();
+ TF_Buffer* buffer = TF_NewBuffer();
+ TF_OperationToNodeDef(oper, buffer, s);
+ bool ret = TF_GetCode(s) == TF_OK;
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length);
+ TF_DeleteBuffer(buffer);
+ TF_DeleteStatus(s);
+ return ret;
+}
+
+bool GetAttrValue(TF_Operation* oper, const char* attr_name,
+ tensorflow::AttrValue* attr_value, TF_Status* s) {
+ TF_Buffer* buffer = TF_NewBuffer();
+ TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
+ bool ret = TF_GetCode(s) == TF_OK;
+ if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length);
+ TF_DeleteBuffer(buffer);
+ return ret;
+}
+
+CSession::CSession(TF_Graph* graph, TF_Status* s) {
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ session_ = TF_NewSession(graph, opts, s);
+ TF_DeleteSessionOptions(opts);
+}
+
+CSession::CSession(TF_Session* session) : session_(session) {}
+
+CSession::~CSession() {
+ TF_Status* s = TF_NewStatus();
+ CloseAndDelete(s);
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_DeleteStatus(s);
+}
+
+void CSession::SetInputs(
+ std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
+ DeleteInputValues();
+ inputs_.clear();
+ for (const auto& p : inputs) {
+ inputs_.emplace_back(TF_Output{p.first, 0});
+ input_values_.emplace_back(p.second);
+ }
+}
+
+void CSession::SetOutputs(std::initializer_list<TF_Operation*> outputs) {
+ ResetOutputValues();
+ outputs_.clear();
+ for (TF_Operation* o : outputs) {
+ outputs_.emplace_back(TF_Output{o, 0});
+ }
+ output_values_.resize(outputs_.size());
+}
+
+void CSession::SetOutputs(const std::vector<TF_Output>& outputs) {
+ ResetOutputValues();
+ outputs_ = outputs;
+ output_values_.resize(outputs_.size());
+}
+
+void CSession::SetTargets(std::initializer_list<TF_Operation*> targets) {
+ targets_.clear();
+ for (TF_Operation* t : targets) {
+ targets_.emplace_back(t);
+ }
+}
+
+void CSession::Run(TF_Status* s) {
+ if (inputs_.size() != input_values_.size()) {
+ ADD_FAILURE() << "Call SetInputs() before Run()";
+ return;
+ }
+ ResetOutputValues();
+ output_values_.resize(outputs_.size(), nullptr);
+
+ const TF_Output* inputs_ptr = inputs_.empty() ? nullptr : &inputs_[0];
+ TF_Tensor* const* input_values_ptr =
+ input_values_.empty() ? nullptr : &input_values_[0];
+
+ const TF_Output* outputs_ptr = outputs_.empty() ? nullptr : &outputs_[0];
+ TF_Tensor** output_values_ptr =
+ output_values_.empty() ? nullptr : &output_values_[0];
+
+ TF_Operation* const* targets_ptr = targets_.empty() ? nullptr : &targets_[0];
+
+ TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr, inputs_.size(),
+ outputs_ptr, output_values_ptr, outputs_.size(), targets_ptr,
+ targets_.size(), nullptr, s);
+
+ DeleteInputValues();
+}
+
+void CSession::CloseAndDelete(TF_Status* s) {
+ DeleteInputValues();
+ ResetOutputValues();
+ if (session_ != nullptr) {
+ TF_CloseSession(session_, s);
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_DeleteSession(session_, s);
+ session_ = nullptr;
+ }
+}
+
+void CSession::DeleteInputValues() {
+ for (size_t i = 0; i < input_values_.size(); ++i) {
+ TF_DeleteTensor(input_values_[i]);
+ }
+ input_values_.clear();
+}
+
+void CSession::ResetOutputValues() {
+ for (size_t i = 0; i < output_values_.size(); ++i) {
+ if (output_values_[i] != nullptr) TF_DeleteTensor(output_values_[i]);
+ }
+ output_values_.clear();
+}