diff options
author | Mingsheng Hong <hongm@google.com> | 2018-08-08 18:00:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 18:06:53 -0700 |
commit | 8453e23a8b65423a4f2bfb2a98a928954771498f (patch) | |
tree | de12dd18c507eea466fb495fc8f58774a74e204a /tensorflow/c/c_test_util.cc | |
parent | f2ff7b160794f2fd8ea5bafb1e910f1c14966202 (diff) |
Added forked versions of stateless If and While ops. They should only be used,
when the if then/else body of If or the While body funcs do not have stateful
ops.
The are lowered to the same XLA ops.
One use case is in the S4TF compiler: https://github.com/apple/swift/pull/18509
PiperOrigin-RevId: 207977126
Diffstat (limited to 'tensorflow/c/c_test_util.cc')
-rw-r--r-- | tensorflow/c/c_test_util.cc | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 24eb6c069b..f15d9ee20a 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -26,6 +26,10 @@ limitations under the License. using tensorflow::GraphDef; using tensorflow::NodeDef; +static void BoolDeallocator(void* data, size_t, void* arg) { + delete[] static_cast<bool*>(data); +} + static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast<int32_t*>(data); } @@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) { delete[] static_cast<float*>(data); } +TF_Tensor* BoolTensor(bool v) { + const int num_bytes = sizeof(bool); + bool* values = new bool[1]; + values[0] = v; + return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator, + nullptr); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, return op; } +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name) { unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); |