aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/if_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/if_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc30
1 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index 6e1dbf5472..56da50f140 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/if_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
+ if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+ has_token_input_output_ = false;
+ } else {
+ has_token_input_output_ = !token_input_nodes_.empty();
+ }
}
// TODO(b/35949885): There is duplication here with the handling of the
@@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
options.resolve_compile_time_constants = false;
options.return_updated_values_for_all_resources = true;
options.is_entry_computation = false;
+ options.add_token_input_output = has_token_input_output_;
XlaCompiler* compiler = ctx->compiler();
XlaCompiler::CompilationResult then_result;
@@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = then_result.input_mapping[i] + 1;
- if (ctx->input_type(input_num) == DT_RESOURCE) {
+ if (has_token_input_output_ && i == num_inputs - 1) {
+ // Set token input for this "if" op.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const string& node_name : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(node_name);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ inputs[i] = xla::AfterAll(b, token_inputs);
+ } else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
@@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
ctx->SetOutput(i, output_handle);
}
+ if (has_token_input_output_) {
+ // Set token output for this "if" op.
+ xla::XlaOp token_output =
+ xla::GetTupleElement(outputs, output_types_.size());
+ auto shape_or = b->GetShape(token_output);
+ OP_REQUIRES_OK(ctx, shape_or.status());
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+ errors::FailedPrecondition(
+ "Token output is not token type: ",
+ xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+ OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+ }
// Updates the values of any resource variables modified by the conditional
// bodies.