aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_test_util.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-03-01 14:15:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-01 14:18:50 -0800
commitf8f4a6e26cc1108495c0b9a55d9a7d6e7005c2b5 (patch)
tree47566990e1e975b7f6e6d60083b454951020fee4 /tensorflow/c/c_test_util.cc
parent80710d5c53a8b2896a57dbe026d7f742e71fc03b (diff)
Internal change.
PiperOrigin-RevId: 187532378
Diffstat (limited to 'tensorflow/c/c_test_util.cc')
-rw-r--r--tensorflow/c/c_test_util.cc31
1 files changed, 28 insertions, 3 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index 3db2852ce6..53346a8cdf 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -34,6 +34,10 @@ static void DoubleDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<double*>(data);
}
+static void FloatDeallocator(void* data, size_t, void* arg) {
+ delete[] static_cast<float*>(data);
+}
+
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
@@ -78,13 +82,21 @@ TF_Tensor* DoubleTensor(double v) {
&DoubleDeallocator, nullptr);
}
+TF_Tensor* FloatTensor(float v) {
+ const int num_bytes = sizeof(float);
+ float* values = new float[1];
+ values[0] = v;
+ return TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes,
+ &FloatDeallocator, nullptr);
+}
+
// All the *Helper methods are used as a workaround for the restrictions that
// one cannot call ASSERT_* methods in non-void-returning functions (when
// exceptions are disabled during compilation)
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
- TF_Operation** op) {
+ TF_DataType dtype, TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
- TF_SetAttrType(desc, "dtype", TF_INT32);
+ TF_SetAttrType(desc, "dtype", dtype);
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
@@ -92,7 +104,14 @@ void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
TF_Operation* op;
- PlaceholderHelper(graph, s, name, &op);
+ PlaceholderHelper(graph, s, name, TF_INT32, &op);
+ return op;
+}
+
+TF_Operation* PlaceholderFloat(TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ PlaceholderHelper(graph, s, name, TF_FLOAT, &op);
return op;
}
@@ -126,6 +145,12 @@ TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
return Const(tensor.get(), graph, s, name);
}
+TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ unique_tensor_ptr tensor(FloatTensor(v), TF_DeleteTensor);
+ return Const(tensor.get(), graph, s, name);
+}
+
void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name, TF_Operation** op,
bool check) {