diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-09 16:17:08 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-09 16:20:56 -0800 |
commit | 4f333b63f7b46a3122f91b5358f2763e6c2e8206 (patch) | |
tree | 538bf42dde608a8e1db05fac207691e56fd4c893 | |
parent | c5e8d4819a897a5701470ae291e09811f5b4762f (diff) |
[XLA] Add a whole graph execution interface.
PiperOrigin-RevId: 188554206
-rw-r--r-- | tensorflow/compiler/xla/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service_interface.h | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/xla.proto | 9 |
5 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index c7cb69215f..cd13db4d30 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -52,6 +52,7 @@ xla_proto_library( visibility = ["//visibility:public"], deps = [ ":xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:session_proto", ], ) diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 25c2fe97e4..8edd457281 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -937,6 +937,11 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, return tensorflow::Status::OK(); } +tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* /*arg*/, + ExecuteResponse* /*result*/) { + return Unimplemented("execute-graph is not yet implemented"); +} + tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, ExecuteAsyncResponse* result) { VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index e047df2648..96352d9096 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -112,6 +112,12 @@ class Service : public ServiceInterface { tensorflow::Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; + // Executes a computation with the provided global data passed as + // immutable arguments. The request contains the whole computation graph. + // Returns global data output and execution timing. + tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; + // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 809941d8fe..d8235113dd 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -54,6 +54,9 @@ class ServiceInterface { virtual tensorflow::Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) = 0; + virtual tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) = 0; + virtual tensorflow::Status ExecuteParallel( const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 56162ab44e..edf1b07af8 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -16,6 +16,7 @@ limitations under the License. syntax = "proto3"; import "tensorflow/compiler/xla/xla_data.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; import "tensorflow/compiler/xla/service/session.proto"; package xla; @@ -342,6 +343,14 @@ message ExecuteRequest { ExecutionOptions execution_options = 5; } +message ExecuteGraphRequest { + HloModuleProto computation = 1; + repeated GlobalDataHandle arguments = 2; + + // Options that affect how XLA compiles and runs code to service this request. + ExecutionOptions execution_options = 3; +} + message ExecuteParallelRequest { repeated ExecuteRequest requests = 1; } |