aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-09-20 11:19:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-20 11:26:24 -0700
commit453d06b059e3b0d8eb151423e42bc3cda2768d4d (patch)
tree3687af485be6592b09e6fbb391b52220adaf113e
parent3e4521bd290e4654a8b1e432d16ca893181ab018 (diff)
[TF:XLA] Simplify XlaCompiler API. Unconditionally builds an XLA computation, even if the computation is empty.
Reduces code complexity by removing an "optimization" of dubious value. PiperOrigin-RevId: 169421426
-rw-r--r--tensorflow/compiler/jit/kernels/xla_local_launch_op.cc118
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc1
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc5
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc27
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h3
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc5
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h8
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc5
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h3
11 files changed, 70 insertions, 112 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
index 5cbff88780..d9b5b2dd69 100644
--- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
@@ -259,72 +259,70 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
XlaLocalRuntimeContext local_runtime_context;
std::unique_ptr<xla::ShapedBuffer> output;
- if (!kernel->computation->IsNull()) {
- // Build xla::ShapedBuffers that point directly to the Tensor buffers.
- std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
- arg_buffers.reserve(kernel->xla_input_shapes.size() + 1);
- arg_buffers.resize(kernel->xla_input_shapes.size());
- std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size());
-
- const int first_variable_arg = ctx->num_inputs() - num_resource_args_;
- // Pass remaining parameters.
- const Tensor* t;
- for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
- int arg_num = kernel->input_mapping[i];
- const xla::Shape& shape = kernel->xla_input_shapes[i];
- if (arg_num >= first_variable_arg) {
- t = &(variables[arg_num - first_variable_arg].value);
- } else {
- t = &(ctx->input(arg_num));
- }
+ // Build xla::ShapedBuffers that point directly to the Tensor buffers.
+ std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
+ arg_buffers.reserve(kernel->xla_input_shapes.size() + 1);
+ arg_buffers.resize(kernel->xla_input_shapes.size());
+ std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size());
+
+ const int first_variable_arg = ctx->num_inputs() - num_resource_args_;
+ // Pass remaining parameters.
+ const Tensor* t;
+ for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
+ int arg_num = kernel->input_mapping[i];
+ const xla::Shape& shape = kernel->xla_input_shapes[i];
+ if (arg_num >= first_variable_arg) {
+ t = &(variables[arg_num - first_variable_arg].value);
+ } else {
+ t = &(ctx->input(arg_num));
+ }
- gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase(
- const_cast<char*>(t->tensor_data().data()), t->tensor_data().size());
+ gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase(
+ const_cast<char*>(t->tensor_data().data()), t->tensor_data().size());
- arg_buffers[i] =
- xla::ShapedBuffer::MakeArrayShapedBuffer(
- shape, client->platform(), client->default_device_ordinal(), dmem)
- .ConsumeValueOrDie();
- arg_ptrs[i] = arg_buffers[i].get();
- }
+ arg_buffers[i] =
+ xla::ShapedBuffer::MakeArrayShapedBuffer(
+ shape, client->platform(), client->default_device_ordinal(), dmem)
+ .ConsumeValueOrDie();
+ arg_ptrs[i] = arg_buffers[i].get();
+ }
- // Make the final parameter point at local_runtime_context.
- if (kernel->requires_runtime_context) {
- gpu::DeviceMemoryBase local_runtime_context_dmem(
- &local_runtime_context, sizeof(local_runtime_context));
- arg_buffers.push_back(
- xla::ShapedBuffer::MakeArrayShapedBuffer(
- xla::ShapeUtil::MakeOpaqueShape(), client->platform(),
- client->default_device_ordinal(), local_runtime_context_dmem)
- .ConsumeValueOrDie());
- arg_ptrs.push_back(arg_buffers.back().get());
- }
+ // Make the final parameter point at local_runtime_context.
+ if (kernel->requires_runtime_context) {
+ gpu::DeviceMemoryBase local_runtime_context_dmem(
+ &local_runtime_context, sizeof(local_runtime_context));
+ arg_buffers.push_back(
+ xla::ShapedBuffer::MakeArrayShapedBuffer(
+ xla::ShapeUtil::MakeOpaqueShape(), client->platform(),
+ client->default_device_ordinal(), local_runtime_context_dmem)
+ .ConsumeValueOrDie());
+ arg_ptrs.push_back(arg_buffers.back().get());
+ }
- // Execute the computation.
- VLOG(2) << "Executing computation.";
- xla::ExecutableRunOptions run_options;
- run_options.set_stream(stream);
- run_options.set_allocator(&xla_allocator);
- run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
- Env* env = Env::Default();
- auto start_time = env->NowMicros();
- auto run_result = executable->Run(arg_ptrs, run_options);
- OP_REQUIRES(ctx, run_result.ok(), run_result.status());
-
- if (local_runtime_context.error) {
- ctx->CtxFailure(errors::InvalidArgument(
- "Compiled kernel returned error: ", local_runtime_context.error_msg));
- return;
- }
+ // Execute the computation.
+ VLOG(2) << "Executing computation.";
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(&xla_allocator);
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+ auto run_result = executable->Run(arg_ptrs, run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ if (local_runtime_context.error) {
+ ctx->CtxFailure(errors::InvalidArgument("Compiled kernel returned error: ",
+ local_runtime_context.error_msg));
+ return;
+ }
- output = std::move(run_result.ValueOrDie());
- auto elapsed = env->NowMicros() - start_time;
- VLOG(2) << "Elapsed time: " << elapsed << "us";
+ output = std::move(run_result.ValueOrDie());
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time: " << elapsed << "us";
- // Computation output should always be a tuple.
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "Result tuple shape: " << output->shape().DebugString();
- }
+ // Computation output should always be a tuple.
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Result tuple shape: " << output->shape().DebugString();
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index d3939002aa..4fd155565c 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -309,8 +309,7 @@ Status XlaCompilationCache::Compile(
}
*compilation_result = &entry->compilation_result;
if (entry->compilation_status.ok() && executable) {
- if (entry->executable == nullptr &&
- !entry->compilation_result.computation->IsNull()) {
+ if (entry->executable == nullptr) {
XlaCompiler compiler(options);
entry->compilation_status = BuildExecutable(
options, entry->compilation_result, &entry->executable);
diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
index 2a24529850..ed818c56ed 100644
--- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
@@ -45,7 +45,6 @@ void SendOp::Compile(XlaOpKernelContext* ctx) {
xla::ChannelHandle channel;
OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel));
ctx->builder()->Send(ctx->Input(0), channel);
- ctx->SetOpHasSideEffects();
}
REGISTER_XLA_OP(Name("_XLASend"), SendOp);
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index b29c92190d..b7213a6cc1 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -304,11 +304,6 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
" constant results. The configuration of "
"the output args (i.e. fetch ids) is probably wrong.");
}
- if (computation->IsNull()) {
- return errors::Aborted(
- "Conversion from TensorFlow graph to XLA resulted in an empty "
- "computation.");
- }
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 34b1246be2..0b583b54bf 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -343,8 +343,6 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
// `retvals` is the list of retvals produced by _Retval operators, in index
// order. `variable_map` is a map from variable ID numbers to XlaOpContext
// variable states, generated by the symbolic evaluation.
-// If `has_side_effects` is true, the computation has side effects and should be
-// built even if it has no outputs.
// If `return_updated_values_for_all_resources` is true, all resources will be
// included in `resource_updates`, regardless of whether their value changed.
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
@@ -358,7 +356,7 @@ Status BuildComputation(
const std::vector<XlaCompiler::Argument>& args,
const std::vector<XlaExpression>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources,
- bool has_side_effects, bool return_updated_values_for_all_resources,
+ bool return_updated_values_for_all_resources,
xla::ComputationBuilder* builder, xla::Computation* computation,
int* num_computation_outputs, int* num_nonconst_outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
@@ -412,18 +410,14 @@ Status BuildComputation(
}
*num_computation_outputs = elems.size();
- if (!elems.empty() || has_side_effects) {
- // Builds a empty tuple return value for computations that have side effects
- // but have no return values.
- builder->Tuple(elems);
-
- // Builds the XLA computation.
- xla::StatusOr<xla::Computation> computation_status = builder->Build();
- if (!computation_status.ok()) {
- return computation_status.status();
- }
- *computation = computation_status.ConsumeValueOrDie();
+
+ // Builds the XLA computation.
+ builder->Tuple(elems);
+ xla::StatusOr<xla::Computation> computation_status = builder->Build();
+ if (!computation_status.ok()) {
+ return computation_status.status();
}
+ *computation = computation_status.ConsumeValueOrDie();
return Status::OK();
}
@@ -484,7 +478,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
result->computation = std::make_shared<xla::Computation>();
TF_RETURN_IF_ERROR(BuildComputation(
args, context->retvals(), context->resources(),
- context->has_side_effects(),
options.return_updated_values_for_all_resources, &builder,
result->computation.get(), &num_computation_outputs,
&num_nonconst_outputs, &result->resource_updates));
@@ -508,10 +501,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
}
}
- if (result->computation->IsNull()) {
- return Status::OK();
- }
-
// Compute the output shapes, if there is a computation with non-constant
// outputs.
auto computation_shape = client()->GetComputationShape(*result->computation);
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index cf78e2cc13..35159dbad4 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -194,8 +194,7 @@ class XlaCompiler {
// results in the outputs of XLA computation.
std::vector<ResourceUpdate> resource_updates;
- // The XLA computation built from the tensorflow subgraph. May be null
- // if the output consists solely of compile-time constants.
+ // The XLA computation built from the tensorflow subgraph.
std::shared_ptr<xla::Computation> computation;
};
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index f516dd867a..531725a623 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -151,7 +151,7 @@ class XlaCompilerTest : public ::testing::Test {
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
};
-// Tests compilation of an empty graph.
+// Tests compilation and execution of an empty graph.
TEST_F(XlaCompilerTest, EmptyReturnValues) {
XlaCompiler compiler(DefaultOptions());
@@ -161,8 +161,7 @@ TEST_F(XlaCompilerTest, EmptyReturnValues) {
std::move(graph),
/*args=*/{}, &result));
- // No computation should be generated.
- EXPECT_EQ(0, result.computation->handle().handle());
+ TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
}
// Tests compilation and execution of a graph that adds two tensors.
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 35219feca4..cf1cc6ab49 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -119,10 +119,6 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
return Status::OK();
}
-void XlaContext::AddSideEffects() {
- has_side_effects_ = true;
-}
-
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num,
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 5e8149e3aa..70b81c111f 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -76,11 +76,6 @@ class XlaContext : public ResourceBase {
Status AddConstRetval(int retval_index, DataType dtype,
const xla::Literal& literal);
- // Mark the computation as having side effects (e.g., Send operators).
- void AddSideEffects();
-
- bool has_side_effects() const { return has_side_effects_; }
-
// Creates a resource with resource `kind` and initial type `type` and
// value `handle`. `name` is a descriptive name for use in error messages.
// Fails if the resource already exists.
@@ -133,9 +128,6 @@ class XlaContext : public ResourceBase {
// Return values of the Tensorflow graph, indexed by _Retval index.
std::vector<XlaExpression> retvals_;
- // Does the computation have side effects, i.e., Send() calls?
- bool has_side_effects_ = false;
-
// Holds ownership of resources. The resources are not ordered.
std::vector<std::unique_ptr<XlaResource>> resources_;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 459d9edb87..b0607cfa0c 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -345,7 +345,6 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
Status XlaOpKernelContext::AssignVariable(
int input_index, DataType type, const xla::ComputationDataHandle& handle) {
TF_RET_CHECK(handle.handle() != 0);
- SetOpHasSideEffects();
const XlaExpression* expression =
CastExpressionFromTensor(context_->input(input_index));
@@ -363,10 +362,6 @@ Status XlaOpKernelContext::AssignVariable(
return Status::OK();
}
-void XlaOpKernelContext::SetOpHasSideEffects() {
- XlaContext::Get(context_).AddSideEffects();
-}
-
XlaCompiler* XlaOpKernelContext::compiler() const {
return XlaContext::Get(context_).compiler();
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 30b794c8c1..c8b5370faa 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -143,9 +143,6 @@ class XlaOpKernelContext {
void SetStatus(const Status& status) { context_->SetStatus(status); }
Status status() { return context_->status(); }
- // Mark the op has having side effects (i.e., via Send).
- void SetOpHasSideEffects();
-
// Variables
// Sets '*resource' to the resource associated with input `index`.