aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/aot/BUILD2
-rw-r--r--tensorflow/compiler/aot/compile.cc17
-rw-r--r--tensorflow/compiler/xla/client/BUILD22
-rw-r--r--tensorflow/compiler/xla/client/client_library.cc38
-rw-r--r--tensorflow/compiler/xla/client/client_library.h22
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.cc59
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.h66
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc32
-rw-r--r--tensorflow/compiler/xla/client/local_client.h26
-rw-r--r--tensorflow/compiler/xla/service/BUILD21
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc131
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h128
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc64
-rw-r--r--tensorflow/compiler/xla/service/local_service.h16
-rw-r--r--tensorflow/compiler/xla/service/service.cc30
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc4
-rw-r--r--tensorflow/opensource_only/eigen.threadpool1
17 files changed, 512 insertions, 167 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index c52a56b642..c12005a4ca 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -73,7 +73,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 4b5534c164..3955cabedf 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -274,7 +274,8 @@ Status CreateXlaArgs(const Graph& graph,
// Converts the TensorFlow graph into an XLA computation, by executing the
// graph symbolically, with each op building up the XLA HLO.
-Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
+Status ConvertGraphToXla(xla::CompileOnlyClient* client,
+ std::unique_ptr<Graph> graph,
xla::Computation* computation, bool* has_context_arg) {
// Create a device and context to convert the graph into an XLA computation.
XlaOpRegistry::RegisterCompilationKernels();
@@ -333,7 +334,8 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
}
// Compiles the XLA computation into executable code.
-Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
+Status CompileXla(xla::CompileOnlyClient* client,
+ const xla::Computation& computation,
const xla::cpu::CpuAotCompilationOptions& aot_opts,
CompileResult* compile_result) {
// Retrieves arg and result layouts from the computation.
@@ -350,7 +352,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
}
- xla::LocalClient::AheadOfTimeComputationInstance instance;
+ xla::CompileOnlyClient::AotComputationInstance instance;
instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts);
instance.result_layout = &pshape->result();
@@ -365,7 +367,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
std::move(aot_or.ValueOrDie().back()));
compile_result->entry_point = aot_opts.entry_point_name();
compile_result->pointer_size =
- xla::LocalClient::PointerSizeForTriple(aot_opts.triple());
+ xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
return Status::OK();
}
@@ -394,8 +396,9 @@ Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
namespace gpu = perftools::gputools;
gpu::Platform* cpu_platform =
gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
- xla::LocalClient* client =
- xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie();
+ xla::CompileOnlyClient* client =
+ xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
+ .ValueOrDie();
xla::Computation computation;
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation,
&compile_result->has_context_arg));
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 3e9dfe2a92..2d96128e25 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -99,6 +99,26 @@ cc_library(
],
)
+cc_library(
+ name = "compile_only_client",
+ srcs = ["compile_only_client.cc"],
+ hdrs = ["compile_only_client.h"],
+ deps = [
+ ":client",
+ ":computation",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:compile_only_service",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "@llvm//:support",
+ ],
+)
+
# This target is used to instantiate the XLA service in-process and create
# a client for it.
cc_library(
@@ -106,12 +126,14 @@ cc_library(
srcs = ["client_library.cc"],
hdrs = ["client_library.h"],
deps = [
+ ":compile_only_client",
":local_client",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:compile_only_service",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc
index 93437023bc..eb9a7ff2ac 100644
--- a/tensorflow/compiler/xla/client/client_library.cc
+++ b/tensorflow/compiler/xla/client/client_library.cc
@@ -69,8 +69,8 @@ ClientLibrary::~ClientLibrary() = default;
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
}
- auto it = client_library.instances_.find(platform->id());
- if (it != client_library.instances_.end()) {
+ auto it = client_library.local_instances_.find(platform->id());
+ if (it != client_library.local_instances_.end()) {
return it->second->client.get();
}
@@ -78,13 +78,13 @@ ClientLibrary::~ClientLibrary() = default;
service_options.set_platform(platform);
service_options.set_number_of_replicas(replica_count);
- std::unique_ptr<LocalInstance> instance = MakeUnique<LocalInstance>();
+ auto instance = MakeUnique<LocalInstance>();
TF_ASSIGN_OR_RETURN(instance->service,
LocalService::NewService(service_options));
instance->client = MakeUnique<LocalClient>(instance->service.get());
LocalClient* cl = instance->client.get();
- client_library.instances_.insert(
+ client_library.local_instances_.insert(
std::make_pair(platform->id(), std::move(instance)));
return cl;
}
@@ -99,9 +99,35 @@ ClientLibrary::~ClientLibrary() = default;
perftools::gputools::Platform* platform) {
ClientLibrary& client_library = Singleton();
tensorflow::mutex_lock lock(client_library.service_mutex_);
- auto it = client_library.instances_.find(platform->id());
- CHECK(it != client_library.instances_.end());
+ auto it = client_library.local_instances_.find(platform->id());
+ CHECK(it != client_library.local_instances_.end());
return it->second->service.get();
}
+/* static */ StatusOr<CompileOnlyClient*>
+ClientLibrary::GetOrCreateCompileOnlyClient(
+ perftools::gputools::Platform* platform) {
+ ClientLibrary& client_library = Singleton();
+ tensorflow::mutex_lock lock(client_library.service_mutex_);
+
+ if (platform == nullptr) {
+ TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
+ }
+
+ auto it = client_library.compile_only_instances_.find(platform->id());
+ if (it != client_library.compile_only_instances_.end()) {
+ return it->second->client.get();
+ }
+
+ auto instance = MakeUnique<CompileOnlyInstance>();
+ TF_ASSIGN_OR_RETURN(instance->service,
+ CompileOnlyService::NewService(platform));
+ instance->client = MakeUnique<CompileOnlyClient>(instance->service.get());
+ CompileOnlyClient* cl = instance->client.get();
+
+ client_library.compile_only_instances_.insert(
+ std::make_pair(platform->id(), std::move(instance)));
+ return cl;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h
index 2bc319f933..49f4541437 100644
--- a/tensorflow/compiler/xla/client/client_library.h
+++ b/tensorflow/compiler/xla/client/client_library.h
@@ -26,7 +26,9 @@ limitations under the License.
#include <string>
#include <vector>
+#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/compile_only_service.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -76,6 +78,13 @@ class ClientLibrary {
// access user computations from client.
static LocalService* GetXlaService(perftools::gputools::Platform* platform);
+ // Singleton constructor-or-accessor for compile-only clients. Arguments:
+ //
+ // platform : The platform the underlying XLA service should target. If
+ // null then default platform is used.
+ static StatusOr<CompileOnlyClient*> GetOrCreateCompileOnlyClient(
+ perftools::gputools::Platform* platform = nullptr);
+
private:
// Returns the singleton instance of ClientLibrary.
static ClientLibrary& Singleton();
@@ -90,10 +99,21 @@ class ClientLibrary {
std::unique_ptr<LocalClient> client;
};
+ struct CompileOnlyInstance {
+ // Service that is wrapped by the singleton client object.
+ std::unique_ptr<CompileOnlyService> service;
+ // Singleton client object.
+ std::unique_ptr<CompileOnlyClient> client;
+ };
+
tensorflow::mutex service_mutex_; // Guards the singleton creation state.
std::unordered_map<perftools::gputools::Platform::Id,
std::unique_ptr<LocalInstance>>
- instances_ GUARDED_BY(service_mutex_);
+ local_instances_ GUARDED_BY(service_mutex_);
+
+ std::unordered_map<perftools::gputools::Platform::Id,
+ std::unique_ptr<CompileOnlyInstance>>
+ compile_only_instances_ GUARDED_BY(service_mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary);
};
diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc
new file mode 100644
index 0000000000..2ff6f0b300
--- /dev/null
+++ b/tensorflow/compiler/xla/client/compile_only_client.cc
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/compile_only_client.h"
+
+#include "external/llvm/include/llvm/ADT/Triple.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+CompileOnlyClient::CompileAheadOfTime(
+ const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
+ const AotCompilationOptions& options) {
+ std::vector<CompileOnlyService::AotComputationInstance> service_instances;
+ service_instances.reserve(computations.size());
+ for (const AotComputationInstance& instance : computations) {
+ service_instances.push_back({});
+ CompileOnlyService::AotComputationInstance& service_instance =
+ service_instances.back();
+ TF_RET_CHECK(instance.computation != nullptr);
+ service_instance.computation = instance.computation->handle();
+ service_instance.argument_layouts = instance.argument_layouts;
+ service_instance.result_layout = instance.result_layout;
+ }
+ return compiler_service_->CompileAheadOfTime(service_instances, options);
+}
+
+int64 CompileOnlyClient::PointerSizeForTriple(
+ tensorflow::StringPiece target_triple) {
+ llvm::Triple triple(
+ llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple)));
+ if (triple.isArch64Bit()) {
+ return 8;
+ } else if (triple.isArch32Bit()) {
+ return 4;
+ } else {
+ CHECK(triple.isArch16Bit());
+ return 2;
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h
new file mode 100644
index 0000000000..5900048711
--- /dev/null
+++ b/tensorflow/compiler/xla/client/compile_only_client.h
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/service/compile_only_service.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+// An XLA Client specialization for doing ahead-of-time compilation. This does
+// not require (or attempt to instantiate) an execution-capable backend for the
+// relevant platform.
+class CompileOnlyClient : public Client {
+ public:
+ explicit CompileOnlyClient(CompileOnlyService* service)
+ : Client(service), compiler_service_(service) {}
+
+ CompileOnlyClient(const CompileOnlyClient&) = delete;
+ void operator=(const CompileOnlyClient&) = delete;
+
+ // A description of a computation to compile using CompileAheadOfTime.
+ struct AotComputationInstance {
+ const Computation* computation;
+ // Inform the compiler of the expected layout for arguments.
+ std::vector<const Shape*> argument_layouts;
+ // Specifies the expected result layout.
+ const Shape* result_layout;
+ };
+
+ // Compiles a list of computations for ahead-of-time execution. This is
+ // intended for use in static compilation. The |options| parameter describes
+ // the target for which the compiler should emit code.
+ StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+ CompileAheadOfTime(
+ const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
+ const AotCompilationOptions& options);
+
+ // Returns the size of a pointer in bytes for a given triple.
+ static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
+
+ private:
+ CompileOnlyService* compiler_service_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index bfd14bc1c0..aaed34f4c3 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -261,38 +261,6 @@ tensorflow::Status LocalClient::ResolveArguments(
argument_ptrs);
}
-StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-LocalClient::CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
- computations,
- const AotCompilationOptions& options) {
- std::vector<LocalService::AheadOfTimeComputationInstance> service_instances;
- service_instances.reserve(computations.size());
- for (const AheadOfTimeComputationInstance& instance : computations) {
- service_instances.push_back({});
- LocalService::AheadOfTimeComputationInstance& service_instance =
- service_instances.back();
- TF_RET_CHECK(instance.computation != nullptr);
- service_instance.computation = instance.computation->handle();
- service_instance.argument_layouts = instance.argument_layouts;
- service_instance.result_layout = instance.result_layout;
- }
- return local_service_->CompileAheadOfTime(service_instances, options);
-}
-
-int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) {
- llvm::Triple triple(
- llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple)));
- if (triple.isArch64Bit()) {
- return 8;
- } else if (triple.isArch32Bit()) {
- return 4;
- } else {
- CHECK(triple.isArch16Bit());
- return 2;
- }
-}
-
se::Platform* LocalClient::platform() const {
return local_service_->backend().platform();
}
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 2c467efcea..94d5610639 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -148,7 +148,7 @@ class LocalExecutable {
const ExecutableBuildOptions& build_options_;
};
-// An XLA service client object for use when the client and service run in
+// An XLA Client specialization for use when the client and service run in
// the same process.
class LocalClient : public Client {
public:
@@ -182,30 +182,6 @@ class LocalClient : public Client {
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const ExecutableBuildOptions& options);
- // A description of a computation to compile using CompileAheadOfTime.
- struct AheadOfTimeComputationInstance {
- const Computation* computation;
- // Inform the compiler of the expected layout for arguments.
- std::vector<const Shape*> argument_layouts;
- // Specifies the expected result layout.
- const Shape* result_layout;
- };
-
- // Compiles a list of computations for ahead-of-time execution. This is
- // intended for use in static compilation. The |options| parameter describes
- // the target for which the compiler should emit code.
- //
- // TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its
- // own library.
- StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
- CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
- computations,
- const AotCompilationOptions& options);
-
- // Returns the size of a pointer in bytes for a given triple.
- static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
-
// Returns the platform that the underlying service targets.
perftools::gputools::Platform* platform() const;
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 2452158efa..fd47ffe806 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -409,6 +409,27 @@ cc_library(
)
cc_library(
+ name = "compile_only_service",
+ srcs = ["compile_only_service.cc"],
+ hdrs = ["compile_only_service.h"],
+ deps = [
+ ":backend",
+ ":compiler",
+ ":computation_layout",
+ ":computation_tracker",
+ ":platform_util",
+ ":service",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
name = "cpu_plugin",
deps = [
":cpu_transfer_manager",
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
new file mode 100644
index 0000000000..ac1906c88c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -0,0 +1,131 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/compile_only_service.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+/* static */ StatusOr<std::unique_ptr<CompileOnlyService>>
+CompileOnlyService::NewService(perftools::gputools::Platform* platform) {
+ ServiceOptions default_options;
+ default_options.set_platform(platform);
+ return NewService(default_options);
+}
+
+/* static */ StatusOr<std::unique_ptr<CompileOnlyService>>
+CompileOnlyService::NewService(const ServiceOptions& options) {
+ perftools::gputools::Platform* platform = options.platform();
+ if (platform == nullptr) {
+ TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
+ }
+
+ TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
+ CreateComputeConstantBackend());
+ std::unique_ptr<CompileOnlyService> service(
+ new CompileOnlyService(compiler, std::move(compute_constant_backend)));
+ return std::move(service);
+}
+
+CompileOnlyService::CompileOnlyService(
+ Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend)
+ : Service(/*backend=*/nullptr, std::move(compute_constant_backend)),
+ compiler_(compiler) {
+ runs_in_client_process_ = true;
+}
+
+StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+CompileOnlyService::CompileAheadOfTime(
+ const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
+ const AotCompilationOptions& options) {
+ std::vector<std::unique_ptr<HloModule>> hlo_modules;
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
+ for (const AotComputationInstance& instance : computations) {
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
+ computation_tracker_.Resolve(instance.computation));
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandle();
+
+ // Dump computation proto state if flag is set.
+ legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
+ const string& directory_path = flags->xla_dump_computations_to;
+ if (!directory_path.empty()) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<SessionModule> session_module,
+ computation_tracker_.SnapshotComputation(versioned_handle.handle));
+ string filename = tensorflow::strings::StrCat(
+ "computation_", versioned_handle.handle.handle(), "__",
+ session_module->entry().name(), "__version_",
+ versioned_handle.version);
+ TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
+ *session_module));
+ }
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
+ computation_tracker_.BuildHloModule(
+ versioned_handle,
+ /*include_unreachable_instructions=*/true));
+ hlo_modules.push_back(std::move(hlo_module));
+
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> program_shape,
+ user_computation->ComputeProgramShape(versioned_handle.version));
+
+ module_configs.push_back(MakeUnique<HloModuleConfig>(*program_shape));
+ HloModuleConfig* module_config = module_configs.back().get();
+ auto* computation_layout =
+ module_config->mutable_entry_computation_layout();
+ if (flags->xla_hlo_profile) {
+ module_config->enable_hlo_profiling(true);
+ }
+ for (int i = 0; i < instance.argument_layouts.size(); ++i) {
+ const Shape& argument_layout = *instance.argument_layouts[i];
+ if (ShapeUtil::IsTuple(argument_layout)) {
+ return Unimplemented("tuple arguments not supported yet");
+ }
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
+ argument_layout));
+ }
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ *instance.result_layout));
+ }
+
+ return compiler_->CompileAheadOfTime(std::move(hlo_modules),
+ std::move(module_configs),
+ MakeHloDumper(), options);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
new file mode 100644
index 0000000000..06735b21ca
--- /dev/null
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -0,0 +1,128 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
+
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/service.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+// An XLA Service specialization for ahead-of-time compilation. This only
+// instantiates a Compiler object for the relevant platform; it does not
+// instantiate or require an execution backend.
+class CompileOnlyService : public Service {
+ public:
+ // Factory for creating a CompileOnlyService. The parameter platform is the
+ // platform that the service should target. If platform is null then the
+ // default platform is used.
+ static StatusOr<std::unique_ptr<CompileOnlyService>> NewService(
+ perftools::gputools::Platform* platform);
+ static StatusOr<std::unique_ptr<CompileOnlyService>> NewService(
+ const ServiceOptions& options);
+
+ // A description of a computation to compile using CompileAheadOfTime.
+ struct AotComputationInstance {
+ ComputationHandle computation;
+ std::vector<const Shape*> argument_layouts;
+ const Shape* result_layout = nullptr;
+ };
+
+ // Compiles a list of computations for ahead-of-time execution. This is
+ // intended for use in static compilation. See
+ // |CompileOnlyClient::CompileAheadOfTime| for additional details.
+ StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+ CompileAheadOfTime(
+ const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
+ const AotCompilationOptions& Options);
+
+ // Override Service methods that require an execute backend.
+ tensorflow::Status Execute(const ExecuteRequest* arg,
+ ExecuteResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support execution.");
+ }
+ tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
+ ExecuteParallelResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support execution.");
+ }
+ tensorflow::Status GetDeviceHandles(
+ const GetDeviceHandlesRequest* arg,
+ GetDeviceHandlesResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support devices.");
+ }
+ tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
+ ExecuteAsyncResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support execution.");
+ }
+ tensorflow::Status WaitForExecution(
+ const WaitForExecutionRequest* arg,
+ WaitForExecutionResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support execution.");
+ }
+ tensorflow::Status TransferToClient(
+ const TransferToClientRequest* arg,
+ TransferToClientResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support data transfers.");
+ }
+ tensorflow::Status TransferToClientInProcess(
+ const TransferToClientInProcessRequest* arg,
+ TransferToClientInProcessResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support data transfers.");
+ }
+ tensorflow::Status TransferToServer(
+ const TransferToServerRequest* arg,
+ TransferToServerResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support data transfers.");
+ }
+ tensorflow::Status TransferToInfeed(
+ const TransferToInfeedRequest* arg,
+ TransferToInfeedResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support data transfers.");
+ }
+ tensorflow::Status TransferFromOutfeed(
+ const TransferFromOutfeedRequest* arg,
+ TransferFromOutfeedResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support data transfers.");
+ }
+ tensorflow::Status TransferToServerInProcess(
+ const TransferToServerInProcessRequest* arg,
+ TransferToServerInProcessResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support data transfers.");
+ }
+ tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
+ ResetDeviceResponse* result) override {
+ return Unimplemented("CompileOnlyService does not support devices.");
+ }
+
+ private:
+ explicit CompileOnlyService(
+ Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend);
+ CompileOnlyService(const CompileOnlyService&) = delete;
+ void operator=(const CompileOnlyService&) = delete;
+
+ // The compiler for the target platform. This is included in place of
+ // the Service::execute_backend_'s compiler, since execute_backend_ is a
+ // nullptr in CompileOnlyService.
+ Compiler* compiler_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 17d7b97b21..6947c5d2e1 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -128,70 +128,6 @@ StatusOr<GlobalDataHandle> LocalService::AllocateBufferOnDevice(
allocation_size));
}
-StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-LocalService::CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
- computations,
- const AotCompilationOptions& options) {
- std::vector<std::unique_ptr<HloModule>> hlo_modules;
- std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
- for (const AheadOfTimeComputationInstance& instance : computations) {
- TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
- computation_tracker_.Resolve(instance.computation));
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandle();
-
- // Dump computation proto state if flag is set.
- legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
- const string& directory_path = flags->xla_dump_computations_to;
- if (!directory_path.empty()) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<SessionModule> session_module,
- computation_tracker_.SnapshotComputation(versioned_handle.handle));
- string filename = tensorflow::strings::StrCat(
- "computation_", versioned_handle.handle.handle(), "__",
- session_module->entry().name(), "__version_",
- versioned_handle.version);
- TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
- *session_module));
- }
-
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
- computation_tracker_.BuildHloModule(
- versioned_handle,
- /*include_unreachable_instructions=*/true));
- hlo_modules.push_back(std::move(hlo_module));
-
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> program_shape,
- user_computation->ComputeProgramShape(versioned_handle.version));
-
- module_configs.push_back(MakeUnique<HloModuleConfig>(*program_shape));
- HloModuleConfig* module_config = module_configs.back().get();
- auto* computation_layout =
- module_config->mutable_entry_computation_layout();
- if (flags->xla_hlo_profile) {
- module_config->enable_hlo_profiling(true);
- }
- for (int i = 0; i < instance.argument_layouts.size(); ++i) {
- const Shape& argument_layout = *instance.argument_layouts[i];
- if (ShapeUtil::IsTuple(argument_layout)) {
- return Unimplemented("tuple arguments not supported yet");
- }
- TF_RETURN_IF_ERROR(
- computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
- argument_layout));
- }
- TF_RETURN_IF_ERROR(
- computation_layout->mutable_result_layout()->CopyLayoutFromShape(
- *instance.result_layout));
- }
-
- return execute_backend_->compiler()->CompileAheadOfTime(
- std::move(hlo_modules), std::move(module_configs), MakeHloDumper(),
- options);
-}
-
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const ComputationHandle& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index df27f0a7a6..a1a2ef98e9 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -59,22 +59,6 @@ class LocalService : public Service {
const Shape& shape, int device_ordinal,
bool allocate_space_for_deep_copy);
- // A description of a computation to compile using CompileAheadOfTime.
- struct AheadOfTimeComputationInstance {
- ComputationHandle computation;
- std::vector<const Shape*> argument_layouts;
- const Shape* result_layout = nullptr;
- };
-
- // Compiles a list of computations for ahead-of-time execution. This is
- // intended for use in static compilation. See
- // |LocalClient::CompileAheadOfTime| for additional details.
- StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
- CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
- computations,
- const AotCompilationOptions& Options);
-
// Builds an Executable with the given argument layouts and options. If
// result_layout is non-null, then the executable is compiled to produce a
// result of the given layout.
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 451bb8c7ea..892265f5b6 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -180,20 +180,24 @@ Service::Service(std::unique_ptr<Backend> execute_backend,
std::unique_ptr<Backend> compute_constant_backend)
: execute_backend_(std::move(execute_backend)),
compute_constant_backend_(std::move(compute_constant_backend)) {
- LOG(INFO) << Printf(
- "XLA service %p executing computations on platform %s. Devices:", this,
- execute_backend_->platform()->Name().c_str());
- for (int i = 0; i < execute_backend_->device_count(); ++i) {
- if (execute_backend_->device_ordinal_supported(i)) {
- se::StreamExecutor* executor =
- execute_backend_->stream_executor(i).ValueOrDie();
- const auto& description = executor->GetDeviceDescription();
- LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i,
- description.name().c_str(),
- description.platform_version().c_str());
- } else {
- LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i);
+ if (execute_backend_) {
+ LOG(INFO) << Printf(
+ "XLA service %p executing computations on platform %s. Devices:", this,
+ execute_backend_->platform()->Name().c_str());
+ for (int i = 0; i < execute_backend_->device_count(); ++i) {
+ if (execute_backend_->device_ordinal_supported(i)) {
+ se::StreamExecutor* executor =
+ execute_backend_->stream_executor(i).ValueOrDie();
+ const auto& description = executor->GetDeviceDescription();
+ LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i,
+ description.name().c_str(),
+ description.platform_version().c_str());
+ } else {
+ LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i);
+ }
}
+ } else {
+ VLOG(1) << "XLA compile-only service constructed";
}
}
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 7ea83a9e95..52816dc72c 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -42,7 +42,7 @@ xla::Computation Doubler(xla::Client* client) {
int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
- auto client = xla::ClientLibrary::LocalClientOrDie();
+ auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie();
xla::ComputationBuilder builder(client, "aot_test_helper");
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
@@ -74,7 +74,7 @@ int main(int argc, char** argv) {
llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string));
xla::Computation computation = builder.Build().ConsumeValueOrDie();
- xla::LocalClient::AheadOfTimeComputationInstance instance{
+ xla::CompileOnlyClient::AotComputationInstance instance{
&computation, /*argument_layouts=*/{&opaque_shape}, &r0f32};
xla::cpu::CpuAotCompilationOptions options(
diff --git a/tensorflow/opensource_only/eigen.threadpool b/tensorflow/opensource_only/eigen.threadpool
new file mode 100644
index 0000000000..d2639af4d9
--- /dev/null
+++ b/tensorflow/opensource_only/eigen.threadpool
@@ -0,0 +1 @@
+#include "unsupported/Eigen/CXX11/ThreadPool"