aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-07-24 14:10:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 14:14:04 -0700
commite9398c43cf470a7388df7d20baf6dd10a3b42edb (patch)
treeb1f901eaa30d36533f5056de3e6ccdc6805fb66e /tensorflow/core/distributed_runtime
parent76e8f7b7fdf89b131e0406022129d5dde6b89e40 (diff)
Push tensors from client to workers.
At times, a server cannot open a reverse connection to the client. This is required when using the _Send/_Recv ops and the client needs to send a tensor to the server (tensors are pulled). Instead, this adds a way to push the tensors directly from the client. Currently, pushing tensors always happens in sync mode. PiperOrigin-RevId: 205888825
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_client.h1
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.cc36
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.h3
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc79
-rw-r--r--tensorflow/core/distributed_runtime/eager/remote_execute_node.h19
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc1
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc14
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h15
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc1
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h1
10 files changed, 151 insertions, 19 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/eager_client.h b/tensorflow/core/distributed_runtime/eager/eager_client.h
index 9ba8c8d80c..707f3234b9 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_client.h
+++ b/tensorflow/core/distributed_runtime/eager/eager_client.h
@@ -39,6 +39,7 @@ class EagerClient {
CLIENT_METHOD(KeepAlive);
CLIENT_METHOD(CloseContext);
CLIENT_METHOD(RegisterFunction);
+ CLIENT_METHOD(SendTensor);
#undef CLIENT_METHOD
};
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 466e779fab..916c8720f0 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -81,10 +81,11 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
CreateContextResponse* response) {
- //make sure env_ , env_->rendezvous_mgr available
+ // make sure env_ , env_->rendezvous_mgr available
if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
- return tensorflow::errors::Internal("invalid eager env_ or env_->rendezvous_mgr.");
- }
+ return tensorflow::errors::Internal(
+ "invalid eager env_ or env_->rendezvous_mgr.");
+ }
std::vector<tensorflow::Device*> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
@@ -266,6 +267,35 @@ Status EagerServiceImpl::RegisterFunction(
return context->Context()->AddFunctionDef(request->function_def());
}
+Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
+ SendTensorResponse* response) {
+ ServerContext* context = nullptr;
+ TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
+ core::ScopedUnref context_unref(context);
+
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
+ for (const auto& tensor_proto : request->tensors()) {
+ Tensor tensor;
+ if (!tensor.FromProto(tensor_proto)) {
+ return errors::InvalidArgument("Unable to parse tensor proto");
+ }
+
+ TensorHandle* tensor_handle =
+ new TensorHandle(tensor, nullptr, nullptr, nullptr);
+
+ TensorHandle* copied_handle = nullptr;
+ TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(),
+ request->device_name().c_str(),
+ &copied_handle));
+ tensors.push_back(copied_handle);
+ tensor_handle->Unref();
+ }
+
+ context->AddOperationOutputs(tensors, request->op_id());
+
+ return Status::OK();
+}
+
tensorflow::Status EagerServiceImpl::GetServerContext(
uint64 context_id, ServerContext** server_context) {
mutex_lock l(contexts_mu_);
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
index b0e4aa84b9..718b4e2457 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
@@ -62,6 +62,9 @@ class EagerServiceImpl {
Status RegisterFunction(const RegisterFunctionRequest* request,
RegisterFunctionResponse* response);
+ Status SendTensor(const SendTensorRequest* request,
+ SendTensorResponse* response);
+
protected:
// This is the server-side execution context. All state regarding execution of
// a client's ops is held in this server-side context (all generated tensors,
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index b98386ba86..d1f2a6da8f 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -84,7 +84,7 @@ class EagerServiceImplTest : public ::testing::Test {
std::unique_ptr<DeviceMgr> device_mgr_;
};
-void SetTensorProto(AttrValue* val) {
+void SetTensorProto(TensorProto* tensor_proto) {
int64_t dims[] = {2, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
TF_Tensor* t = TF_AllocateTensor(
@@ -92,7 +92,7 @@ void SetTensorProto(AttrValue* val) {
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
tensorflow::Tensor tensor;
TF_ASSERT_OK(tensorflow::TF_TensorToTensor(t, &tensor));
- tensor.AsProtoTensorContent(val->mutable_tensor());
+ tensor.AsProtoTensorContent(tensor_proto);
TF_DeleteTensor(t);
}
@@ -175,7 +175,7 @@ TEST_F(EagerServiceImplTest, BasicTest) {
val.set_type(tensorflow::DataType::DT_FLOAT);
const_attrs.insert({"dtype", val});
val.Clear();
- SetTensorProto(&val);
+ SetTensorProto(val.mutable_tensor());
const_attrs.insert({"value", val});
AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
@@ -260,7 +260,7 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
const_attrs.insert({"dtype", val});
val.Clear();
- SetTensorProto(&val);
+ SetTensorProto(val.mutable_tensor());
const_attrs.insert({"value", val});
AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
@@ -294,6 +294,77 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
&close_context_response));
}
+// Test creates a context and attempts to send a tensor (using the RPC), and
+// then use the tensor.
+TEST_F(EagerServiceImplTest, SendTensorTest) {
+ TestEagerServiceImpl eager_service_impl(&worker_env_);
+
+ CreateContextRequest request;
+ request.mutable_server_def()->set_job_name("localhost");
+ request.mutable_server_def()->set_task_index(0);
+ request.set_rendezvous_id(random::New64());
+ CreateContextResponse response;
+
+ TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
+
+ uint64 context_id = response.context_id();
+
+ SendTensorRequest send_tensor_request;
+ send_tensor_request.set_context_id(context_id);
+ send_tensor_request.set_op_id(1);
+ SetTensorProto(send_tensor_request.add_tensors());
+ SendTensorResponse send_tensor_response;
+
+ TF_ASSERT_OK(eager_service_impl.SendTensor(&send_tensor_request,
+ &send_tensor_response));
+
+ EnqueueRequest remote_enqueue_request;
+ remote_enqueue_request.set_context_id(context_id);
+ EnqueueResponse remote_enqueue_response;
+
+ std::unordered_map<string, AttrValue> attrs;
+ AttrValue val;
+ val.Clear();
+ val.set_type(tensorflow::DataType::DT_FLOAT);
+ attrs.insert({"T", val});
+ val.Clear();
+ val.set_b(false);
+ attrs.insert({"transpose_a", val});
+ attrs.insert({"transpose_b", val});
+
+ AddOperationToEnqueueRequest(2, "MatMul", {{1, 0}, {1, 0}}, attrs,
+ "/job:localhost/replica:0/task:0/device:CPU:0",
+ &remote_enqueue_request);
+
+ TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
+ &remote_enqueue_response));
+
+ const tensorflow::Tensor* t = nullptr;
+ tensorflow::TensorHandle* tensor_handle;
+ TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
+ response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle));
+ TF_ASSERT_OK(tensor_handle->Tensor(&t));
+
+ Device* device = nullptr;
+ TF_ASSERT_OK(tensor_handle->Device(&device));
+ EXPECT_NE(device, nullptr);
+ EXPECT_EQ(device->name(), "/job:localhost/replica:0/task:0/device:CPU:0");
+
+ auto actual = t->flat<float>();
+ EXPECT_EQ(4, actual.size());
+
+ EXPECT_EQ(7, actual(0));
+ EXPECT_EQ(10, actual(1));
+ EXPECT_EQ(15, actual(2));
+ EXPECT_EQ(22, actual(3));
+
+ CloseContextRequest close_context_request;
+ close_context_request.set_context_id(context_id);
+ CloseContextResponse close_context_response;
+ TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
+ &close_context_response));
+}
+
} // namespace
} // namespace eager
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index 28b68c3b88..0e3a68c4d8 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -29,8 +29,8 @@ namespace eager {
class RemoteExecuteNode : public tensorflow::EagerNode {
public:
RemoteExecuteNode(
- tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request,
- tensorflow::eager::EagerClient* eager_client,
+ tensorflow::uint64 id, std::unique_ptr<EnqueueRequest> request,
+ EagerClient* eager_client,
const gtl::InlinedVector<TensorHandle*, 4>& inputs,
std::function<void(const Status& status, const EnqueueResponse& response)>
done_callback)
@@ -45,8 +45,8 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
}
RemoteExecuteNode(tensorflow::uint64 id,
- const tensorflow::eager::EnqueueRequest& request,
- tensorflow::eager::EagerClient* eager_client)
+ std::unique_ptr<EnqueueRequest> request,
+ EagerClient* eager_client)
: tensorflow::EagerNode(id),
request_(std::move(request)),
eager_client_(eager_client) {}
@@ -58,10 +58,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
}
tensorflow::Status Run() override {
- tensorflow::eager::EnqueueResponse response;
- tensorflow::Status status;
+ EnqueueResponse response;
+ Status status;
Notification n;
- eager_client_->EnqueueAsync(&request_, &response,
+ eager_client_->EnqueueAsync(request_.get(), &response,
[&n, &status](const tensorflow::Status& s) {
status.Update(s);
n.Notify();
@@ -76,9 +76,8 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
}
private:
- EnqueueRequest request_;
- tensorflow::eager::EagerClient*
- eager_client_; // Not owned, and must outlive the RemoteExecuteNode.
+ std::unique_ptr<EnqueueRequest> request_;
+ EagerClient* eager_client_; // Not owned, and must outlive this node.
// This is required to ensure that the tensor handles stay alive across the
// execution.
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
index b23466037f..181422118c 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
@@ -49,6 +49,7 @@ class GrpcEagerClient : public EagerClient {
CLIENT_METHOD(KeepAlive);
CLIENT_METHOD(CloseContext);
CLIENT_METHOD(RegisterFunction);
+ CLIENT_METHOD(SendTensor);
#undef CLIENT_METHOD
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc
index 39ab6856c5..ab3aa3fd1d 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc
@@ -36,6 +36,7 @@ static const char* grpcEagerService_method_names[] = {
"/tensorflow.eager.EagerService/KeepAlive",
"/tensorflow.eager.EagerService/CloseContext",
"/tensorflow.eager.EagerService/RegisterFunction",
+ "/tensorflow.eager.EagerService/SendTensor",
};
std::unique_ptr<EagerService::Stub> EagerService::NewStub(
@@ -62,7 +63,9 @@ EagerService::Stub::Stub(
::grpc::internal::RpcMethod::NORMAL_RPC, channel),
rpcmethod_RegisterFunction_(grpcEagerService_method_names[5],
::grpc::internal::RpcMethod::NORMAL_RPC,
- channel) {}
+ channel),
+ rpcmethod_SendTensor_(grpcEagerService_method_names[6],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel) {}
::grpc::Status EagerService::Stub::CreateContext(
::grpc::ClientContext* context, const CreateContextRequest& request,
@@ -106,8 +109,15 @@ EagerService::Stub::Stub(
channel_.get(), rpcmethod_RegisterFunction_, context, request, response);
}
+::grpc::Status EagerService::Stub::SendTensor(::grpc::ClientContext* context,
+ const SendTensorRequest& request,
+ SendTensorResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_SendTensor_, context, request, response);
+}
+
EagerService::AsyncService::AsyncService() {
- for (int i = 0; i < 6; ++i) {
+ for (int i = 0; i < 7; ++i) {
AddMethod(new ::grpc::internal::RpcServiceMethod(
grpcEagerService_method_names[i],
::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h
index 66458186ad..521e0ac4fa 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h
@@ -69,6 +69,9 @@ class EagerService final {
virtual ::grpc::Status RegisterFunction(
::grpc::ClientContext* context, const RegisterFunctionRequest& request,
RegisterFunctionResponse* response) = 0;
+ virtual ::grpc::Status SendTensor(::grpc::ClientContext* context,
+ const SendTensorRequest& request,
+ SendTensorResponse* response) = 0;
};
class Stub final : public StubInterface {
public:
@@ -91,6 +94,9 @@ class EagerService final {
::grpc::Status RegisterFunction(
::grpc::ClientContext* context, const RegisterFunctionRequest& request,
RegisterFunctionResponse* response) override;
+ ::grpc::Status SendTensor(::grpc::ClientContext* context,
+ const SendTensorRequest& request,
+ SendTensorResponse* response) override;
private:
std::shared_ptr< ::grpc::ChannelInterface> channel_;
@@ -100,6 +106,7 @@ class EagerService final {
const ::grpc::internal::RpcMethod rpcmethod_KeepAlive_;
const ::grpc::internal::RpcMethod rpcmethod_CloseContext_;
const ::grpc::internal::RpcMethod rpcmethod_RegisterFunction_;
+ const ::grpc::internal::RpcMethod rpcmethod_SendTensor_;
};
static std::unique_ptr<Stub> NewStub(
const std::shared_ptr< ::grpc::ChannelInterface>& channel,
@@ -157,6 +164,14 @@ class EagerService final {
::grpc::Service::RequestAsyncUnary(5, context, request, response,
new_call_cq, notification_cq, tag);
}
+ void RequestSendTensor(
+ ::grpc::ServerContext* context, SendTensorRequest* request,
+ ::grpc::ServerAsyncResponseWriter<SendTensorResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(6, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
};
};
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
index 44e880de04..f511674e1f 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
@@ -48,6 +48,7 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() {
ENQUEUE_REQUEST(KeepAlive);
ENQUEUE_REQUEST(CloseContext);
ENQUEUE_REQUEST(RegisterFunction);
+ ENQUEUE_REQUEST(SendTensor);
#undef ENQUEUE_REQUEST
void* tag; // Matches the operation started against this cq_.
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
index 502f3ef529..537e9043bd 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
@@ -62,6 +62,7 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
HANDLER(KeepAlive);
HANDLER(CloseContext);
HANDLER(RegisterFunction);
+ HANDLER(SendTensor);
#undef HANDLER
const WorkerEnv* const env_; // Not owned.