diff options
author | Derek Murray <mrry@google.com> | 2017-01-04 18:34:21 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-04 18:48:03 -0800 |
commit | bf00bcc5fc75d9bd1d61c67cc6c2fc55708a26ea (patch) | |
tree | 4eb2e54ff54fb98d98864a7352e24f7dd0ef1b86 /tensorflow/core/distributed_runtime/worker_interface.h | |
parent | 37b430c48ff0a9df80e881c2b339463d5609e9b7 (diff) |
Provide multiple implementations of RPC requests on the feed path.
This CL includes wrapper classes for the protocol buffer messages
`tensorflow::RunStepRequest` and `tensorflow::RunGraphRequest`.
Previously the service arguments were always protocol buffer messages,
which can entail copying large tensor values into and out of the
request message. This change makes the backend code deal with abstract
`tensorflow::RunStepRequestWrapper` and
`tensorflow::RunGraphRequestWrapper` interfaces and adds three
concrete implementations of each interface:
* An mutable in-memory wrapper, which maintains the tensor data in
`tensorflow::Tensor` objects, and provides the most efficient
implementation when the client and master are in the same address
space.
* A mutable protobuf wrapper, which has a similar implementation to
today's client code.
* A const wrapper around a const protobuf, which has a similar
implementation to today's server code.
This is another improvement for issue #6256.
Change: 143620823
Diffstat (limited to 'tensorflow/core/distributed_runtime/worker_interface.h')
-rw-r--r-- | tensorflow/core/distributed_runtime/worker_interface.h | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index 823a29e226..577ecf25ed 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -19,6 +19,7 @@ limitations under the License. #include <functional> #include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/message_wrappers.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -47,10 +48,26 @@ class WorkerInterface { DeregisterGraphResponse* response, StatusCallback done) = 0; - virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request, - RunGraphResponse* response, + virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, + RunGraphResponse* repsonse, StatusCallback done) = 0; + virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request, + RunGraphResponse* response, StatusCallback done) { + // TODO(mrry): Convert this to std::bind/std::move if the overhead + // of std::function copying becomes too much. + RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request); + RunGraphAsync(opts, wrapped_request, response, + [wrapped_request, done](const Status& s) { + done(s); + delete wrapped_request; + }); + } + + virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() { + return new MutableProtoRunGraphRequest; + } + virtual void CleanupGraphAsync(const CleanupGraphRequest* request, CleanupGraphResponse* response, StatusCallback done) = 0; |