aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-06-13 10:00:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 10:03:26 -0700
commit65cefda2f9a62f29af51b3effa0725c180244576 (patch)
treedd931fcbc392ac26f12725a45c01ca7f9507ed33 /tensorflow
parentf0e053afc99c8dcf6aa196b00dafaee0a7f6923f (diff)
Add AotCompilationMetadata field to variant of CompileAheadOfTime.
Add CompileAheadOfTime parameter that can optionally be populated during compilation process. This change is to allow populating metadata even if the CompileAheadOfTime fails. PiperOrigin-RevId: 200407917
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.cc6
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.h10
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc6
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h6
-rw-r--r--tensorflow/compiler/xla/service/compiler.cc14
-rw-r--r--tensorflow/compiler/xla/service/compiler.h20
6 files changed, 54 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc
index dc69d2097e..5c9abad4c3 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.cc
+++ b/tensorflow/compiler/xla/client/compile_only_client.cc
@@ -24,7 +24,8 @@ namespace xla {
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyClient::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
- const AotCompilationOptions& options) {
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata) {
std::vector<CompileOnlyService::AotXlaComputationInstance> service_instances;
service_instances.reserve(computations.size());
for (const AotXlaComputationInstance& instance : computations) {
@@ -36,7 +37,8 @@ CompileOnlyClient::CompileAheadOfTime(
service_instance.argument_layouts = instance.argument_layouts;
service_instance.result_layout = instance.result_layout;
}
- return compiler_service_->CompileAheadOfTime(service_instances, options);
+ return compiler_service_->CompileAheadOfTime(service_instances, options,
+ metadata);
}
int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) {
diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h
index f9a7c31270..332c965036 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.h
+++ b/tensorflow/compiler/xla/client/compile_only_client.h
@@ -46,13 +46,15 @@ class CompileOnlyClient : public Client {
const Shape* result_layout;
};
- // Compiles a list of xla computations for ahead-of-time execution. This is
- // intended for use in static compilation. The |options| parameter describes
- // the target for which the compiler should emit code.
+ // Compiles a list of xla computations for ahead-of-time execution.
+ // This is intended for use in static compilation. The |options|
+ // parameter describes the target for which the compiler should emit
+ // code. |metadata|, if provided, is populated during compilation.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
- const AotCompilationOptions& options);
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata = nullptr);
// Returns the size of a pointer in bytes for a given triple.
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index d8fdccf9bb..7426672a7a 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -63,7 +63,8 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
- const AotCompilationOptions& options) {
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata) {
std::vector<std::unique_ptr<HloModule>> hlo_modules;
for (const AotXlaComputationInstance& instance : computations) {
TF_RET_CHECK(instance.computation.has_program_shape());
@@ -100,7 +101,8 @@ CompileOnlyService::CompileAheadOfTime(
hlo_modules.push_back(std::move(hlo_module));
}
- return compiler_->CompileAheadOfTime(std::move(hlo_modules), options);
+ return compiler_->CompileAheadOfTime(std::move(hlo_modules), options,
+ metadata);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
index e6a66c202d..1ac950bdd6 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.h
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -53,6 +53,12 @@ class CompileOnlyService : public Service {
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
const AotCompilationOptions& options);
+ StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+ CompileAheadOfTime(
+ const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata);
+
Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) override {
return Unimplemented("CompileOnlyService does not support devices.");
diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc
index 6f06bba679..0dceed853d 100644
--- a/tensorflow/compiler/xla/service/compiler.cc
+++ b/tensorflow/compiler/xla/service/compiler.cc
@@ -35,6 +35,20 @@ Compiler::ComputeBackendConfigs(const HloInstruction& hlo,
return {};
}
+// Define a default version where metadata is not used.
+StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+Compiler::CompileAheadOfTime(
+ std::vector<std::unique_ptr<HloModule>> modules,
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata) {
+ if (metadata != nullptr) {
+ return Unimplemented(
+ "Populating AotCompilationMetadata is not implemented on this "
+ "compiler.");
+ }
+ return CompileAheadOfTime(std::move(modules), options);
+}
+
/* static */ std::map<se::Platform::Id, Compiler::CompilerFactory>*
Compiler::GetPlatformCompilerFactories() {
static auto* r = new std::map<se::Platform::Id, CompilerFactory>;
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index 6c52ffd800..d1144f97bb 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -94,6 +94,19 @@ class AotCompilationOptions {
DebugOptions debug_options_;
};
+// Abstract superclass describing metadata produced during ahead-of-time
+// compilation.
+class AotCompilationMetadata {
+ public:
+ AotCompilationMetadata(const AotCompilationMetadata&) = delete;
+ AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete;
+
+ virtual ~AotCompilationMetadata() = default;
+
+ protected:
+ AotCompilationMetadata() = default;
+};
+
// Abstract compiler interface that is subclassed for compilation on a
// particular platform.
//
@@ -172,6 +185,13 @@ class Compiler {
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
const AotCompilationOptions& options) = 0;
+ // Similar to CompileAheadOfTime above but AotCompilationMetadata
+ // has an argument that can be populated during compilation.
+ virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+ CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata);
+
/////
// The Compiler class also serves as a point to register compiler objects
// for the various platforms.