aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_test_util.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-08-30 21:05:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 21:08:53 -0700
commit9624d165f1f2c717eda96464fee8bf7229cc14f5 (patch)
tree8024d708b58b0c78f19d4c3cfc9f7c4b0c24b70c /tensorflow/c/c_test_util.cc
parent424aa9aa9559f6fa29d8ccf3d74ff25528b39209 (diff)
Add function support to Tensorflow C API
This change adds minimal functionality. Support for FunctionOptions, attributes, output name rewriting, function name generation, etc is comming next. PiperOrigin-RevId: 167091238
Diffstat (limited to 'tensorflow/c/c_test_util.cc')
-rw-r--r--tensorflow/c/c_test_util.cc131
1 files changed, 123 insertions, 8 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index 21603c1a07..9cd978c97e 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/c/c_test_util.h"
+#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
using tensorflow::GraphDef;
@@ -36,6 +38,23 @@ TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
return t;
}
+TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
+ const int32_t* values) {
+ int64_t num_values = 1;
+ for (int i = 0; i < num_dims; ++i) {
+ num_values *= dims[i];
+ }
+ TF_Tensor* t =
+ TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values);
+ memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values);
+ return t;
+}
+
+TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) {
+ int64_t dims = values.size();
+ return Int32Tensor(&dims, 1, values.data());
+}
+
TF_Tensor* Int32Tensor(int32_t v) {
const int num_bytes = sizeof(int32_t);
int32_t* values = new int32_t[1];
@@ -44,19 +63,40 @@ TF_Tensor* Int32Tensor(int32_t v) {
&Int32Deallocator, nullptr);
}
-TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
+// All the *Helper methods are used as a workaround for the restrictions that
+// one cannot call ASSERT_* methods in non-void-returning functions (when
+// exceptions are disabled during compilation)
+void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
+ TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
- return TF_FinishOperation(desc, s);
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
}
-TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
- const char* name) {
+TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
+ TF_Operation* op;
+ PlaceholderHelper(graph, s, name, &op);
+ return op;
+}
+
+void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name,
+ TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", t, s);
- if (TF_GetCode(s) != TF_OK) return nullptr;
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
- return TF_FinishOperation(desc, s);
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+}
+
+TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ ConstHelper(t, graph, s, name, &op);
+ return op;
}
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
@@ -65,11 +105,39 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
return Const(tensor.get(), graph, s, name);
}
+void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s,
+ const char* name, TF_Operation** op, bool check) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
+ TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
+ TF_AddInputList(desc, add_inputs, 2);
+ *op = TF_FinishOperation(desc, s);
+ if (check) {
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+ }
+}
+
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
+ TF_Operation* op;
+ AddHelper(l, r, graph, s, name, &op, true);
+ return op;
+}
+
+TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name) {
+ TF_Operation* op;
+ AddHelper(l, r, graph, s, name, &op, false);
+ return op;
+}
+
+TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
+ TF_Graph* graph, TF_Operation* ctrl_op,
+ TF_Status* s, const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
+ TF_AddControlInput(desc, ctrl_op);
return TF_FinishOperation(desc, s);
}
@@ -81,11 +149,20 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
return TF_FinishOperation(desc, s);
}
-TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
+void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s,
+ TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
TF_Output neg_input = {n, 0};
TF_AddInput(desc, neg_input);
- return TF_FinishOperation(desc, s);
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+}
+
+TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
+ TF_Operation* op;
+ NegHelper(n, graph, s, &op);
+ return op;
}
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
@@ -96,6 +173,32 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
return TF_FinishOperation(desc, s);
}
+void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
+ const char* name, TF_Operation** op) {
+ TF_Operation* zero = ScalarConst(
+ 0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str());
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name);
+ TF_AddInput(desc, {zero, 0});
+ TF_AddInput(desc, {input, 0});
+ TF_SetAttrInt(desc, "num_split", 3);
+ TF_SetAttrType(desc, "T", TF_INT32);
+ // Set device to CPU since there is no version of split for int32 on GPU
+ // TODO(iga): Convert all these helpers and tests to use floats because
+ // they are usually available on GPUs. After doing this, remove TF_SetDevice
+ // call in c_api_function_test.cc
+ TF_SetDevice(desc, "/cpu:0");
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+}
+
+TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ Split3Helper(input, graph, s, name, &op);
+ return op;
+}
+
bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
return false;
@@ -196,6 +299,18 @@ bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
return ret;
}
+bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) {
+ TF_Status* s = TF_NewStatus();
+ TF_Buffer* buffer = TF_NewBuffer();
+ TF_FunctionToFunctionDef(func, buffer, s);
+ bool ret = TF_GetCode(s) == TF_OK;
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length);
+ TF_DeleteBuffer(buffer);
+ TF_DeleteStatus(s);
+ return ret;
+}
+
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s) {
TF_Buffer* buffer = TF_NewBuffer();