diff options
author | Tong Shen <endlessroad@google.com> | 2018-09-07 18:41:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 18:45:44 -0700 |
commit | 3e1b06ee93d7a638db1fdd5f733d66064c1acf59 (patch) | |
tree | 98ef8a3a7ce89114e8f9a296ee6252fc739d0943 | |
parent | 3ea43a044e7515388ecf322437b08f4ced5674aa (diff) |
Add XLA token input/output to XlaIf and XlaWhile when necessary.
PiperOrigin-RevId: 212070721
-rw-r--r-- | tensorflow/compiler/tf2xla/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/if_op.cc | 30 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/if_op.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/while_op.cc | 31 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/while_op.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/side_effect_util.cc | 67 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/side_effect_util.h | 47 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 113 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.h | 23 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 68 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.cc | 11 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.h | 3 |
13 files changed, 403 insertions, 8 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 95004534b9..3821dced63 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -191,6 +191,7 @@ cc_library( ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", + ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", @@ -360,6 +361,7 @@ tf_cc_test( name = "xla_compiler_test", srcs = ["xla_compiler_test.cc"], deps = [ + ":side_effect_util", ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", @@ -371,6 +373,7 @@ tf_cc_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", @@ -632,3 +635,12 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "side_effect_util", + srcs = ["side_effect_util.cc"], + hdrs = ["side_effect_util.h"], + deps = [ + "//tensorflow/core:core_cpu", + ], +) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index c78538114f..46794f7b50 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -178,6 +178,7 @@ tf_kernel_library( hdrs = ["while_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", @@ -195,6 +196,7 @@ tf_kernel_library( hdrs = ["if_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 6e1dbf5472..56da50f140 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } // TODO(b/35949885): There is duplication here with the handling of the @@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { options.resolve_compile_time_constants = false; options.return_updated_values_for_all_resources = true; options.is_entry_computation = false; + options.add_token_input_output = has_token_input_output_; XlaCompiler* compiler = ctx->compiler(); XlaCompiler::CompilationResult then_result; @@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { std::vector<xla::XlaOp> inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = then_result.input_mapping[i] + 1; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "if" op. + std::vector<xla::XlaOp> token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(b, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); @@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } ctx->SetOutput(i, output_handle); } + if (has_token_input_output_) { + // Set token output for this "if" op. + xla::XlaOp token_output = + xla::GetTupleElement(outputs, output_types_.size()); + auto shape_or = b->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the conditional // bodies. diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index f9bc98a198..7783e13a8a 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel { DataType cond_type_; DataTypeVector input_types_; DataTypeVector output_types_; + bool has_token_input_output_; + std::vector<string> token_input_nodes_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 296518229e..559414eeaa 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/while_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { cond_name_attr_ = *name_attr; OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr)); body_name_attr_ = *name_attr; + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { @@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { body_options.return_updated_values_for_all_resources = true; body_options.resolve_compile_time_constants = false; body_options.is_entry_computation = false; + body_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult body; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, arguments, &body)); @@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { cond_options.use_tuple_arg = true; cond_options.resolve_compile_time_constants = false; cond_options.is_entry_computation = false; + cond_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult cond; OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); @@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { std::vector<xla::XlaOp> inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = body.input_mapping[i]; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "while" op. + std::vector<xla::XlaOp> token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(builder, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder)); @@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::GetTupleElement(while_result, i)); } } + if (has_token_input_output_) { + // Set token output for this "while" op. + xla::XlaOp token_output = + xla::GetTupleElement(while_result, ctx->num_outputs()); + auto shape_or = builder->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the loop. for (int i = 0; i < body.resource_updates.size(); ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index 67edebabf9..aeeff40e68 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel { private: NameAttrList cond_name_attr_; NameAttrList body_name_attr_; + bool has_token_input_output_; + std::vector<string> token_input_nodes_; TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); }; diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc new file mode 100644 index 0000000000..6cd7b24592 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/side_effect_util.h" + +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; + +const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; + +std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) { + std::set<std::string> results; + Node* first_side_effecting_node_on_path = nullptr; + ReverseDFS(g, + [&](Node* n) { + std::vector<string> token_input_nodes; + if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, + &token_input_nodes) + .ok() || + token_input_nodes.empty()) { + return; + } + + if (first_side_effecting_node_on_path != nullptr) { + return; + } + + first_side_effecting_node_on_path = n; + results.insert(n->name()); + }, + [&](Node* n) { + if (first_side_effecting_node_on_path == n) { + first_side_effecting_node_on_path = nullptr; + } + }, + NodeComparatorName()); + return results; +} + +bool HasSideEffectingNodes(const Graph& g) { + for (Node* n : g.nodes()) { + std::vector<string> token_input_nodes; + if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes) + .ok() && + !token_input_nodes.empty()) { + return true; + } + } + return false; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h new file mode 100644 index 0000000000..ad07624729 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ + +#include <vector> + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Side-effecting nodes will have this attribute set. Its value is the list of +// node names which this node has side-effect dependencies on. +// +// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute, +// because they always have side-effect. +// If and While nodes may or may not have this attribute, depending on whether +// their bodies have side-effecting nodes. +extern const char kXlaTokenInputNodesAttrName[]; + +// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a +// node has side-effect dependency on current graph's token input. +extern const char kXlaTokenArgNodeName[]; + +// Calculates side-effect dependencies for the graph's token output. +// Returns a set of node names representing these dependencies. +std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g); + +// Returns whether a graph contains side-effecting nodes. +bool HasSideEffectingNodes(const Graph& g); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 41d305d461..dcb455779d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" @@ -291,6 +292,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, "Invalid resource type in XLAShapeForArgument()"); } } + case XlaCompiler::Argument::kToken: { + *xla_shape = xla::ShapeUtil::MakeTokenShape(); + return Status::OK(); + } case XlaCompiler::Argument::kInvalid: return errors::Internal("Invalid argument type in XLAShapeForArgument()"); } @@ -489,7 +494,8 @@ Status XlaCompiler::BuildArguments( } break; - case XlaCompiler::Argument::kParameter: { + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kToken: { input_mapping->push_back(i); break; } @@ -616,6 +622,10 @@ Status XlaCompiler::BuildArguments( arg_expression.set_handle(arg_handles[i]); } break; + case XlaCompiler::Argument::kToken: { + arg_expression.set_handle(arg_handles[i]); + break; + } case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: return errors::Internal( @@ -757,23 +767,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, &options_.shape_representation_fn); core::ScopedUnref context_unref(context); + std::vector<XlaCompiler::Argument> real_args(args); + int token_input_index = -1; + if (options.add_token_input_output) { + // Add extra token input. + token_input_index = real_args.size(); + + XlaCompiler::Argument token_arg; + token_arg.kind = XlaCompiler::Argument::kToken; + real_args.push_back(token_arg); + } + std::vector<XlaExpression> arg_expressions; std::vector<int> arg_cores; - TF_RETURN_IF_ERROR( - BuildArguments(*graph, args, options.use_tuple_arg, &builder, context, - &arg_cores, &arg_expressions, &result->input_mapping, - &result->xla_input_shapes, options.is_entry_computation)); + TF_RETURN_IF_ERROR(BuildArguments( + *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + &arg_expressions, &result->input_mapping, &result->xla_input_shapes, + options.is_entry_computation)); context->set_args(std::move(arg_expressions)); + PushNodeTokenMapping(); + // Use std::set instead of std::unordered_set to ensure determinism. + std::set<std::string> output_node_token_inputs; + if (token_input_index != -1) { + // Original token comes from input. + auto arg_expression = context->args()[token_input_index]; + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle())); + + // Calculate token inputs for output token. + output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph); + + // If there's no side-effecting op in the graph, use token input as token + // output. + if (output_node_token_inputs.empty()) { + output_node_token_inputs.insert(kXlaTokenArgNodeName); + } + } else if (options.is_entry_computation) { + // Original token is manually created. + if (HasSideEffectingNodes(*graph)) { + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder))); + } + } + TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, flib_runtime_, NextStepId())); + if (token_input_index != -1) { + // Add extra token output. + std::vector<xla::XlaOp> token_inputs; + for (const auto& node_name : output_node_token_inputs) { + auto token_or = GetNodeToken(node_name); + TF_RETURN_IF_ERROR(token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + TF_RETURN_IF_ERROR( + context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs))); + } + TF_RETURN_IF_ERROR(PopNodeTokenMapping()); int num_nonconst_outputs; int num_computation_outputs; result->computation = std::make_shared<xla::XlaComputation>(); result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( - args, arg_cores, context->retvals(), context->resources(), + real_args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, @@ -912,4 +970,47 @@ Status XlaCompiler::SetHostComputeControlDependency( return Status::OK(); } +void XlaCompiler::PushNodeTokenMapping() { + node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{}); +} + +Status XlaCompiler::PopNodeTokenMapping() { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is " + "empty."); + } + node_token_mapping_stack_.pop(); + return Status::OK(); +} + +Status XlaCompiler::SetNodeToken(const string& node_name, + const xla::XlaOp& op) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling SetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto insert_result = node_token_mapping_stack_.top().insert({node_name, op}); + if (!insert_result.second) { + return errors::FailedPrecondition("Token mapping already exists for node ", + node_name); + } + return Status::OK(); +} + +xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling GetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto iter = node_token_mapping_stack_.top().find(node_name); + if (iter == node_token_mapping_stack_.top().end()) { + return errors::FailedPrecondition("Cannot find token mapping for node ", + node_name); + } + return iter->second; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 8f4a9858ed..2cc603a580 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#include <stack> + #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -106,6 +109,9 @@ class XlaCompiler { // Argument is a run-time parameter. kParameter, + + // Argument is an XLA token. + kToken, }; Kind kind = kInvalid; @@ -179,6 +185,9 @@ class XlaCompiler { // True when compiling the entry computation, false for subcomputations // (while, call, etc.) bool is_entry_computation = true; + + // True when we should add XLA input & output to the graph/function. + bool add_token_input_output = false; }; struct OutputDescription { @@ -384,6 +393,11 @@ class XlaCompiler { xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } + void PushNodeTokenMapping(); + Status PopNodeTokenMapping(); + Status SetNodeToken(const string& node_name, const xla::XlaOp& op); + xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name); + private: // Sets the function body `fbody` to the one registered as `function`. Status FindFunctionBody(const NameAttrList& function, @@ -448,6 +462,15 @@ class XlaCompiler { std::unordered_map<string, xla::XlaOp> host_compute_control_output_; + // This is used to store <node name, token output> mapping. Side-effecting + // ops call SetNodeToken() to record its token output, so later side-effecting + // ops can use GetNodeToken() to get it and use it as token input. + // + // It's a stack because we need a mapping like this for each level of nested + // CompileGraph() call. In CompileGraph(), we will push a new mapping to the + // stack, and pop the mapping before returning. + std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index be3c93ae47..40ce9fb41c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -32,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -1274,5 +1277,70 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { } } +class DummySideEffectingOp : public XlaOpKernel { + public: + explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken( + name(), xla::CreateToken(ctx->builder()))); + } +}; + +REGISTER_OP("DummySideEffectingOp"); + +REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp); + +TEST_F(XlaCompilerTest, TokenInputAndOutput) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + NodeDef side_effecting_op; + side_effecting_op.set_name("DummySideEffectingOp"); + side_effecting_op.set_op("DummySideEffectingOp"); + AddNodeAttr(kXlaTokenInputNodesAttrName, + std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op); + Status status; + graph->AddNode(side_effecting_op, &status); + TF_ASSERT_OK(status); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get())); + + const std::vector<XlaCompiler::Argument> empty_args; + { + // The case for entry computation: we don't add token input/output. Instead, + // we use CreateToken HLO to create the entry token. + XlaCompiler::CompileOptions options; + options.is_entry_computation = true; + options.add_token_input_output = false; + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 0); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0); + } + { + // The case for non-entry computation (e.g. while loop body). We add token + // input/output. + XlaCompiler::CompileOptions options; + options.is_entry_computation = false; + options.add_token_input_output = true; + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0])); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken( + xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0))); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index e8b4b0eb36..f247570d72 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -119,6 +119,17 @@ Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { return Status::OK(); } +Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) { + VLOG(1) << "Adding retval index " << retvals_.size() + << " with token to XLA computation"; + XlaExpression e; + e.set_handle(token); + // We use DT_INVALID because there is no TF DataType which corresponds to XLA + // token. XlaCompiler handles this case separately, so putting it here is OK. + retvals_.push_back(Retval{DT_INVALID, TensorShape(), e}); + return Status::OK(); +} + xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 4da891634e..d7dbdc957f 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -89,6 +89,9 @@ class XlaContext : public ResourceBase { // As for Retval, but for return values that are resource handles. Status AddResourceRetval(int retval_index, XlaResource* resource); + // As for Retval, but for return values that are XLA tokens. + Status AppendTokenRetval(const xla::XlaOp& token); + // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` // constructor for a description of the remaining arguments. |