aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tools/replay_computation.cc
blob: 89b26b8916b67eeb38852c9e91314187fc8a7d48 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Usage: replay_computation some_binary_snapshot_proto*
//
// Replays computations and shows the results on the command line.
//
// some_binary_snapshot_proto is obtained by serializing the SessionModule from
// ServiceInterface::SnapshotComputation to disk.
//
// Computations that require arguments can be replayed using fake data by
// passing --use_fake_data on the command line.  If the real data is available
// in the proto and --use_fake_data is false, the real data is used.
//
// The output format is:
//
// file_path: computation_name :: type:literal_str

#include <stdio.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/testing.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"

namespace xla {
namespace tools {
namespace {

// Invokes the given computation passing arbitrary data for every (unbound)
// parameter if use_fake_data, Otherwise use recorded data if available.
//
// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
// otherwise, no infeed is performed.
StatusOr<std::unique_ptr<Literal>> ReplayComputation(
    const SessionModule& module, tensorflow::StringPiece fake_infeed_shape,
    bool use_fake_data, Client* client) {
  TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module));

  std::vector<std::unique_ptr<GlobalData>> arguments;
  if (use_fake_data) {
    arguments = MakeFakeArgumentsOrDie(computation, client);
  } else {  // use recorded data if available
    for (const auto& proto : module.arguments()) {
      Literal literal(proto);
      TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
                          client->TransferToServer(literal));
      arguments.push_back(std::move(data));
    }
  }

  // We only instantiate the thread pool if the user has requested that a
  // concurrent infeed occur via the fake_infeed_shape.
  tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool;

  if (!fake_infeed_shape.empty()) {
    pool.emplace(tensorflow::Env::Default(), "infeed",
                 /*num_threads=*/1);
    pool->Schedule([fake_infeed_shape, client]() {
      StatusOr<Shape> shape_status =
          ShapeUtil::ParseShapeString(fake_infeed_shape);
      TF_CHECK_OK(shape_status.status());
      Shape shape = std::move(shape_status).ValueOrDie();
      StatusOr<std::unique_ptr<Literal>> data_status = MakeFakeLiteral(shape);
      TF_CHECK_OK(data_status.status());
      std::unique_ptr<Literal> data = std::move(data_status).ValueOrDie();
      while (true) {
        TF_CHECK_OK(client->TransferToInfeed(*data));
      }
    });
  }

  std::vector<GlobalData*> execute_arguments;
  execute_arguments.reserve(arguments.size());
  for (auto& argument : arguments) {
    execute_arguments.push_back(argument.get());
  }
  return client->ExecuteAndTransfer(computation, execute_arguments);
}

int RealMain(tensorflow::gtl::ArraySlice<char*> args,
             tensorflow::StringPiece fake_infeed_shape, bool use_fake_data) {
  Client* client = ClientLibrary::LocalClientOrDie();
  tensorflow::Env* env = tensorflow::Env::Default();
  int exit_status = EXIT_SUCCESS;
  for (char* arg : args) {
    SessionModule module;
    TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module));
    StatusOr<std::unique_ptr<Literal>> result_status =
        ReplayComputation(module, fake_infeed_shape, use_fake_data, client);
    if (!result_status.ok()) {
      fprintf(stderr, "%s: error: %s\n", arg,
              result_status.status().ToString().c_str());
      exit_status = EXIT_FAILURE;
      continue;
    }
    std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
    fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(),
            ShapeUtil::HumanString(result->shape()).c_str(),
            result->ToString().c_str());
    if (module.has_result()) {
      fprintf(stdout, "was %s:%s\n",
              ShapeUtil::HumanString(module.result().shape()).c_str(),
              Literal(module.result()).ToString().c_str());
    }
  }
  return exit_status;
}

}  // namespace
}  // namespace tools
}  // namespace xla

int main(int argc, char** argv) {
  // Flags
  xla::string fake_infeed_shape;
  bool use_fake_data = false;
  const std::vector<tensorflow::Flag> flag_list = {
      tensorflow::Flag("use_fake_data", &use_fake_data,
                       "Replay computation using fake data"),
      tensorflow::Flag("fake_infeed_shape", &fake_infeed_shape,
                       "Shape of fake data to construct for (infinite) infeed"),
  };
  xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
  bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
  tensorflow::port::InitMain(argv[0], &argc, &argv);
  if (argc < 2 || !parse_ok) {
    LOG(QFATAL) << usage;
  }

  tensorflow::gtl::ArraySlice<char*> args(argv, argc);
  args.pop_front();  // Pop off the binary name, argv[0]
  return xla::tools::RealMain(args, fake_infeed_shape, use_fake_data);
}