diff options
author | 2018-09-13 18:34:29 -0700 | |
---|---|---|
committer | 2018-09-13 18:38:09 -0700 | |
commit | 1831ef73ba693ba7f27a3ecb391b47601e6a3758 (patch) | |
tree | fcd3209a20a518a94056251ffcbcea011b106983 | |
parent | eb5cd6926ef8d2a5a748f1aa978e51148e22dd97 (diff) |
[XLA] Add hook for dump directory expansion.
Also puts a ".unoptimized" suffix on dumped HLO protobuf files
to avoid the unoptimized dumped HLO protobuf colliding with the
optimized dumped HLO protobufs when the same dump directory is
specified for both.
PiperOrigin-RevId: 212914100
-rw-r--r-- | tensorflow/compiler/xla/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/protobuf_util.cc | 29 | ||||
-rw-r--r-- | tensorflow/compiler/xla/protobuf_util.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/compile_only_service.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service.cc | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service.h | 4 |
6 files changed, 40 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 76e36f3c46..ef70c1f8ac 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -193,6 +193,7 @@ cc_library( ":types", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index 787725e884..b507a2ef79 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { @@ -49,16 +50,40 @@ string SanitizeFilename(const string& file_name) { return safe_file_name; } +std::pair<tensorflow::mutex*, std::vector<std::function<string(string)>>*> +GetDirectoryExpanders() { + static auto* mutex = new tensorflow::mutex; + static auto* singleton = new std::vector<std::function<string(string)>>; + return {mutex, singleton}; +} + +// Runs all the directory expanders over x and returns the result. +string Expand(string x) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + for (const auto& f : *pair.second) { + x = f(x); + } + return x; +} + } // namespace Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name) { tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string expanded_dir = Expand(directory); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir)); string safe_file_name = SanitizeFileName(file_name) + ".pb"; - const string path = tensorflow::io::JoinPath(directory, safe_file_name); + const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name); return tensorflow::WriteBinaryProto(env, path, message); } +void RegisterDirectoryExpander(const std::function<string(string)>& expander) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + pair.second->push_back(expander); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 3667621367..f22fc8b849 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -39,6 +39,10 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name); +// Registers a function that may either expand a dirpath or forward the original +// dirpath along as-is. +void RegisterDirectoryExpander(const std::function<string(string)>& expander); + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index e5a6c28478..96bd2616f5 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr<HloModule> hlo_module, HloModule::CreateFromProto(instance.computation, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 922ebdf0e3..b27a92f2a0 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -812,7 +812,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable( TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module, HloModule::CreateFromProto(module_proto, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, @@ -1160,7 +1160,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas( return replicas; } -Status Service::MaybeDumpHloModule(const HloModule& module) const { +Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const { const string xla_dump_unoptimized_hlo_proto_to = module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); if (xla_dump_unoptimized_hlo_proto_to.empty()) { @@ -1168,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const { } HloProto proto = MakeHloProto(module); return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_unoptimized_hlo_proto_to, module.name()); + proto, xla_dump_unoptimized_hlo_proto_to, + StrCat(module.name(), ".unoptimized")); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 44c5248b15..1f62fad4c8 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -271,7 +271,9 @@ class Service : public ServiceInterface { StatusOr<std::vector<se::StreamExecutor*>> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; - Status MaybeDumpHloModule(const HloModule& module) const; + // Dumps the (unoptimized) module given if the corresponding DebugOptions + // field has been set. + Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const; // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. |