diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/if_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/if_op.cc | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 6e1dbf5472..56da50f140 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } // TODO(b/35949885): There is duplication here with the handling of the @@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { options.resolve_compile_time_constants = false; options.return_updated_values_for_all_resources = true; options.is_entry_computation = false; + options.add_token_input_output = has_token_input_output_; XlaCompiler* compiler = ctx->compiler(); XlaCompiler::CompilationResult then_result; @@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { std::vector<xla::XlaOp> inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = then_result.input_mapping[i] + 1; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "if" op. + std::vector<xla::XlaOp> token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(b, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); @@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } ctx->SetOutput(i, output_handle); } + if (has_token_input_output_) { + // Set token output for this "if" op. + xla::XlaOp token_output = + xla::GetTupleElement(outputs, output_types_.size()); + auto shape_or = b->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the conditional // bodies. |