aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-08-08 18:00:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 18:06:53 -0700
commit8453e23a8b65423a4f2bfb2a98a928954771498f (patch)
treede12dd18c507eea466fb495fc8f58774a74e204a /tensorflow/c
parentf2ff7b160794f2fd8ea5bafb1e910f1c14966202 (diff)
Added forked versions of stateless If and While ops. They should only be used,
when the if then/else body of If or the While body funcs do not have stateful ops. The are lowered to the same XLA ops. One use case is in the S4TF compiler: https://github.com/apple/swift/pull/18509 PiperOrigin-RevId: 207977126
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/c_api_function_test.cc61
-rw-r--r--tensorflow/c/c_test_util.cc18
-rw-r--r--tensorflow/c/c_test_util.h5
3 files changed, 84 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index bb9433ce25..73fe73769b 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -1619,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
TF_DeleteFunction(func1);
}
+// This test only works when the TF build includes XLA compiler. One way to set
+// this up is via bazel build option "--define with_xla_support=true".
+//
+// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
+// something like TENSORFLOW_CAPI_USE_XLA.
+#ifdef TENSORFLOW_EAGER_USE_XLA
+TEST_F(CApiFunctionTest, StatelessIf_XLA) {
+ TF_Function* func;
+ const std::string funcName = "BranchFunc";
+ DefineFunction(funcName.c_str(), &func);
+ TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_Operation* feed = Placeholder(host_graph_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_OperationDescription* desc =
+ TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
+ TF_AddInput(desc, {true_cond, 0});
+ TF_Output inputs[] = {{feed, 0}};
+ TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_SetAttrType(desc, "Tcond", TF_BOOL);
+ TF_DataType inputType = TF_INT32;
+ TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
+ TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
+ TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
+ TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
+ TF_SetDevice(desc, "/device:XLA_CPU:0");
+ auto op = TF_FinishOperation(desc, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ ASSERT_NE(op, nullptr);
+
+ // Create a session for this graph.
+ CSession csession(host_graph_, s_, /*use_XLA*/ true);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Run the graph.
+ csession.SetInputs({{feed, Int32Tensor(17)}});
+ csession.SetOutputs({op});
+ csession.Run(s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_Tensor* out = csession.output_tensor(0);
+ ASSERT_TRUE(out != nullptr);
+ EXPECT_EQ(TF_INT32, TF_TensorType(out));
+ EXPECT_EQ(0, TF_NumDims(out)); // scalar
+ ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
+ int32* output_contents = static_cast<int32*>(TF_TensorData(out));
+ EXPECT_EQ(-17, *output_contents);
+
+ // Clean up
+ csession.CloseAndDelete(s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_DeleteFunction(func);
+}
+#endif // TENSORFLOW_EAGER_USE_XLA
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index 24eb6c069b..f15d9ee20a 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -26,6 +26,10 @@ limitations under the License.
using tensorflow::GraphDef;
using tensorflow::NodeDef;
+static void BoolDeallocator(void* data, size_t, void* arg) {
+ delete[] static_cast<bool*>(data);
+}
+
static void Int32Deallocator(void* data, size_t, void* arg) {
delete[] static_cast<int32_t*>(data);
}
@@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
+TF_Tensor* BoolTensor(bool v) {
+ const int num_bytes = sizeof(bool);
+ bool* values = new bool[1];
+ values[0] = v;
+ return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
+ nullptr);
+}
+
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
@@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
return op;
}
+TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
+ return Const(tensor.get(), graph, s, name);
+}
+
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index 38313d647c..7eeb1ee5e1 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -31,6 +31,8 @@ using ::tensorflow::string;
typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
unique_tensor_ptr;
+TF_Tensor* BoolTensor(int32_t v);
+
// Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
@@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name = "const");
+TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
+ const char* name = "scalar");
+
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");