aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-09-07 18:41:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 18:45:44 -0700
commit3e1b06ee93d7a638db1fdd5f733d66064c1acf59 (patch)
tree98ef8a3a7ce89114e8f9a296ee6252fc739d0943
parent3ea43a044e7515388ecf322437b08f4ced5674aa (diff)
Add XLA token input/output to XlaIf and XlaWhile when necessary.
PiperOrigin-RevId: 212070721
-rw-r--r--tensorflow/compiler/tf2xla/BUILD12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc30
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.h2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc31
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.h2
-rw-r--r--tensorflow/compiler/tf2xla/side_effect_util.cc67
-rw-r--r--tensorflow/compiler/tf2xla/side_effect_util.h47
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc113
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h23
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc68
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc11
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h3
13 files changed, 403 insertions, 8 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 95004534b9..3821dced63 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -191,6 +191,7 @@ cc_library(
":functionalize_control_flow",
":host_compute_metadata_proto",
":sharding_util",
+ ":side_effect_util",
":tf2xla_util",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal",
@@ -360,6 +361,7 @@ tf_cc_test(
name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"],
deps = [
+ ":side_effect_util",
":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
@@ -371,6 +373,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:core_cpu_internal",
@@ -632,3 +635,12 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
+
+cc_library(
+ name = "side_effect_util",
+ srcs = ["side_effect_util.cc"],
+ hdrs = ["side_effect_util.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ ],
+)
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index c78538114f..46794f7b50 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -178,6 +178,7 @@ tf_kernel_library(
hdrs = ["while_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
@@ -195,6 +196,7 @@ tf_kernel_library(
hdrs = ["if_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index 6e1dbf5472..56da50f140 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/if_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
+ if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+ has_token_input_output_ = false;
+ } else {
+ has_token_input_output_ = !token_input_nodes_.empty();
+ }
}
// TODO(b/35949885): There is duplication here with the handling of the
@@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
options.resolve_compile_time_constants = false;
options.return_updated_values_for_all_resources = true;
options.is_entry_computation = false;
+ options.add_token_input_output = has_token_input_output_;
XlaCompiler* compiler = ctx->compiler();
XlaCompiler::CompilationResult then_result;
@@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = then_result.input_mapping[i] + 1;
- if (ctx->input_type(input_num) == DT_RESOURCE) {
+ if (has_token_input_output_ && i == num_inputs - 1) {
+ // Set token input for this "if" op.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const string& node_name : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(node_name);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ inputs[i] = xla::AfterAll(b, token_inputs);
+ } else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
@@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
ctx->SetOutput(i, output_handle);
}
+ if (has_token_input_output_) {
+ // Set token output for this "if" op.
+ xla::XlaOp token_output =
+ xla::GetTupleElement(outputs, output_types_.size());
+ auto shape_or = b->GetShape(token_output);
+ OP_REQUIRES_OK(ctx, shape_or.status());
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+ errors::FailedPrecondition(
+ "Token output is not token type: ",
+ xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+ OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+ }
// Updates the values of any resource variables modified by the conditional
// bodies.
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h
index f9bc98a198..7783e13a8a 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.h
@@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel {
DataType cond_type_;
DataTypeVector input_types_;
DataTypeVector output_types_;
+ bool has_token_input_output_;
+ std::vector<string> token_input_nodes_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 296518229e..559414eeaa 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
cond_name_attr_ = *name_attr;
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
body_name_attr_ = *name_attr;
+ if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+ has_token_input_output_ = false;
+ } else {
+ has_token_input_output_ = !token_input_nodes_.empty();
+ }
}
void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
@@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
body_options.return_updated_values_for_all_resources = true;
body_options.resolve_compile_time_constants = false;
body_options.is_entry_computation = false;
+ body_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult body;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
@@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
cond_options.use_tuple_arg = true;
cond_options.resolve_compile_time_constants = false;
cond_options.is_entry_computation = false;
+ cond_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult cond;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
arguments, &cond));
@@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
- if (ctx->input_type(input_num) == DT_RESOURCE) {
+ if (has_token_input_output_ && i == num_inputs - 1) {
+ // Set token input for this "while" op.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const string& node_name : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(node_name);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ inputs[i] = xla::AfterAll(builder, token_inputs);
+ } else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
@@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla::GetTupleElement(while_result, i));
}
}
+ if (has_token_input_output_) {
+ // Set token output for this "while" op.
+ xla::XlaOp token_output =
+ xla::GetTupleElement(while_result, ctx->num_outputs());
+ auto shape_or = builder->GetShape(token_output);
+ OP_REQUIRES_OK(ctx, shape_or.status());
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+ errors::FailedPrecondition(
+ "Token output is not token type: ",
+ xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+ OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+ }
// Updates the values of any resource variables modified by the loop.
for (int i = 0; i < body.resource_updates.size(); ++i) {
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h
index 67edebabf9..aeeff40e68 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.h
@@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel {
private:
NameAttrList cond_name_attr_;
NameAttrList body_name_attr_;
+ bool has_token_input_output_;
+ std::vector<string> token_input_nodes_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp);
};
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc
new file mode 100644
index 0000000000..6cd7b24592
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/side_effect_util.cc
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
+
+#include "tensorflow/core/graph/algorithm.h"
+
+namespace tensorflow {
+
+const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes";
+
+const char kXlaTokenArgNodeName[] = "_xla_token_arg_node";
+
+std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) {
+ std::set<std::string> results;
+ Node* first_side_effecting_node_on_path = nullptr;
+ ReverseDFS(g,
+ [&](Node* n) {
+ std::vector<string> token_input_nodes;
+ if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName,
+ &token_input_nodes)
+ .ok() ||
+ token_input_nodes.empty()) {
+ return;
+ }
+
+ if (first_side_effecting_node_on_path != nullptr) {
+ return;
+ }
+
+ first_side_effecting_node_on_path = n;
+ results.insert(n->name());
+ },
+ [&](Node* n) {
+ if (first_side_effecting_node_on_path == n) {
+ first_side_effecting_node_on_path = nullptr;
+ }
+ },
+ NodeComparatorName());
+ return results;
+}
+
+bool HasSideEffectingNodes(const Graph& g) {
+ for (Node* n : g.nodes()) {
+ std::vector<string> token_input_nodes;
+ if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes)
+ .ok() &&
+ !token_input_nodes.empty()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h
new file mode 100644
index 0000000000..ad07624729
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/side_effect_util.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
+
+#include <vector>
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Side-effecting nodes will have this attribute set. Its value is the list of
+// node names which this node has side-effect dependencies on.
+//
+// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute,
+// because they always have side-effect.
+// If and While nodes may or may not have this attribute, depending on whether
+// their bodies have side-effecting nodes.
+extern const char kXlaTokenInputNodesAttrName[];
+
+// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a
+// node has side-effect dependency on current graph's token input.
+extern const char kXlaTokenArgNodeName[];
+
+// Calculates side-effect dependencies for the graph's token output.
+// Returns a set of node names representing these dependencies.
+std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g);
+
+// Returns whether a graph contains side-effecting nodes.
+bool HasSideEffectingNodes(const Graph& g);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 41d305d461..dcb455779d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
@@ -291,6 +292,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
"Invalid resource type in XLAShapeForArgument()");
}
}
+ case XlaCompiler::Argument::kToken: {
+ *xla_shape = xla::ShapeUtil::MakeTokenShape();
+ return Status::OK();
+ }
case XlaCompiler::Argument::kInvalid:
return errors::Internal("Invalid argument type in XLAShapeForArgument()");
}
@@ -489,7 +494,8 @@ Status XlaCompiler::BuildArguments(
}
break;
- case XlaCompiler::Argument::kParameter: {
+ case XlaCompiler::Argument::kParameter:
+ case XlaCompiler::Argument::kToken: {
input_mapping->push_back(i);
break;
}
@@ -616,6 +622,10 @@ Status XlaCompiler::BuildArguments(
arg_expression.set_handle(arg_handles[i]);
}
break;
+ case XlaCompiler::Argument::kToken: {
+ arg_expression.set_handle(arg_handles[i]);
+ break;
+ }
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kInvalid:
return errors::Internal(
@@ -757,23 +767,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
&options_.shape_representation_fn);
core::ScopedUnref context_unref(context);
+ std::vector<XlaCompiler::Argument> real_args(args);
+ int token_input_index = -1;
+ if (options.add_token_input_output) {
+ // Add extra token input.
+ token_input_index = real_args.size();
+
+ XlaCompiler::Argument token_arg;
+ token_arg.kind = XlaCompiler::Argument::kToken;
+ real_args.push_back(token_arg);
+ }
+
std::vector<XlaExpression> arg_expressions;
std::vector<int> arg_cores;
- TF_RETURN_IF_ERROR(
- BuildArguments(*graph, args, options.use_tuple_arg, &builder, context,
- &arg_cores, &arg_expressions, &result->input_mapping,
- &result->xla_input_shapes, options.is_entry_computation));
+ TF_RETURN_IF_ERROR(BuildArguments(
+ *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
+ &arg_expressions, &result->input_mapping, &result->xla_input_shapes,
+ options.is_entry_computation));
context->set_args(std::move(arg_expressions));
+ PushNodeTokenMapping();
+ // Use std::set instead of std::unordered_set to ensure determinism.
+ std::set<std::string> output_node_token_inputs;
+ if (token_input_index != -1) {
+ // Original token comes from input.
+ auto arg_expression = context->args()[token_input_index];
+ TF_RETURN_IF_ERROR(
+ SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
+
+ // Calculate token inputs for output token.
+ output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
+
+ // If there's no side-effecting op in the graph, use token input as token
+ // output.
+ if (output_node_token_inputs.empty()) {
+ output_node_token_inputs.insert(kXlaTokenArgNodeName);
+ }
+ } else if (options.is_entry_computation) {
+ // Original token is manually created.
+ if (HasSideEffectingNodes(*graph)) {
+ TF_RETURN_IF_ERROR(
+ SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
+ }
+ }
+
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
flib_runtime_, NextStepId()));
+ if (token_input_index != -1) {
+ // Add extra token output.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const auto& node_name : output_node_token_inputs) {
+ auto token_or = GetNodeToken(node_name);
+ TF_RETURN_IF_ERROR(token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ TF_RETURN_IF_ERROR(
+ context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs)));
+ }
+ TF_RETURN_IF_ERROR(PopNodeTokenMapping());
int num_nonconst_outputs;
int num_computation_outputs;
result->computation = std::make_shared<xla::XlaComputation>();
result->outputs.resize(context->retvals().size());
TF_RETURN_IF_ERROR(BuildComputation(
- args, arg_cores, context->retvals(), context->resources(),
+ real_args, arg_cores, context->retvals(), context->resources(),
options.return_updated_values_for_all_resources,
options.always_return_tuple, &builder, result->computation.get(),
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
@@ -912,4 +970,47 @@ Status XlaCompiler::SetHostComputeControlDependency(
return Status::OK();
}
+void XlaCompiler::PushNodeTokenMapping() {
+ node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
+}
+
+Status XlaCompiler::PopNodeTokenMapping() {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ node_token_mapping_stack_.pop();
+ return Status::OK();
+}
+
+Status XlaCompiler::SetNodeToken(const string& node_name,
+ const xla::XlaOp& op) {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling SetNodeToken() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
+ if (!insert_result.second) {
+ return errors::FailedPrecondition("Token mapping already exists for node ",
+ node_name);
+ }
+ return Status::OK();
+}
+
+xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling GetNodeToken() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ auto iter = node_token_mapping_stack_.top().find(node_name);
+ if (iter == node_token_mapping_stack_.top().end()) {
+ return errors::FailedPrecondition("Cannot find token mapping for node ",
+ node_name);
+ }
+ return iter->second;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 8f4a9858ed..2cc603a580 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
+#include <stack>
+
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -26,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
@@ -106,6 +109,9 @@ class XlaCompiler {
// Argument is a run-time parameter.
kParameter,
+
+ // Argument is an XLA token.
+ kToken,
};
Kind kind = kInvalid;
@@ -179,6 +185,9 @@ class XlaCompiler {
// True when compiling the entry computation, false for subcomputations
// (while, call, etc.)
bool is_entry_computation = true;
+
+ // True when we should add XLA input & output to the graph/function.
+ bool add_token_input_output = false;
};
struct OutputDescription {
@@ -384,6 +393,11 @@ class XlaCompiler {
xla::Client* client() const { return options_.client; }
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
+ void PushNodeTokenMapping();
+ Status PopNodeTokenMapping();
+ Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
+ xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
+
private:
// Sets the function body `fbody` to the one registered as `function`.
Status FindFunctionBody(const NameAttrList& function,
@@ -448,6 +462,15 @@ class XlaCompiler {
std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
+ // This is used to store <node name, token output> mapping. Side-effecting
+ // ops call SetNodeToken() to record its token output, so later side-effecting
+ // ops can use GetNodeToken() to get it and use it as token input.
+ //
+ // It's a stack because we need a mapping like this for each level of nested
+ // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
+ // stack, and pop the mapping before returning.
+ std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
+
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
};
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index be3c93ae47..40ce9fb41c 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -20,10 +20,12 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -32,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -1274,5 +1277,70 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
}
}
+class DummySideEffectingOp : public XlaOpKernel {
+ public:
+ explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
+ name(), xla::CreateToken(ctx->builder())));
+ }
+};
+
+REGISTER_OP("DummySideEffectingOp");
+
+REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
+
+TEST_F(XlaCompilerTest, TokenInputAndOutput) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ NodeDef side_effecting_op;
+ side_effecting_op.set_name("DummySideEffectingOp");
+ side_effecting_op.set_op("DummySideEffectingOp");
+ AddNodeAttr(kXlaTokenInputNodesAttrName,
+ std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
+ Status status;
+ graph->AddNode(side_effecting_op, &status);
+ TF_ASSERT_OK(status);
+ EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
+
+ const std::vector<XlaCompiler::Argument> empty_args;
+ {
+ // The case for entry computation: we don't add token input/output. Instead,
+ // we use CreateToken HLO to create the entry token.
+ XlaCompiler::CompileOptions options;
+ options.is_entry_computation = true;
+ options.add_token_input_output = false;
+ XlaCompiler compiler(DefaultOptions());
+
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
+ empty_args, &result));
+ EXPECT_EQ(result.xla_input_shapes.size(), 0);
+ EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
+ EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0);
+ }
+ {
+ // The case for non-entry computation (e.g. while loop body). We add token
+ // input/output.
+ XlaCompiler::CompileOptions options;
+ options.is_entry_computation = false;
+ options.add_token_input_output = true;
+ XlaCompiler compiler(DefaultOptions());
+
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
+ empty_args, &result));
+ EXPECT_EQ(result.xla_input_shapes.size(), 1);
+ EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0]));
+ EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
+ EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
+ EXPECT_TRUE(xla::ShapeUtil::IsToken(
+ xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0)));
+ }
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index e8b4b0eb36..f247570d72 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -119,6 +119,17 @@ Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
return Status::OK();
}
+Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) {
+ VLOG(1) << "Adding retval index " << retvals_.size()
+ << " with token to XLA computation";
+ XlaExpression e;
+ e.set_handle(token);
+ // We use DT_INVALID because there is no TF DataType which corresponds to XLA
+ // token. XlaCompiler handles this case separately, so putting it here is OK.
+ retvals_.push_back(Retval{DT_INVALID, TensorShape(), e});
+ return Status::OK();
+}
+
xla::XlaBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource(
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 4da891634e..d7dbdc957f 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -89,6 +89,9 @@ class XlaContext : public ResourceBase {
// As for Retval, but for return values that are resource handles.
Status AddResourceRetval(int retval_index, XlaResource* resource);
+ // As for Retval, but for return values that are XLA tokens.
+ Status AppendTokenRetval(const xla::XlaOp& token);
+
// Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource`
// constructor for a description of the remaining arguments.