diff options
author | 2017-09-20 11:19:28 -0700 | |
---|---|---|
committer | 2017-09-20 11:26:24 -0700 | |
commit | 453d06b059e3b0d8eb151423e42bc3cda2768d4d (patch) | |
tree | 3687af485be6592b09e6fbb391b52220adaf113e | |
parent | 3e4521bd290e4654a8b1e432d16ca893181ab018 (diff) |
[TF:XLA] Simplify XlaCompiler API. Unconditionally builds an XLA computation, even if the computation is empty.
Reduces code complexity by removing an "optimization" of dubious value.
PiperOrigin-RevId: 169421426
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_local_launch_op.cc | 118 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_compilation_cache.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/tf2xla.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 27 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.h | 3 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.h | 8 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.h | 3 |
11 files changed, 70 insertions, 112 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc index 5cbff88780..d9b5b2dd69 100644 --- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc @@ -259,72 +259,70 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { XlaLocalRuntimeContext local_runtime_context; std::unique_ptr<xla::ShapedBuffer> output; - if (!kernel->computation->IsNull()) { - // Build xla::ShapedBuffers that point directly to the Tensor buffers. - std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers; - arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); - arg_buffers.resize(kernel->xla_input_shapes.size()); - std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size()); - - const int first_variable_arg = ctx->num_inputs() - num_resource_args_; - // Pass remaining parameters. - const Tensor* t; - for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { - int arg_num = kernel->input_mapping[i]; - const xla::Shape& shape = kernel->xla_input_shapes[i]; - if (arg_num >= first_variable_arg) { - t = &(variables[arg_num - first_variable_arg].value); - } else { - t = &(ctx->input(arg_num)); - } + // Build xla::ShapedBuffers that point directly to the Tensor buffers. + std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers; + arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); + arg_buffers.resize(kernel->xla_input_shapes.size()); + std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size()); + + const int first_variable_arg = ctx->num_inputs() - num_resource_args_; + // Pass remaining parameters. + const Tensor* t; + for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { + int arg_num = kernel->input_mapping[i]; + const xla::Shape& shape = kernel->xla_input_shapes[i]; + if (arg_num >= first_variable_arg) { + t = &(variables[arg_num - first_variable_arg].value); + } else { + t = &(ctx->input(arg_num)); + } - gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( - const_cast<char*>(t->tensor_data().data()), t->tensor_data().size()); + gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( + const_cast<char*>(t->tensor_data().data()), t->tensor_data().size()); - arg_buffers[i] = - xla::ShapedBuffer::MakeArrayShapedBuffer( - shape, client->platform(), client->default_device_ordinal(), dmem) - .ConsumeValueOrDie(); - arg_ptrs[i] = arg_buffers[i].get(); - } + arg_buffers[i] = + xla::ShapedBuffer::MakeArrayShapedBuffer( + shape, client->platform(), client->default_device_ordinal(), dmem) + .ConsumeValueOrDie(); + arg_ptrs[i] = arg_buffers[i].get(); + } - // Make the final parameter point at local_runtime_context. - if (kernel->requires_runtime_context) { - gpu::DeviceMemoryBase local_runtime_context_dmem( - &local_runtime_context, sizeof(local_runtime_context)); - arg_buffers.push_back( - xla::ShapedBuffer::MakeArrayShapedBuffer( - xla::ShapeUtil::MakeOpaqueShape(), client->platform(), - client->default_device_ordinal(), local_runtime_context_dmem) - .ConsumeValueOrDie()); - arg_ptrs.push_back(arg_buffers.back().get()); - } + // Make the final parameter point at local_runtime_context. + if (kernel->requires_runtime_context) { + gpu::DeviceMemoryBase local_runtime_context_dmem( + &local_runtime_context, sizeof(local_runtime_context)); + arg_buffers.push_back( + xla::ShapedBuffer::MakeArrayShapedBuffer( + xla::ShapeUtil::MakeOpaqueShape(), client->platform(), + client->default_device_ordinal(), local_runtime_context_dmem) + .ConsumeValueOrDie()); + arg_ptrs.push_back(arg_buffers.back().get()); + } - // Execute the computation. - VLOG(2) << "Executing computation."; - xla::ExecutableRunOptions run_options; - run_options.set_stream(stream); - run_options.set_allocator(&xla_allocator); - run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); - Env* env = Env::Default(); - auto start_time = env->NowMicros(); - auto run_result = executable->Run(arg_ptrs, run_options); - OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - - if (local_runtime_context.error) { - ctx->CtxFailure(errors::InvalidArgument( - "Compiled kernel returned error: ", local_runtime_context.error_msg)); - return; - } + // Execute the computation. + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(&xla_allocator); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + auto run_result = executable->Run(arg_ptrs, run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + if (local_runtime_context.error) { + ctx->CtxFailure(errors::InvalidArgument("Compiled kernel returned error: ", + local_runtime_context.error_msg)); + return; + } - output = std::move(run_result.ValueOrDie()); - auto elapsed = env->NowMicros() - start_time; - VLOG(2) << "Elapsed time: " << elapsed << "us"; + output = std::move(run_result.ValueOrDie()); + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time: " << elapsed << "us"; - // Computation output should always be a tuple. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output->shape().DebugString(); - } + // Computation output should always be a tuple. + if (VLOG_IS_ON(2)) { + VLOG(2) << "Result tuple shape: " << output->shape().DebugString(); } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index d3939002aa..4fd155565c 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -309,8 +309,7 @@ Status XlaCompilationCache::Compile( } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { - if (entry->executable == nullptr && - !entry->compilation_result.computation->IsNull()) { + if (entry->executable == nullptr) { XlaCompiler compiler(options); entry->compilation_status = BuildExecutable( options, entry->compilation_result, &entry->executable); diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 2a24529850..ed818c56ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -45,7 +45,6 @@ void SendOp::Compile(XlaOpKernelContext* ctx) { xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); ctx->builder()->Send(ctx->Input(0), channel); - ctx->SetOpHasSideEffects(); } REGISTER_XLA_OP(Name("_XLASend"), SendOp); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index b29c92190d..b7213a6cc1 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -304,11 +304,6 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client, " constant results. The configuration of " "the output args (i.e. fetch ids) is probably wrong."); } - if (computation->IsNull()) { - return errors::Aborted( - "Conversion from TensorFlow graph to XLA resulted in an empty " - "computation."); - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 34b1246be2..0b583b54bf 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -343,8 +343,6 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args, // `retvals` is the list of retvals produced by _Retval operators, in index // order. `variable_map` is a map from variable ID numbers to XlaOpContext // variable states, generated by the symbolic evaluation. -// If `has_side_effects` is true, the computation has side effects and should be -// built even if it has no outputs. // If `return_updated_values_for_all_resources` is true, all resources will be // included in `resource_updates`, regardless of whether their value changed. // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. @@ -358,7 +356,7 @@ Status BuildComputation( const std::vector<XlaCompiler::Argument>& args, const std::vector<XlaExpression>& retvals, const std::vector<std::unique_ptr<XlaResource>>& resources, - bool has_side_effects, bool return_updated_values_for_all_resources, + bool return_updated_values_for_all_resources, xla::ComputationBuilder* builder, xla::Computation* computation, int* num_computation_outputs, int* num_nonconst_outputs, std::vector<XlaCompiler::ResourceUpdate>* resource_updates) { @@ -412,18 +410,14 @@ Status BuildComputation( } *num_computation_outputs = elems.size(); - if (!elems.empty() || has_side_effects) { - // Builds a empty tuple return value for computations that have side effects - // but have no return values. - builder->Tuple(elems); - - // Builds the XLA computation. - xla::StatusOr<xla::Computation> computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status(); - } - *computation = computation_status.ConsumeValueOrDie(); + + // Builds the XLA computation. + builder->Tuple(elems); + xla::StatusOr<xla::Computation> computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status(); } + *computation = computation_status.ConsumeValueOrDie(); return Status::OK(); } @@ -484,7 +478,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->computation = std::make_shared<xla::Computation>(); TF_RETURN_IF_ERROR(BuildComputation( args, context->retvals(), context->resources(), - context->has_side_effects(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->resource_updates)); @@ -508,10 +501,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, } } - if (result->computation->IsNull()) { - return Status::OK(); - } - // Compute the output shapes, if there is a computation with non-constant // outputs. auto computation_shape = client()->GetComputationShape(*result->computation); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index cf78e2cc13..35159dbad4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -194,8 +194,7 @@ class XlaCompiler { // results in the outputs of XLA computation. std::vector<ResourceUpdate> resource_updates; - // The XLA computation built from the tensorflow subgraph. May be null - // if the output consists solely of compile-time constants. + // The XLA computation built from the tensorflow subgraph. std::shared_ptr<xla::Computation> computation; }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index f516dd867a..531725a623 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -151,7 +151,7 @@ class XlaCompilerTest : public ::testing::Test { std::unique_ptr<FunctionLibraryDefinition> flib_def_; }; -// Tests compilation of an empty graph. +// Tests compilation and execution of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { XlaCompiler compiler(DefaultOptions()); @@ -161,8 +161,7 @@ TEST_F(XlaCompilerTest, EmptyReturnValues) { std::move(graph), /*args=*/{}, &result)); - // No computation should be generated. - EXPECT_EQ(0, result.computation->handle().handle()); + TF_ASSERT_OK(client_->Execute(*result.computation, {}).status()); } // Tests compilation and execution of a graph that adds two tensors. diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 35219feca4..cf1cc6ab49 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -119,10 +119,6 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, return Status::OK(); } -void XlaContext::AddSideEffects() { - has_side_effects_ = true; -} - xla::ComputationBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num, diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 5e8149e3aa..70b81c111f 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -76,11 +76,6 @@ class XlaContext : public ResourceBase { Status AddConstRetval(int retval_index, DataType dtype, const xla::Literal& literal); - // Mark the computation as having side effects (e.g., Send operators). - void AddSideEffects(); - - bool has_side_effects() const { return has_side_effects_; } - // Creates a resource with resource `kind` and initial type `type` and // value `handle`. `name` is a descriptive name for use in error messages. // Fails if the resource already exists. @@ -133,9 +128,6 @@ class XlaContext : public ResourceBase { // Return values of the Tensorflow graph, indexed by _Retval index. std::vector<XlaExpression> retvals_; - // Does the computation have side effects, i.e., Send() calls? - bool has_side_effects_ = false; - // Holds ownership of resources. The resources are not ordered. std::vector<std::unique_ptr<XlaResource>> resources_; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 459d9edb87..b0607cfa0c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -345,7 +345,6 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { Status XlaOpKernelContext::AssignVariable( int input_index, DataType type, const xla::ComputationDataHandle& handle) { TF_RET_CHECK(handle.handle() != 0); - SetOpHasSideEffects(); const XlaExpression* expression = CastExpressionFromTensor(context_->input(input_index)); @@ -363,10 +362,6 @@ Status XlaOpKernelContext::AssignVariable( return Status::OK(); } -void XlaOpKernelContext::SetOpHasSideEffects() { - XlaContext::Get(context_).AddSideEffects(); -} - XlaCompiler* XlaOpKernelContext::compiler() const { return XlaContext::Get(context_).compiler(); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 30b794c8c1..c8b5370faa 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -143,9 +143,6 @@ class XlaOpKernelContext { void SetStatus(const Status& status) { context_->SetStatus(status); } Status status() { return context_->status(); } - // Mark the op has having side effects (i.e., via Send). - void SetOpHasSideEffects(); - // Variables // Sets '*resource' to the resource associated with input `index`. |