aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_test_util.h
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-07-17 08:29:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-17 08:33:52 -0700
commitbed7cbe389420abe760d86cd345e88d45eb624ab (patch)
treec4d5c9d789b06e7596b6473902c140638fbb5de7 /tensorflow/c/c_test_util.h
parenta925c14596135b19bedb73e0f6ae3cd170180106 (diff)
Split out new while_loop_test.cc from c_api_test.cc
This change also separates shared functionality into c_test_util.h/cc. This brings c_api_test.cc to a mere 1715 LOC (further splits can be more easily done now too). PiperOrigin-RevId: 162216399
Diffstat (limited to 'tensorflow/c/c_test_util.h')
-rw-r--r--tensorflow/c/c_test_util.h102
1 files changed, 102 insertions, 0 deletions
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
new file mode 100644
index 0000000000..0c0ba667bd
--- /dev/null
+++ b/tensorflow/c/c_test_util.h
@@ -0,0 +1,102 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_
+
+#include "tensorflow/c/c_api.h"
+
+#include <vector>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+using ::tensorflow::string;
+
+typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
+ unique_tensor_ptr;
+
+// Create a tensor with values of type TF_INT8 provided by `values`.
+TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
+
+TF_Tensor* Int32Tensor(int32_t v);
+
+TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
+ const char* name = "feed");
+
+TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
+ const char* name = "const");
+
+TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
+ const char* name = "scalar");
+
+TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name = "add");
+
+TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
+ const char* name = "add");
+
+TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s);
+
+TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s);
+
+bool IsPlaceholder(const tensorflow::NodeDef& node_def);
+
+bool IsScalarConst(const tensorflow::NodeDef& node_def, int v);
+
+bool IsAddN(const tensorflow::NodeDef& node_def, int n);
+
+bool IsNeg(const tensorflow::NodeDef& node_def, const string& input);
+
+bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
+
+bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def);
+
+bool GetAttrValue(TF_Operation* oper, const char* attr_name,
+ tensorflow::AttrValue* attr_value, TF_Status* s);
+
+class CSession {
+ public:
+ CSession(TF_Graph* graph, TF_Status* s);
+ explicit CSession(TF_Session* session);
+
+ ~CSession();
+
+ void SetInputs(std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs);
+ void SetOutputs(std::initializer_list<TF_Operation*> outputs);
+ void SetOutputs(const std::vector<TF_Output>& outputs);
+ void SetTargets(std::initializer_list<TF_Operation*> targets);
+
+ void Run(TF_Status* s);
+
+ void CloseAndDelete(TF_Status* s);
+
+ TF_Tensor* output_tensor(int i) { return output_values_[i]; }
+
+ private:
+ void DeleteInputValues();
+ void ResetOutputValues();
+
+ TF_Session* session_;
+ std::vector<TF_Output> inputs_;
+ std::vector<TF_Tensor*> input_values_;
+ std::vector<TF_Output> outputs_;
+ std::vector<TF_Tensor*> output_values_;
+ std::vector<TF_Operation*> targets_;
+};
+
+#endif // THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_