diff options
author | Peter Hawkins <phawkins@google.com> | 2017-05-11 08:30:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-11 11:13:19 -0700 |
commit | 851982c0e3c543306f3942f92bc4bdbc0d50d1f7 (patch) | |
tree | 5121012e0cd4ef5cc2541c2dadc47c0ff729b6e4 | |
parent | 1e390f5f7992f3ae5f9ecfc91ebe3d711efe0b7d (diff) |
[TF:XLA] Add function compilation cache.
PiperOrigin-RevId: 155751485
-rw-r--r-- | tensorflow/compiler/aot/compile.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_device_launch_op.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_local_launch_op.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_compilation_cache.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 34 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.h | 13 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 8 | ||||
-rw-r--r-- | tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc | 2 |
9 files changed, 53 insertions, 16 deletions
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 4c572fd390..162e719ade 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -301,7 +301,7 @@ Status ConvertGraphToXla(xla::CompileOnlyClient* client, "tfcompile", std::move(graph), xla_args, &result)); *has_context_arg = result.requires_runtime_context; - *computation = std::move(result.computation); + *computation = std::move(*result.computation); int num_const_results = 0; for (int i = 0; i < result.outputs.size(); ++i) { diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc index 497c813d99..cb6c14901b 100644 --- a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc @@ -121,7 +121,7 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { // Runs the computation, if any. There might not be a computation if all // outputs were compile-time constants. std::vector<std::unique_ptr<xla::GlobalData>> outputs; - if (!kernel->computation.IsNull()) { + if (!kernel->computation->IsNull()) { auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); // Builds the inputs to the computation. @@ -152,7 +152,7 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { Env* env = Env::Default(); auto start_time = env->NowMicros(); VLOG(1) << "Executing XLA Computation..."; - auto result = cache->client()->Execute(kernel->computation, arg_ptrs, + auto result = cache->client()->Execute(*kernel->computation, arg_ptrs, &execution_options, &profile); auto elapsed = env->NowMicros() - start_time; OP_REQUIRES(ctx, result.ok(), result.status()); diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc index 0a2af4050e..032ded54e6 100644 --- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc @@ -225,7 +225,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { std::unique_ptr<xla::ShapedBuffer> output; bool output_is_tuple; - if (!kernel->computation.IsNull()) { + if (!kernel->computation->IsNull()) { // Build xla::ShapedBuffers that point directly to the Tensor buffers. std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers; arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 8238f7d919..82af304169 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -273,7 +273,7 @@ Status XlaCompilationCache::Compile( *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { if (entry->executable == nullptr && - !entry->compilation_result.computation.IsNull()) { + !entry->compilation_result.computation->IsNull()) { XlaCompiler compiler(options); entry->compilation_status = compiler.BuildExecutable( entry->compilation_result, &entry->executable); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 5ed4393bc5..d246e7f9ac 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -57,6 +57,18 @@ Status CheckSignature(const DataTypeVector& types, } // namespace +bool XlaCompiler::Argument::operator==( + const XlaCompiler::Argument& other) const { + if (std::tie(kind, type, shape, name) != + std::tie(other.kind, other.type, other.shape, other.name)) { + return false; + } + if (constant_value.shape() != other.constant_value.shape()) { + return false; + } + return constant_value.tensor_data() == other.constant_value.tensor_data(); +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(std::move(options)), initialization_status_(Status::OK()), @@ -85,6 +97,11 @@ int64 XlaCompiler::NextStepId() { return next_step_id_++; } +uint64 XlaCompiler::SignatureHash::operator()( + const std::pair<string, std::vector<Argument>>& signature) const { + return std::hash<string>()(signature.first); +} + Status XlaCompiler::CompileFunction( const XlaCompiler::CompileOptions& options, const NameAttrList& function, const std::vector<XlaCompiler::Argument>& args, @@ -92,6 +109,12 @@ Status XlaCompiler::CompileFunction( const string function_id = Canonicalize(function.name(), function.attr()); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; + auto it = cache_.find({function_id, args}); + if (it != cache_.end()) { + *result = it->second; + return Status::OK(); + } + FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR( flib_runtime_->Instantiate(function.name(), function.attr(), &handle)); @@ -129,6 +152,7 @@ Status XlaCompiler::CompileFunction( CompileGraph(options, function_id, std::move(graph), args, result)); VLOG(1) << "===================================================="; + cache_[{function_id, args}] = *result; return Status::OK(); } @@ -155,7 +179,7 @@ Status XlaCompiler::BuildExecutable( build_options.set_has_hybrid_result( options_.local_executable_has_hybrid_result); - auto compile_result = local_client->Compile(result.computation, + auto compile_result = local_client->Compile(*result.computation, argument_layouts, build_options); if (!compile_result.ok()) { return compile_result.status(); @@ -403,10 +427,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, flib_runtime_.get(), NextStepId())); int num_nonconst_outputs; + result->computation = std::make_shared<xla::Computation>(); TF_RETURN_IF_ERROR(BuildComputation( context->retvals(), context->variables(), context->has_side_effects(), options.return_updated_values_for_all_variables, &builder, - &result->computation, &num_nonconst_outputs, &result->variable_updates)); + result->computation.get(), &num_nonconst_outputs, + &result->variable_updates)); result->requires_runtime_context = context->has_context_parameter(); @@ -427,13 +453,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, } } - if (result->computation.IsNull()) { + if (result->computation->IsNull()) { return Status::OK(); } // Compute the output shapes, if there is a computation with non-constant // outputs. - auto computation_shape = client()->GetComputationShape(result->computation); + auto computation_shape = client()->GetComputationShape(*result->computation); if (!computation_shape.ok()) { return computation_shape.status(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 1a089155ce..15f723ad78 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -113,6 +113,8 @@ class XlaCompiler { // The name of this argument, used for debugging. string name; + + bool operator==(const Argument& other) const; }; struct OutputDescription { @@ -173,7 +175,7 @@ class XlaCompiler { // The XLA computation built from the tensorflow subgraph. May be null // if the output consists solely of compile-time constants. - xla::Computation computation; + std::shared_ptr<xla::Computation> computation; }; struct Options { @@ -290,6 +292,15 @@ class XlaCompiler { std::unique_ptr<FunctionLibraryRuntime> flib_runtime_; + struct SignatureHash { + uint64 operator()( + const std::pair<string, std::vector<Argument>>& signature) const; + }; + + std::unordered_map<std::pair<string, std::vector<Argument>>, + CompilationResult, SignatureHash> + cache_; + std::unordered_map<string, xla::ChannelHandle> channels_ GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 11c811613e..58d74057d1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -131,7 +131,7 @@ TEST_F(XlaCompilerTest, EmptyReturnValues) { /*args=*/{}, &result)); // No computation should be generated. - EXPECT_EQ(0, result.computation.handle().handle()); + EXPECT_EQ(0, result.computation->handle().handle()); } // Tests compilation and execution of a graph that adds two tensors. @@ -173,7 +173,7 @@ TEST_F(XlaCompilerTest, Simple) { std::unique_ptr<xla::GlobalData> actual = client_ - ->Execute(result.computation, {param0_data.get(), param1_data.get()}) + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); std::unique_ptr<xla::Literal> actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -230,7 +230,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = - client_->Execute(result.computation, {param0_data.get()}) + client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); std::unique_ptr<xla::Literal> actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -265,7 +265,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = - client_->Execute(result.computation, {param0_data.get()}) + client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); std::unique_ptr<xla::Literal> actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc index 87f7cba5f4..302aa6457a 100644 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc +++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc @@ -228,7 +228,7 @@ ConvertTfGraphToXlaSessionModule(const std::vector<XlaCompiler::Argument>& args, TF_CHECK_OK(compiler->CompileGraph(XlaCompiler::CompileOptions(), GRAPH_NAME, std::move(graph), args, &result)); - return result.computation.Snapshot(); + return result.computation->Snapshot(); } xla::StatusOr<std::unordered_map<int64, XlaNode>> diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc index 7162627562..beb4c8009b 100644 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc +++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc @@ -79,7 +79,7 @@ static void DumpHloGraphForDebug(const std::vector<XlaCompiler::Argument>& args, std::move(graph), args, &result)); // Convert to hlo - xla::Computation& computation = result.computation; + xla::Computation& computation = *result.computation; xla::Service* service( static_cast<xla::Service*>(xla::ClientLibrary::GetXlaService( |