diff options
author | Mingsheng Hong <hongm@google.com> | 2018-08-08 18:00:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 18:06:53 -0700 |
commit | 8453e23a8b65423a4f2bfb2a98a928954771498f (patch) | |
tree | de12dd18c507eea466fb495fc8f58774a74e204a | |
parent | f2ff7b160794f2fd8ea5bafb1e910f1c14966202 (diff) |
Added forked versions of stateless If and While ops. They should only be used,
when the if then/else body of If or the While body funcs do not have stateful
ops.
The are lowered to the same XLA ops.
One use case is in the S4TF compiler: https://github.com/apple/swift/pull/18509
PiperOrigin-RevId: 207977126
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 61 | ||||
-rw-r--r-- | tensorflow/c/c_test_util.cc | 18 | ||||
-rw-r--r-- | tensorflow/c/c_test_util.h | 5 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/if_op.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/while_op.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_StatelessIf.pbtxt | 43 | ||||
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_StatelessWhile.pbtxt | 36 | ||||
-rw-r--r-- | tensorflow/core/api_def/python_api/api_def_StatelessIf.pbtxt | 1 | ||||
-rw-r--r-- | tensorflow/core/api_def/python_api/api_def_StatelessWhile.pbtxt | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/functional_ops.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/ops/functional_ops.cc | 26 |
11 files changed, 198 insertions, 2 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index bb9433ce25..73fe73769b 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1619,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { TF_DeleteFunction(func1); } +// This test only works when the TF build includes XLA compiler. One way to set +// this up is via bazel build option "--define with_xla_support=true". +// +// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to +// something like TENSORFLOW_CAPI_USE_XLA. +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST_F(CApiFunctionTest, StatelessIf_XLA) { + TF_Function* func; + const std::string funcName = "BranchFunc"; + DefineFunction(funcName.c_str(), &func); + TF_GraphCopyFunction(host_graph_, func, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* feed = Placeholder(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* true_cond = ScalarConst(true, host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_OperationDescription* desc = + TF_NewOperation(host_graph_, "StatelessIf", "IfNode"); + TF_AddInput(desc, {true_cond, 0}); + TF_Output inputs[] = {{feed, 0}}; + TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_SetAttrType(desc, "Tcond", TF_BOOL); + TF_DataType inputType = TF_INT32; + TF_SetAttrTypeList(desc, "Tin", &inputType, 1); + TF_SetAttrTypeList(desc, "Tout", &inputType, 1); + TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size()); + TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size()); + TF_SetDevice(desc, "/device:XLA_CPU:0"); + auto op = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(op, nullptr); + + // Create a session for this graph. + CSession csession(host_graph_, s_, /*use_XLA*/ true); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run the graph. + csession.SetInputs({{feed, Int32Tensor(17)}}); + csession.SetOutputs({op}); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* output_contents = static_cast<int32*>(TF_TensorData(out)); + EXPECT_EQ(-17, *output_contents); + + // Clean up + csession.CloseAndDelete(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_DeleteFunction(func); +} +#endif // TENSORFLOW_EAGER_USE_XLA + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 24eb6c069b..f15d9ee20a 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -26,6 +26,10 @@ limitations under the License. using tensorflow::GraphDef; using tensorflow::NodeDef; +static void BoolDeallocator(void* data, size_t, void* arg) { + delete[] static_cast<bool*>(data); +} + static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast<int32_t*>(data); } @@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) { delete[] static_cast<float*>(data); } +TF_Tensor* BoolTensor(bool v) { + const int num_bytes = sizeof(bool); + bool* values = new bool[1]; + values[0] = v; + return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator, + nullptr); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, return op; } +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name) { unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 38313d647c..7eeb1ee5e1 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -31,6 +31,8 @@ using ::tensorflow::string; typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> unique_tensor_ptr; +TF_Tensor* BoolTensor(int32_t v); + // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); @@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index ceb2af756c..462e0e4395 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -247,6 +247,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp); +REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp); REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 1e8a376765..296518229e 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -301,6 +301,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); +REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp); REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); } // namespace tensorflow diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessIf.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessIf.pbtxt new file mode 100644 index 0000000000..c0a6ba15e6 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatelessIf.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "StatelessIf" + in_arg { name: "cond" description: "The predicate." } + in_arg { + name: "cond" + description: <<END + A Tensor. If the tensor is a scalar of non-boolean type, the + scalar is converted to a boolean according to the + following rule: if the scalar is a numerical value, non-zero means + `True` and zero means False; if the scalar is a string, non-empty + means `True` and empty means `False`. If the tensor is not a scalar, + being empty means False and being non-empty means True. + + This should only be used when the if then/else body functions do not + have stateful ops. +END + } + in_arg { + name: "input" + description: "A list of input tensors." + } + out_arg { + name: "output" + description: "A list of return values." + } + attr { name: "Tin" description: "A list of input types." } + attr { name: "Tout" description: "A list of output types." } + attr { + name: "then_branch" + description: <<END + A function that takes 'inputs' and returns a list of tensors, whose + types are the same as what else_branch returns. +END + } + attr { + name: "else_branch" + description: <<END + A function that takes 'inputs' and returns a list of tensors, whose + types are the same as what then_branch returns. +END + } + summary: "output = cond ? then_branch(input) : else_branch(input)" +} diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessWhile.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessWhile.pbtxt new file mode 100644 index 0000000000..87c0e09673 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatelessWhile.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "StatelessWhile" + in_arg { + name: "input" + description: "A list of input tensors whose types are T." + } + out_arg { + name: "output" + description: "A list of output tensors whose types are T." + } + attr { name: "T" description: "dtype in use." } + attr { + name: "cond" + description: <<END + A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. + + This should only be used when the while condition and body functions + do not have stateful ops. +END + } + attr { + name: "body" + description: <<END + A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified + by T. +END + } + summary: "output = input; While (Cond(output)) { output = Body(output) }" +} diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessIf.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessIf.pbtxt new file mode 100644 index 0000000000..0298c4852c --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_StatelessIf.pbtxt @@ -0,0 +1 @@ +op { graph_op_name: "StatelessIf" visibility: HIDDEN } diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessWhile.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessWhile.pbtxt new file mode 100644 index 0000000000..c138a71087 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_StatelessWhile.pbtxt @@ -0,0 +1 @@ +op { graph_op_name: "StatelessWhile" visibility: HIDDEN } diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index 1c0abf26cd..1529d2e336 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -218,6 +218,10 @@ REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"), REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp); REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp); +REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp); +REGISTER_KERNEL_BUILDER( + Name("StatelessIf").Device(DEVICE_GPU).HostMemory("cond"), IfOp); + class WhileOp : public AsyncOpKernel { public: explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { @@ -379,6 +383,9 @@ REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp); REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp); REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp); +REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp); +REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_GPU), WhileOp); + Status GetScalar(OpKernelContext* ctx, int index, int32* value, const char* label) { Tensor t = ctx->input(index); diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index a16ecccf00..bda4a75c5d 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -90,6 +90,17 @@ else_branch: A function that takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. )doc"); +REGISTER_OP("StatelessIf") + .Input("cond: Tcond") + .Input("input: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .Attr("then_branch: func") + .Attr("else_branch: func") + .SetShapeFn(shape_inference::UnknownShape); + REGISTER_OP("If") .Input("cond: Tcond") .Input("input: Tin") @@ -133,8 +144,6 @@ body: A function that takes a list of tensors and returns another by T. )doc"); -// TODO(b/37549631) setting the While Op to always be stateful is too -// conservative. REGISTER_OP("While") .Input("input: T") .Output("output: T") @@ -149,6 +158,19 @@ REGISTER_OP("While") return Status::OK(); }); +REGISTER_OP("StatelessWhile") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetShapeFn([](shape_inference::InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(i)); + } + return Status::OK(); + }); + REGISTER_OP("For") .Input("start: int32") .Input("limit: int32") |