From 1fefd1af5b30bfe6213271da558c5131fd33ce0a Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 11 Jun 2018 11:57:16 -0700 Subject: [XLA] Allow replay_computation to take an HLO textual string as input. PiperOrigin-RevId: 200088845 --- tensorflow/compiler/xla/tools/BUILD | 1 + .../compiler/xla/tools/replay_computation.cc | 52 ++++++++++++++++------ 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index ff5340ee3f..e4a052c8f1 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -85,6 +85,7 @@ cc_library( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/compiler/xla/tests:test_utils", diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index be094b7890..f7574e0b1c 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -24,6 +24,9 @@ limitations under the License. // passing --use_fake_data on the command line. If the real data is available // in the proto and --use_fake_data is false, the real data is used. // +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// // The output format is: // // file_path: computation_name :: type:literal_str @@ -43,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -195,25 +199,45 @@ StatusOr ReplayComputation(const HloSnapshot& module, return std::move(*result_literal); } +StatusOr ParseInputFile(const string& filename, + const Options& opts) { + tensorflow::Env* env = tensorflow::Env::Default(); + HloSnapshot snapshot; + if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) { + return snapshot; + } + CHECK(opts.use_fake_data) + << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " + "and textual HLO don't carry real data."; + fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", + filename.c_str()); + + if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) { + return snapshot; + } + fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); + string contents; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); + StatusOr> module = ParseHloString(contents); + if (module.ok()) { + *snapshot.mutable_hlo()->mutable_hlo_module() = + module.ValueOrDie()->ToProto(); + return snapshot; + } + fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", + filename.c_str()); + return InvalidArgument("Could not parse %s.", filename.c_str()); +} + int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); - tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { - HloSnapshot snapshot; - auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); - if (!status.ok()) { - fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", arg); - status = tensorflow::ReadBinaryProto(env, arg, snapshot.mutable_hlo()); - if (!status.ok()) { - fprintf(stderr, "%s: is not HloSnapshot or HloProto: %s.\n", arg, - status.ToString().c_str()); - continue; - } - CHECK(opts.use_fake_data) - << "HloProto input must be handled with --use_fake_data"; + StatusOr maybe_snapshot = ParseInputFile(arg, opts); + if (!maybe_snapshot.ok()) { + continue; } - + HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); StatusOr result_status = ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, -- cgit v1.2.3