diff options
author | 2017-03-15 12:12:46 -0800 | |
---|---|---|
committer | 2017-03-15 13:24:58 -0700 | |
commit | 70096fcdf81cb19e6b59311bab3d227bd7bd6175 (patch) | |
tree | 8b0dab94a8dfb8f5dca7512f6a531cbd1729a999 /tensorflow/compiler/xla/service/local_service.cc | |
parent | 348dd7c3845f74f30fb5317dd264ae65a42f0c1e (diff) |
[XLA] Add support for dumping computations during CompileAheadOfTime. Remove '/' and '\' characters from path names of dumped graphs.
Change: 150231912
Diffstat (limited to 'tensorflow/compiler/xla/service/local_service.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/local_service.cc | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index e041734836..03cbe84561 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -140,6 +141,21 @@ LocalService::CompileAheadOfTime( VersionedComputationHandle versioned_handle = user_computation->GetVersionedHandle(); + // Dump computation proto state if flag is set. + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + const string& directory_path = flags->xla_dump_computations_to; + if (!directory_path.empty()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr<SessionModule> session_module, + computation_tracker_.SnapshotComputation(versioned_handle.handle)); + string filename = tensorflow::strings::StrCat( + "computation_", versioned_handle.handle.handle(), "__", + session_module->entry().name(), "__version_", + versioned_handle.version); + TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, + *session_module)); + } + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module, computation_tracker_.BuildHloModule( versioned_handle, |