aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-26 15:36:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-26 15:43:10 -0700
commit79c828ea6ddbcfccd43a2be176fc1dcad4daf34e (patch)
treeadd42271d2d23096b03af3bf8fb89384cffc9a54 /tensorflow/core/distributed_runtime
parentec34de06981eed74c2c2a47c8a6372735e9d3622 (diff)
Support shapes for remote eager tensor handles.
Since we respond with the shape, all RPCs will happen sync (note that we may still hide the python overhead, since the op is still scheduled for execution via the eager executor). PiperOrigin-RevId: 202207324
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/eager/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.cc21
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.h3
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc5
-rw-r--r--tensorflow/core/distributed_runtime/eager/remote_execute_node.h34
5 files changed, 61 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index 710bd8d021..22d0902af2 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -37,6 +37,7 @@ cc_library(
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:eager_executor",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 2fa234c810..5a26d5bf48 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -128,8 +128,20 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
return Status::OK();
}
+Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
+ const tensorflow::Tensor* t = nullptr;
+
+ // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
+ TF_RETURN_IF_ERROR(handle->Tensor(&t));
+
+ t->shape().AsProto(proto);
+
+ return Status::OK();
+}
+
Status EagerServiceImpl::ExecuteOp(const Operation& operation,
- ServerContext* server_context) {
+ ServerContext* server_context,
+ QueueResponse* queue_response) {
std::unique_ptr<tensorflow::EagerOperation> op;
const char* name = operation.name().c_str(); // Shorthand
const tensorflow::AttrTypeMap* types;
@@ -172,6 +184,10 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
server_context->AddOperationOutputs(retvals, operation.id());
+ for (auto* handle : retvals) {
+ TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape()));
+ }
+
return Status::OK();
}
@@ -182,8 +198,9 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
core::ScopedUnref context_unref(context);
for (const auto& item : request->queue()) {
+ auto* queue_response = response->add_queue_response();
if (item.has_operation()) {
- TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context));
+ TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response));
} else {
TF_RETURN_IF_ERROR(context->DeleteTensorHandle(
RemoteTensorHandleInternal(item.handle_to_decref())));
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
index ebd5269a57..b0e4aa84b9 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
@@ -135,7 +135,8 @@ class EagerServiceImpl {
tensorflow::Status GetServerContext(uint64, ServerContext**);
private:
- Status ExecuteOp(const Operation& operation, ServerContext* server_context);
+ Status ExecuteOp(const Operation& operation, ServerContext* server_context,
+ QueueResponse* queue_response);
const WorkerEnv* const env_; // Not owned.
mutex contexts_mu_;
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index 91b58698a4..b98386ba86 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -198,6 +198,11 @@ TEST_F(EagerServiceImplTest, BasicTest) {
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
&remote_enqueue_response));
+ auto& matmul_result_shape =
+ remote_enqueue_response.queue_response(1).shape(0);
+ EXPECT_EQ(matmul_result_shape.dim(0).size(), 2);
+ EXPECT_EQ(matmul_result_shape.dim(1).size(), 2);
+
tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle));
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index c4bd67aaed..28b68c3b88 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/protobuf/eager_service.pb.h"
@@ -27,6 +28,22 @@ namespace eager {
// via RPC in a remote EagerService.
class RemoteExecuteNode : public tensorflow::EagerNode {
public:
+ RemoteExecuteNode(
+ tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request,
+ tensorflow::eager::EagerClient* eager_client,
+ const gtl::InlinedVector<TensorHandle*, 4>& inputs,
+ std::function<void(const Status& status, const EnqueueResponse& response)>
+ done_callback)
+ : tensorflow::EagerNode(id),
+ request_(std::move(request)),
+ eager_client_(eager_client),
+ inputs_(inputs),
+ done_callback_(std::move(done_callback)) {
+ for (auto* handle : inputs_) {
+ handle->Ref();
+ }
+ }
+
RemoteExecuteNode(tensorflow::uint64 id,
const tensorflow::eager::EnqueueRequest& request,
tensorflow::eager::EagerClient* eager_client)
@@ -34,6 +51,12 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
request_(std::move(request)),
eager_client_(eager_client) {}
+ ~RemoteExecuteNode() {
+ for (auto* handle : inputs_) {
+ handle->Unref();
+ }
+ }
+
tensorflow::Status Run() override {
tensorflow::eager::EnqueueResponse response;
tensorflow::Status status;
@@ -45,6 +68,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
});
n.WaitForNotification();
+ if (done_callback_) {
+ done_callback_(status, response);
+ }
+
return status;
}
@@ -52,6 +79,13 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
EnqueueRequest request_;
tensorflow::eager::EagerClient*
eager_client_; // Not owned, and must outlive the RemoteExecuteNode.
+
+ // This is required to ensure that the tensor handles stay alive across the
+ // execution.
+ gtl::InlinedVector<TensorHandle*, 4> inputs_;
+
+ std::function<void(const Status& status, const EnqueueResponse& response)>
+ done_callback_;
};
} // namespace eager