aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_test_util.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-08-08 18:00:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 18:06:53 -0700
commit8453e23a8b65423a4f2bfb2a98a928954771498f (patch)
treede12dd18c507eea466fb495fc8f58774a74e204a /tensorflow/c/c_test_util.cc
parentf2ff7b160794f2fd8ea5bafb1e910f1c14966202 (diff)
Added forked versions of stateless If and While ops. They should only be used,
when the if then/else body of If or the While body funcs do not have stateful ops. The are lowered to the same XLA ops. One use case is in the S4TF compiler: https://github.com/apple/swift/pull/18509 PiperOrigin-RevId: 207977126
Diffstat (limited to 'tensorflow/c/c_test_util.cc')
-rw-r--r--tensorflow/c/c_test_util.cc18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index 24eb6c069b..f15d9ee20a 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -26,6 +26,10 @@ limitations under the License.
using tensorflow::GraphDef;
using tensorflow::NodeDef;
+static void BoolDeallocator(void* data, size_t, void* arg) {
+ delete[] static_cast<bool*>(data);
+}
+
static void Int32Deallocator(void* data, size_t, void* arg) {
delete[] static_cast<int32_t*>(data);
}
@@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
+TF_Tensor* BoolTensor(bool v) {
+ const int num_bytes = sizeof(bool);
+ bool* values = new bool[1];
+ values[0] = v;
+ return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
+ nullptr);
+}
+
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
@@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
return op;
}
+TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
+ return Const(tensor.get(), graph, s, name);
+}
+
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);