aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-09 16:17:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 16:20:56 -0800
commit4f333b63f7b46a3122f91b5358f2763e6c2e8206 (patch)
tree538bf42dde608a8e1db05fac207691e56fd4c893
parentc5e8d4819a897a5701470ae291e09811f5b4762f (diff)
[XLA] Add a whole graph execution interface.
PiperOrigin-RevId: 188554206
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/service.cc5
-rw-r--r--tensorflow/compiler/xla/service/service.h6
-rw-r--r--tensorflow/compiler/xla/service_interface.h3
-rw-r--r--tensorflow/compiler/xla/xla.proto9
5 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index c7cb69215f..cd13db4d30 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -52,6 +52,7 @@ xla_proto_library(
visibility = ["//visibility:public"],
deps = [
":xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:session_proto",
],
)
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 25c2fe97e4..8edd457281 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -937,6 +937,11 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
return tensorflow::Status::OK();
}
+tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* /*arg*/,
+ ExecuteResponse* /*result*/) {
+ return Unimplemented("execute-graph is not yet implemented");
+}
+
tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
ExecuteAsyncResponse* result) {
VLOG(1) << "running execute-async request: " << arg->ShortDebugString();
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index e047df2648..96352d9096 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -112,6 +112,12 @@ class Service : public ServiceInterface {
tensorflow::Status Execute(const ExecuteRequest* arg,
ExecuteResponse* result) override;
+ // Executes a computation with the provided global data passed as
+ // immutable arguments. The request contains the whole computation graph.
+ // Returns global data output and execution timing.
+ tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg,
+ ExecuteResponse* result) override;
+
// Executes one or more computations in parallel with the provided global data
// passed as immutable arguments. Returns global data output for each
// computation.
diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h
index 809941d8fe..d8235113dd 100644
--- a/tensorflow/compiler/xla/service_interface.h
+++ b/tensorflow/compiler/xla/service_interface.h
@@ -54,6 +54,9 @@ class ServiceInterface {
virtual tensorflow::Status Execute(const ExecuteRequest* arg,
ExecuteResponse* result) = 0;
+ virtual tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg,
+ ExecuteResponse* result) = 0;
+
virtual tensorflow::Status ExecuteParallel(
const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0;
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 56162ab44e..edf1b07af8 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -16,6 +16,7 @@ limitations under the License.
syntax = "proto3";
import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/compiler/xla/service/hlo.proto";
import "tensorflow/compiler/xla/service/session.proto";
package xla;
@@ -342,6 +343,14 @@ message ExecuteRequest {
ExecutionOptions execution_options = 5;
}
+message ExecuteGraphRequest {
+ HloModuleProto computation = 1;
+ repeated GlobalDataHandle arguments = 2;
+
+ // Options that affect how XLA compiles and runs code to service this request.
+ ExecutionOptions execution_options = 3;
+}
+
message ExecuteParallelRequest {
repeated ExecuteRequest requests = 1;
}