aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/worker_interface.h
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2017-01-04 18:34:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-04 18:48:03 -0800
commitbf00bcc5fc75d9bd1d61c67cc6c2fc55708a26ea (patch)
tree4eb2e54ff54fb98d98864a7352e24f7dd0ef1b86 /tensorflow/core/distributed_runtime/worker_interface.h
parent37b430c48ff0a9df80e881c2b339463d5609e9b7 (diff)
Provide multiple implementations of RPC requests on the feed path.
This CL includes wrapper classes for the protocol buffer messages `tensorflow::RunStepRequest` and `tensorflow::RunGraphRequest`. Previously the service arguments were always protocol buffer messages, which can entail copying large tensor values into and out of the request message. This change makes the backend code deal with abstract `tensorflow::RunStepRequestWrapper` and `tensorflow::RunGraphRequestWrapper` interfaces and adds three concrete implementations of each interface: * An mutable in-memory wrapper, which maintains the tensor data in `tensorflow::Tensor` objects, and provides the most efficient implementation when the client and master are in the same address space. * A mutable protobuf wrapper, which has a similar implementation to today's client code. * A const wrapper around a const protobuf, which has a similar implementation to today's server code. This is another improvement for issue #6256. Change: 143620823
Diffstat (limited to 'tensorflow/core/distributed_runtime/worker_interface.h')
-rw-r--r--tensorflow/core/distributed_runtime/worker_interface.h21
1 files changed, 19 insertions, 2 deletions
diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h
index 823a29e226..577ecf25ed 100644
--- a/tensorflow/core/distributed_runtime/worker_interface.h
+++ b/tensorflow/core/distributed_runtime/worker_interface.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <functional>
#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
@@ -47,10 +48,26 @@ class WorkerInterface {
DeregisterGraphResponse* response,
StatusCallback done) = 0;
- virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
- RunGraphResponse* response,
+ virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
+ RunGraphResponse* repsonse,
StatusCallback done) = 0;
+ virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
+ RunGraphResponse* response, StatusCallback done) {
+ // TODO(mrry): Convert this to std::bind/std::move if the overhead
+ // of std::function copying becomes too much.
+ RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
+ RunGraphAsync(opts, wrapped_request, response,
+ [wrapped_request, done](const Status& s) {
+ done(s);
+ delete wrapped_request;
+ });
+ }
+
+ virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
+ return new MutableProtoRunGraphRequest;
+ }
+
virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) = 0;