aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-06-11 11:57:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 12:04:05 -0700
commit1fefd1af5b30bfe6213271da558c5131fd33ce0a (patch)
treee84962f630bfba2c468f5dfa58b6889d7f4bb61f
parente20ccaab7a85d729f37ad4b7b90188e97e2124fa (diff)
[XLA] Allow replay_computation to take an HLO textual string as input.
PiperOrigin-RevId: 200088845
-rw-r--r--tensorflow/compiler/xla/tools/BUILD1
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc52
2 files changed, 39 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index ff5340ee3f..e4a052c8f1 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -85,6 +85,7 @@ cc_library(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:testing",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
"//tensorflow/compiler/xla/tests:test_utils",
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index be094b7890..f7574e0b1c 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -24,6 +24,9 @@ limitations under the License.
// passing --use_fake_data on the command line. If the real data is available
// in the proto and --use_fake_data is false, the real data is used.
//
+// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a
+// textual HLO string.
+//
// The output format is:
//
// file_path: computation_name :: type:literal_str
@@ -43,6 +46,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -195,25 +199,45 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
return std::move(*result_literal);
}
+StatusOr<HloSnapshot> ParseInputFile(const string& filename,
+ const Options& opts) {
+ tensorflow::Env* env = tensorflow::Env::Default();
+ HloSnapshot snapshot;
+ if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) {
+ return snapshot;
+ }
+ CHECK(opts.use_fake_data)
+ << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto "
+ "and textual HLO don't carry real data.";
+ fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",
+ filename.c_str());
+
+ if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) {
+ return snapshot;
+ }
+ fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str());
+ string contents;
+ TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents));
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(contents);
+ if (module.ok()) {
+ *snapshot.mutable_hlo()->mutable_hlo_module() =
+ module.ValueOrDie()->ToProto();
+ return snapshot;
+ }
+ fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n",
+ filename.c_str());
+ return InvalidArgument("Could not parse %s.", filename.c_str());
+}
+
int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
- tensorflow::Env* env = tensorflow::Env::Default();
int exit_status = EXIT_SUCCESS;
for (char* arg : args) {
- HloSnapshot snapshot;
- auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot);
- if (!status.ok()) {
- fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", arg);
- status = tensorflow::ReadBinaryProto(env, arg, snapshot.mutable_hlo());
- if (!status.ok()) {
- fprintf(stderr, "%s: is not HloSnapshot or HloProto: %s.\n", arg,
- status.ToString().c_str());
- continue;
- }
- CHECK(opts.use_fake_data)
- << "HloProto input must be handled with --use_fake_data";
+ StatusOr<HloSnapshot> maybe_snapshot = ParseInputFile(arg, opts);
+ if (!maybe_snapshot.ok()) {
+ continue;
}
-
+ HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie();
StatusOr<Literal> result_status = ReplayComputation(snapshot, client, opts);
if (!result_status.ok()) {
fprintf(stderr, "%s: error: %s\n", arg,