aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_runner.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-12 10:14:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-12 10:16:34 -0700
commitffbf77de81d0b7b4b169c92d0d9fbbdef5b8842a (patch)
tree8d68eedf28bdcac55516b6a3a176d56f6cef0fa2 /tensorflow/compiler/xla/service/hlo_runner.h
parent8a247976484173059aedc17bfd8d770b8d1a70e1 (diff)
Introduced tool to run an HLO module in replicated fashion, by infeeding random data and outfeeding the data generated at each step.
The arguments of the computation can be either read from the session module, or randomly generated. The tool uses the raw transfer manager API to infeed and outfeed the data. PiperOrigin-RevId: 192628605
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_runner.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h66
1 files changed, 60 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index 06ce22a5b9..f54fb44766 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -16,12 +16,16 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
+#include <map>
#include <memory>
+#include <set>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -40,9 +44,43 @@ namespace xla {
// file), or parsed from a hlo textual IR string.
class HloRunner {
public:
- HloRunner();
-
- HloRunner(::perftools::gputools::Platform* platform);
+ // The options used to configure a ExecuteReplicated() call.
+ struct ReplicatedExecuteOptions {
+ // The number of devices the HLO module should be replicated onto.
+ int64 num_replicas = 1;
+
+ // The arguments to be fed to each replica. Since this is used for a
+ // replicated execution, all the arguments are the same for all replicas.
+ std::vector<const Literal*> arguments;
+
+ // If the HLO module being run has an infeed instruction, this will be the
+ // data which will be fed to it, for as many as infeed_steps steps.
+ const Literal* infeed = nullptr;
+
+ // The number of times the infeed literal should be fed to the HLO module.
+ // For a clean exit, this should match the iterations-per-loop parameter
+ // used when generating the HLO module proto (that is usually the main
+ // while bounary counter). A value higher then iterations-per-loop would
+ // lead to infeed threads feeding to a gone computation, while a lower
+ // value would trigger a stuck ExecuteReplicated() call (the computation
+ // will be trying to infeed data which will never come).
+ int64 infeed_steps = -1;
+
+ // The shape of the outfeed operation. If empty, the HLO module does not
+ // generate any outfeed.
+ Shape outfeed_shape;
+
+ // A pointer to a vector where the outfeed values will be stored. If
+ // nullptr, the values will be read and discarded.
+ std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr;
+
+ // Whether the HLO passes should be run on the input module. Usually
+ // saved modules are coming from after the HLO pass pipeline, so triggering
+ // another run will likely cause errors.
+ bool run_hlo_passes = false;
+ };
+
+ explicit HloRunner(::perftools::gputools::Platform* platform);
~HloRunner();
@@ -86,6 +124,13 @@ class HloRunner {
return Execute(std::move(module), argument_pointers, run_hlo_passes);
}
+ // Executes a given HLO module into a set of replicas, and returns a map
+ // with the replica number as key, and the corresponding returned literal as
+ // value.
+ StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated(
+ std::unique_ptr<HloModule> module,
+ const ReplicatedExecuteOptions& options);
+
// If backend is not created in the constructor, creates and returns the
// default backend. If creation fails, crashes the program.
//
@@ -94,9 +139,18 @@ class HloRunner {
Backend& backend();
private:
- struct EigenThreadPoolWrapper;
-
- std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
+ // Creates an executable object given an HLO module. If run_hlo_passes is
+ // true, the HLO passes will be run before.
+ StatusOr<std::unique_ptr<Executable>> CreateExecutable(
+ std::unique_ptr<HloModule> module, bool run_hlo_passes);
+
+ // Creates a ServiceExecutableRunOptions object to configure a run on device,
+ // using the provided stream object. If device_assignment is not nullptr, it
+ // will be used to configure the replication parameters. Replicated executions
+ // should pass the device_assignment parameter.
+ ServiceExecutableRunOptions GetServiceRunOptionsForDevice(
+ int64 device, ::perftools::gputools::Stream* stream,
+ DeviceAssignment* device_assignment);
std::unique_ptr<Backend> backend_;
};