From 8453e23a8b65423a4f2bfb2a98a928954771498f Mon Sep 17 00:00:00 2001 From: Mingsheng Hong Date: Wed, 8 Aug 2018 18:00:05 -0700 Subject: Added forked versions of stateless If and While ops. They should only be used, when the if then/else body of If or the While body funcs do not have stateful ops. The are lowered to the same XLA ops. One use case is in the S4TF compiler: https://github.com/apple/swift/pull/18509 PiperOrigin-RevId: 207977126 --- tensorflow/c/c_api_function_test.cc | 61 +++++++++++++++++++++++++++++++++++++ tensorflow/c/c_test_util.cc | 18 +++++++++++ tensorflow/c/c_test_util.h | 5 +++ 3 files changed, 84 insertions(+) (limited to 'tensorflow/c') diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index bb9433ce25..73fe73769b 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1619,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { TF_DeleteFunction(func1); } +// This test only works when the TF build includes XLA compiler. One way to set +// this up is via bazel build option "--define with_xla_support=true". +// +// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to +// something like TENSORFLOW_CAPI_USE_XLA. +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST_F(CApiFunctionTest, StatelessIf_XLA) { + TF_Function* func; + const std::string funcName = "BranchFunc"; + DefineFunction(funcName.c_str(), &func); + TF_GraphCopyFunction(host_graph_, func, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* feed = Placeholder(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* true_cond = ScalarConst(true, host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_OperationDescription* desc = + TF_NewOperation(host_graph_, "StatelessIf", "IfNode"); + TF_AddInput(desc, {true_cond, 0}); + TF_Output inputs[] = {{feed, 0}}; + TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_SetAttrType(desc, "Tcond", TF_BOOL); + TF_DataType inputType = TF_INT32; + TF_SetAttrTypeList(desc, "Tin", &inputType, 1); + TF_SetAttrTypeList(desc, "Tout", &inputType, 1); + TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size()); + TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size()); + TF_SetDevice(desc, "/device:XLA_CPU:0"); + auto op = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(op, nullptr); + + // Create a session for this graph. + CSession csession(host_graph_, s_, /*use_XLA*/ true); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run the graph. + csession.SetInputs({{feed, Int32Tensor(17)}}); + csession.SetOutputs({op}); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* output_contents = static_cast(TF_TensorData(out)); + EXPECT_EQ(-17, *output_contents); + + // Clean up + csession.CloseAndDelete(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_DeleteFunction(func); +} +#endif // TENSORFLOW_EAGER_USE_XLA + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 24eb6c069b..f15d9ee20a 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -26,6 +26,10 @@ limitations under the License. using tensorflow::GraphDef; using tensorflow::NodeDef; +static void BoolDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } @@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +TF_Tensor* BoolTensor(bool v) { + const int num_bytes = sizeof(bool); + bool* values = new bool[1]; + values[0] = v; + return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator, + nullptr); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, return op; } +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name) { unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 38313d647c..7eeb1ee5e1 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -31,6 +31,8 @@ using ::tensorflow::string; typedef std::unique_ptr unique_tensor_ptr; +TF_Tensor* BoolTensor(int32_t v); + // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); @@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); -- cgit v1.2.3